diff options
Diffstat (limited to 'webrtc/p2p/base')
82 files changed, 31129 insertions, 0 deletions
diff --git a/webrtc/p2p/base/asyncstuntcpsocket.cc b/webrtc/p2p/base/asyncstuntcpsocket.cc new file mode 100644 index 0000000000..444f06146a --- /dev/null +++ b/webrtc/p2p/base/asyncstuntcpsocket.cc @@ -0,0 +1,153 @@ +/* + * Copyright 2013 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/asyncstuntcpsocket.h" + +#include <string.h> + +#include "webrtc/p2p/base/stun.h" +#include "webrtc/base/common.h" +#include "webrtc/base/logging.h" + +namespace cricket { + +static const size_t kMaxPacketSize = 64 * 1024; + +typedef uint16_t PacketLength; +static const size_t kPacketLenSize = sizeof(PacketLength); +static const size_t kPacketLenOffset = 2; +static const size_t kBufSize = kMaxPacketSize + kStunHeaderSize; +static const size_t kTurnChannelDataHdrSize = 4; + +inline bool IsStunMessage(uint16_t msg_type) { + // The first two bits of a channel data message are 0b01. + return (msg_type & 0xC000) ? false : true; +} + +// AsyncStunTCPSocket +// Binds and connects |socket| and creates AsyncTCPSocket for +// it. Takes ownership of |socket|. Returns NULL if bind() or +// connect() fail (|socket| is destroyed in that case). +AsyncStunTCPSocket* AsyncStunTCPSocket::Create( + rtc::AsyncSocket* socket, + const rtc::SocketAddress& bind_address, + const rtc::SocketAddress& remote_address) { + return new AsyncStunTCPSocket(AsyncTCPSocketBase::ConnectSocket( + socket, bind_address, remote_address), false); +} + +AsyncStunTCPSocket::AsyncStunTCPSocket( + rtc::AsyncSocket* socket, bool listen) + : rtc::AsyncTCPSocketBase(socket, listen, kBufSize) { +} + +int AsyncStunTCPSocket::Send(const void *pv, size_t cb, + const rtc::PacketOptions& options) { + if (cb > kBufSize || cb < kPacketLenSize + kPacketLenOffset) { + SetError(EMSGSIZE); + return -1; + } + + // If we are blocking on send, then silently drop this packet + if (!IsOutBufferEmpty()) + return static_cast<int>(cb); + + int pad_bytes; + size_t expected_pkt_len = GetExpectedLength(pv, cb, &pad_bytes); + + // Accepts only complete STUN/ChannelData packets. + if (cb != expected_pkt_len) + return -1; + + AppendToOutBuffer(pv, cb); + + ASSERT(pad_bytes < 4); + char padding[4] = {0}; + AppendToOutBuffer(padding, pad_bytes); + + int res = FlushOutBuffer(); + if (res <= 0) { + // drop packet if we made no progress + ClearOutBuffer(); + return res; + } + + // We claim to have sent the whole thing, even if we only sent partial + return static_cast<int>(cb); +} + +void AsyncStunTCPSocket::ProcessInput(char* data, size_t* len) { + rtc::SocketAddress remote_addr(GetRemoteAddress()); + // STUN packet - First 4 bytes. Total header size is 20 bytes. + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // |0 0| STUN Message Type | Message Length | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + // TURN ChannelData + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Channel Number | Length | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + while (true) { + // We need at least 4 bytes to read the STUN or ChannelData packet length. + if (*len < kPacketLenOffset + kPacketLenSize) + return; + + int pad_bytes; + size_t expected_pkt_len = GetExpectedLength(data, *len, &pad_bytes); + size_t actual_length = expected_pkt_len + pad_bytes; + + if (*len < actual_length) { + return; + } + + SignalReadPacket(this, data, expected_pkt_len, remote_addr, + rtc::CreatePacketTime(0)); + + *len -= actual_length; + if (*len > 0) { + memmove(data, data + actual_length, *len); + } + } +} + +void AsyncStunTCPSocket::HandleIncomingConnection( + rtc::AsyncSocket* socket) { + SignalNewConnection(this, new AsyncStunTCPSocket(socket, false)); +} + +size_t AsyncStunTCPSocket::GetExpectedLength(const void* data, size_t len, + int* pad_bytes) { + *pad_bytes = 0; + PacketLength pkt_len = + rtc::GetBE16(static_cast<const char*>(data) + kPacketLenOffset); + size_t expected_pkt_len; + uint16_t msg_type = rtc::GetBE16(data); + if (IsStunMessage(msg_type)) { + // STUN message. + expected_pkt_len = kStunHeaderSize + pkt_len; + } else { + // TURN ChannelData message. + expected_pkt_len = kTurnChannelDataHdrSize + pkt_len; + // From RFC 5766 section 11.5 + // Over TCP and TLS-over-TCP, the ChannelData message MUST be padded to + // a multiple of four bytes in order to ensure the alignment of + // subsequent messages. The padding is not reflected in the length + // field of the ChannelData message, so the actual size of a ChannelData + // message (including padding) is (4 + Length) rounded up to the nearest + // multiple of 4. Over UDP, the padding is not required but MAY be + // included. + if (expected_pkt_len % 4) + *pad_bytes = 4 - (expected_pkt_len % 4); + } + return expected_pkt_len; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/asyncstuntcpsocket.h b/webrtc/p2p/base/asyncstuntcpsocket.h new file mode 100644 index 0000000000..3a15d4a399 --- /dev/null +++ b/webrtc/p2p/base/asyncstuntcpsocket.h @@ -0,0 +1,50 @@ +/* + * Copyright 2013 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_ASYNCSTUNTCPSOCKET_H_ +#define WEBRTC_P2P_BASE_ASYNCSTUNTCPSOCKET_H_ + +#include "webrtc/base/asynctcpsocket.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/socketfactory.h" + +namespace cricket { + +class AsyncStunTCPSocket : public rtc::AsyncTCPSocketBase { + public: + // Binds and connects |socket| and creates AsyncTCPSocket for + // it. Takes ownership of |socket|. Returns NULL if bind() or + // connect() fail (|socket| is destroyed in that case). + static AsyncStunTCPSocket* Create( + rtc::AsyncSocket* socket, + const rtc::SocketAddress& bind_address, + const rtc::SocketAddress& remote_address); + + AsyncStunTCPSocket(rtc::AsyncSocket* socket, bool listen); + virtual ~AsyncStunTCPSocket() {} + + virtual int Send(const void* pv, size_t cb, + const rtc::PacketOptions& options); + virtual void ProcessInput(char* data, size_t* len); + virtual void HandleIncomingConnection(rtc::AsyncSocket* socket); + + private: + // This method returns the message hdr + length written in the header. + // This method also returns the number of padding bytes needed/added to the + // turn message. |pad_bytes| should be used only when |is_turn| is true. + size_t GetExpectedLength(const void* data, size_t len, + int* pad_bytes); + + RTC_DISALLOW_COPY_AND_ASSIGN(AsyncStunTCPSocket); +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_ASYNCSTUNTCPSOCKET_H_ diff --git a/webrtc/p2p/base/asyncstuntcpsocket_unittest.cc b/webrtc/p2p/base/asyncstuntcpsocket_unittest.cc new file mode 100644 index 0000000000..22c1b26903 --- /dev/null +++ b/webrtc/p2p/base/asyncstuntcpsocket_unittest.cc @@ -0,0 +1,263 @@ +/* + * Copyright 2013 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/asyncstuntcpsocket.h" +#include "webrtc/base/asyncsocket.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/physicalsocketserver.h" +#include "webrtc/base/virtualsocketserver.h" + +namespace cricket { + +static unsigned char kStunMessageWithZeroLength[] = { + 0x00, 0x01, 0x00, 0x00, // length of 0 (last 2 bytes) + 0x21, 0x12, 0xA4, 0x42, + '0', '1', '2', '3', + '4', '5', '6', '7', + '8', '9', 'a', 'b', +}; + + +static unsigned char kTurnChannelDataMessageWithZeroLength[] = { + 0x40, 0x00, 0x00, 0x00, // length of 0 (last 2 bytes) +}; + +static unsigned char kTurnChannelDataMessage[] = { + 0x40, 0x00, 0x00, 0x10, + 0x21, 0x12, 0xA4, 0x42, + '0', '1', '2', '3', + '4', '5', '6', '7', + '8', '9', 'a', 'b', +}; + +static unsigned char kStunMessageWithInvalidLength[] = { + 0x00, 0x01, 0x00, 0x10, + 0x21, 0x12, 0xA4, 0x42, + '0', '1', '2', '3', + '4', '5', '6', '7', + '8', '9', 'a', 'b', +}; + +static unsigned char kTurnChannelDataMessageWithInvalidLength[] = { + 0x80, 0x00, 0x00, 0x20, + 0x21, 0x12, 0xA4, 0x42, + '0', '1', '2', '3', + '4', '5', '6', '7', + '8', '9', 'a', 'b', +}; + +static unsigned char kTurnChannelDataMessageWithOddLength[] = { + 0x40, 0x00, 0x00, 0x05, + 0x21, 0x12, 0xA4, 0x42, + '0', +}; + + +static const rtc::SocketAddress kClientAddr("11.11.11.11", 0); +static const rtc::SocketAddress kServerAddr("22.22.22.22", 0); + +class AsyncStunTCPSocketTest : public testing::Test, + public sigslot::has_slots<> { + protected: + AsyncStunTCPSocketTest() + : vss_(new rtc::VirtualSocketServer(NULL)), + ss_scope_(vss_.get()) { + } + + virtual void SetUp() { + CreateSockets(); + } + + void CreateSockets() { + rtc::AsyncSocket* server = vss_->CreateAsyncSocket( + kServerAddr.family(), SOCK_STREAM); + server->Bind(kServerAddr); + recv_socket_.reset(new AsyncStunTCPSocket(server, true)); + recv_socket_->SignalNewConnection.connect( + this, &AsyncStunTCPSocketTest::OnNewConnection); + + rtc::AsyncSocket* client = vss_->CreateAsyncSocket( + kClientAddr.family(), SOCK_STREAM); + send_socket_.reset(AsyncStunTCPSocket::Create( + client, kClientAddr, recv_socket_->GetLocalAddress())); + ASSERT_TRUE(send_socket_.get() != NULL); + vss_->ProcessMessagesUntilIdle(); + } + + void OnReadPacket(rtc::AsyncPacketSocket* socket, const char* data, + size_t len, const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + recv_packets_.push_back(std::string(data, len)); + } + + void OnNewConnection(rtc::AsyncPacketSocket* server, + rtc::AsyncPacketSocket* new_socket) { + listen_socket_.reset(new_socket); + new_socket->SignalReadPacket.connect( + this, &AsyncStunTCPSocketTest::OnReadPacket); + } + + bool Send(const void* data, size_t len) { + rtc::PacketOptions options; + size_t ret = send_socket_->Send( + reinterpret_cast<const char*>(data), len, options); + vss_->ProcessMessagesUntilIdle(); + return (ret == len); + } + + bool CheckData(const void* data, int len) { + bool ret = false; + if (recv_packets_.size()) { + std::string packet = recv_packets_.front(); + recv_packets_.pop_front(); + ret = (memcmp(data, packet.c_str(), len) == 0); + } + return ret; + } + + rtc::scoped_ptr<rtc::VirtualSocketServer> vss_; + rtc::SocketServerScope ss_scope_; + rtc::scoped_ptr<AsyncStunTCPSocket> send_socket_; + rtc::scoped_ptr<AsyncStunTCPSocket> recv_socket_; + rtc::scoped_ptr<rtc::AsyncPacketSocket> listen_socket_; + std::list<std::string> recv_packets_; +}; + +// Testing a stun packet sent/recv properly. +TEST_F(AsyncStunTCPSocketTest, TestSingleStunPacket) { + EXPECT_TRUE(Send(kStunMessageWithZeroLength, + sizeof(kStunMessageWithZeroLength))); + EXPECT_EQ(1u, recv_packets_.size()); + EXPECT_TRUE(CheckData(kStunMessageWithZeroLength, + sizeof(kStunMessageWithZeroLength))); +} + +// Verify sending multiple packets. +TEST_F(AsyncStunTCPSocketTest, TestMultipleStunPackets) { + EXPECT_TRUE(Send(kStunMessageWithZeroLength, + sizeof(kStunMessageWithZeroLength))); + EXPECT_TRUE(Send(kStunMessageWithZeroLength, + sizeof(kStunMessageWithZeroLength))); + EXPECT_TRUE(Send(kStunMessageWithZeroLength, + sizeof(kStunMessageWithZeroLength))); + EXPECT_TRUE(Send(kStunMessageWithZeroLength, + sizeof(kStunMessageWithZeroLength))); + EXPECT_EQ(4u, recv_packets_.size()); +} + +// Verifying TURN channel data message with zero length. +TEST_F(AsyncStunTCPSocketTest, TestTurnChannelDataWithZeroLength) { + EXPECT_TRUE(Send(kTurnChannelDataMessageWithZeroLength, + sizeof(kTurnChannelDataMessageWithZeroLength))); + EXPECT_EQ(1u, recv_packets_.size()); + EXPECT_TRUE(CheckData(kTurnChannelDataMessageWithZeroLength, + sizeof(kTurnChannelDataMessageWithZeroLength))); +} + +// Verifying TURN channel data message. +TEST_F(AsyncStunTCPSocketTest, TestTurnChannelData) { + EXPECT_TRUE(Send(kTurnChannelDataMessage, + sizeof(kTurnChannelDataMessage))); + EXPECT_EQ(1u, recv_packets_.size()); + EXPECT_TRUE(CheckData(kTurnChannelDataMessage, + sizeof(kTurnChannelDataMessage))); +} + +// Verifying TURN channel messages which needs padding handled properly. +TEST_F(AsyncStunTCPSocketTest, TestTurnChannelDataPadding) { + EXPECT_TRUE(Send(kTurnChannelDataMessageWithOddLength, + sizeof(kTurnChannelDataMessageWithOddLength))); + EXPECT_EQ(1u, recv_packets_.size()); + EXPECT_TRUE(CheckData(kTurnChannelDataMessageWithOddLength, + sizeof(kTurnChannelDataMessageWithOddLength))); +} + +// Verifying stun message with invalid length. +TEST_F(AsyncStunTCPSocketTest, TestStunInvalidLength) { + EXPECT_FALSE(Send(kStunMessageWithInvalidLength, + sizeof(kStunMessageWithInvalidLength))); + EXPECT_EQ(0u, recv_packets_.size()); + + // Modify the message length to larger value. + kStunMessageWithInvalidLength[2] = 0xFF; + kStunMessageWithInvalidLength[3] = 0xFF; + EXPECT_FALSE(Send(kStunMessageWithInvalidLength, + sizeof(kStunMessageWithInvalidLength))); + + // Modify the message length to smaller value. + kStunMessageWithInvalidLength[2] = 0x00; + kStunMessageWithInvalidLength[3] = 0x01; + EXPECT_FALSE(Send(kStunMessageWithInvalidLength, + sizeof(kStunMessageWithInvalidLength))); +} + +// Verifying TURN channel data message with invalid length. +TEST_F(AsyncStunTCPSocketTest, TestTurnChannelDataWithInvalidLength) { + EXPECT_FALSE(Send(kTurnChannelDataMessageWithInvalidLength, + sizeof(kTurnChannelDataMessageWithInvalidLength))); + // Modify the length to larger value. + kTurnChannelDataMessageWithInvalidLength[2] = 0xFF; + kTurnChannelDataMessageWithInvalidLength[3] = 0xF0; + EXPECT_FALSE(Send(kTurnChannelDataMessageWithInvalidLength, + sizeof(kTurnChannelDataMessageWithInvalidLength))); + + // Modify the length to smaller value. + kTurnChannelDataMessageWithInvalidLength[2] = 0x00; + kTurnChannelDataMessageWithInvalidLength[3] = 0x00; + EXPECT_FALSE(Send(kTurnChannelDataMessageWithInvalidLength, + sizeof(kTurnChannelDataMessageWithInvalidLength))); +} + +// Verifying a small buffer handled (dropped) properly. This will be +// a common one for both stun and turn. +TEST_F(AsyncStunTCPSocketTest, TestTooSmallMessageBuffer) { + char data[1]; + EXPECT_FALSE(Send(data, sizeof(data))); +} + +// Verifying a legal large turn message. +TEST_F(AsyncStunTCPSocketTest, TestMaximumSizeTurnPacket) { + // We have problem in getting the SignalWriteEvent from the virtual socket + // server. So increasing the send buffer to 64k. + // TODO(mallinath) - Remove this setting after we fix vss issue. + vss_->set_send_buffer_capacity(64 * 1024); + unsigned char packet[65539]; + packet[0] = 0x40; + packet[1] = 0x00; + packet[2] = 0xFF; + packet[3] = 0xFF; + EXPECT_TRUE(Send(packet, sizeof(packet))); +} + +// Verifying a legal large stun message. +TEST_F(AsyncStunTCPSocketTest, TestMaximumSizeStunPacket) { + // We have problem in getting the SignalWriteEvent from the virtual socket + // server. So increasing the send buffer to 64k. + // TODO(mallinath) - Remove this setting after we fix vss issue. + vss_->set_send_buffer_capacity(64 * 1024); + unsigned char packet[65552]; + packet[0] = 0x00; + packet[1] = 0x01; + packet[2] = 0xFF; + packet[3] = 0xFC; + EXPECT_TRUE(Send(packet, sizeof(packet))); +} + +// Investigate why WriteEvent is not signaled from VSS. +TEST_F(AsyncStunTCPSocketTest, DISABLED_TestWithSmallSendBuffer) { + vss_->set_send_buffer_capacity(1); + Send(kTurnChannelDataMessageWithOddLength, + sizeof(kTurnChannelDataMessageWithOddLength)); + EXPECT_EQ(1u, recv_packets_.size()); + EXPECT_TRUE(CheckData(kTurnChannelDataMessageWithOddLength, + sizeof(kTurnChannelDataMessageWithOddLength))); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/basicpacketsocketfactory.cc b/webrtc/p2p/base/basicpacketsocketfactory.cc new file mode 100644 index 0000000000..697518da9d --- /dev/null +++ b/webrtc/p2p/base/basicpacketsocketfactory.cc @@ -0,0 +1,209 @@ +/* + * Copyright 2011 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/basicpacketsocketfactory.h" + +#include "webrtc/p2p/base/asyncstuntcpsocket.h" +#include "webrtc/p2p/base/stun.h" +#include "webrtc/base/asynctcpsocket.h" +#include "webrtc/base/asyncudpsocket.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/nethelpers.h" +#include "webrtc/base/physicalsocketserver.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/socketadapters.h" +#include "webrtc/base/ssladapter.h" +#include "webrtc/base/thread.h" + +namespace rtc { + +BasicPacketSocketFactory::BasicPacketSocketFactory() + : thread_(Thread::Current()), + socket_factory_(NULL) { +} + +BasicPacketSocketFactory::BasicPacketSocketFactory(Thread* thread) + : thread_(thread), + socket_factory_(NULL) { +} + +BasicPacketSocketFactory::BasicPacketSocketFactory( + SocketFactory* socket_factory) + : thread_(NULL), + socket_factory_(socket_factory) { +} + +BasicPacketSocketFactory::~BasicPacketSocketFactory() { +} + +AsyncPacketSocket* BasicPacketSocketFactory::CreateUdpSocket( + const SocketAddress& address, + uint16_t min_port, + uint16_t max_port) { + // UDP sockets are simple. + rtc::AsyncSocket* socket = + socket_factory()->CreateAsyncSocket( + address.family(), SOCK_DGRAM); + if (!socket) { + return NULL; + } + if (BindSocket(socket, address, min_port, max_port) < 0) { + LOG(LS_ERROR) << "UDP bind failed with error " + << socket->GetError(); + delete socket; + return NULL; + } + return new rtc::AsyncUDPSocket(socket); +} + +AsyncPacketSocket* BasicPacketSocketFactory::CreateServerTcpSocket( + const SocketAddress& local_address, + uint16_t min_port, + uint16_t max_port, + int opts) { + // Fail if TLS is required. + if (opts & PacketSocketFactory::OPT_TLS) { + LOG(LS_ERROR) << "TLS support currently is not available."; + return NULL; + } + + rtc::AsyncSocket* socket = + socket_factory()->CreateAsyncSocket(local_address.family(), + SOCK_STREAM); + if (!socket) { + return NULL; + } + + if (BindSocket(socket, local_address, min_port, max_port) < 0) { + LOG(LS_ERROR) << "TCP bind failed with error " + << socket->GetError(); + delete socket; + return NULL; + } + + // If using SSLTCP, wrap the TCP socket in a pseudo-SSL socket. + if (opts & PacketSocketFactory::OPT_SSLTCP) { + ASSERT(!(opts & PacketSocketFactory::OPT_TLS)); + socket = new rtc::AsyncSSLSocket(socket); + } + + // Set TCP_NODELAY (via OPT_NODELAY) for improved performance. + // See http://go/gtalktcpnodelayexperiment + socket->SetOption(rtc::Socket::OPT_NODELAY, 1); + + if (opts & PacketSocketFactory::OPT_STUN) + return new cricket::AsyncStunTCPSocket(socket, true); + + return new rtc::AsyncTCPSocket(socket, true); +} + +AsyncPacketSocket* BasicPacketSocketFactory::CreateClientTcpSocket( + const SocketAddress& local_address, const SocketAddress& remote_address, + const ProxyInfo& proxy_info, const std::string& user_agent, int opts) { + + rtc::AsyncSocket* socket = + socket_factory()->CreateAsyncSocket(local_address.family(), SOCK_STREAM); + if (!socket) { + return NULL; + } + + if (BindSocket(socket, local_address, 0, 0) < 0) { + LOG(LS_ERROR) << "TCP bind failed with error " + << socket->GetError(); + delete socket; + return NULL; + } + + // If using a proxy, wrap the socket in a proxy socket. + if (proxy_info.type == rtc::PROXY_SOCKS5) { + socket = new rtc::AsyncSocksProxySocket( + socket, proxy_info.address, proxy_info.username, proxy_info.password); + } else if (proxy_info.type == rtc::PROXY_HTTPS) { + socket = new rtc::AsyncHttpsProxySocket( + socket, user_agent, proxy_info.address, + proxy_info.username, proxy_info.password); + } + + // If using TLS, wrap the socket in an SSL adapter. + if (opts & PacketSocketFactory::OPT_TLS) { + ASSERT(!(opts & PacketSocketFactory::OPT_SSLTCP)); + + rtc::SSLAdapter* ssl_adapter = rtc::SSLAdapter::Create(socket); + if (!ssl_adapter) { + return NULL; + } + + socket = ssl_adapter; + + if (ssl_adapter->StartSSL(remote_address.hostname().c_str(), false) != 0) { + delete ssl_adapter; + return NULL; + } + + // If using SSLTCP, wrap the TCP socket in a pseudo-SSL socket. + } else if (opts & PacketSocketFactory::OPT_SSLTCP) { + ASSERT(!(opts & PacketSocketFactory::OPT_TLS)); + socket = new rtc::AsyncSSLSocket(socket); + } + + if (socket->Connect(remote_address) < 0) { + LOG(LS_ERROR) << "TCP connect failed with error " + << socket->GetError(); + delete socket; + return NULL; + } + + // Finally, wrap that socket in a TCP or STUN TCP packet socket. + AsyncPacketSocket* tcp_socket; + if (opts & PacketSocketFactory::OPT_STUN) { + tcp_socket = new cricket::AsyncStunTCPSocket(socket, false); + } else { + tcp_socket = new rtc::AsyncTCPSocket(socket, false); + } + + // Set TCP_NODELAY (via OPT_NODELAY) for improved performance. + // See http://go/gtalktcpnodelayexperiment + tcp_socket->SetOption(rtc::Socket::OPT_NODELAY, 1); + + return tcp_socket; +} + +AsyncResolverInterface* BasicPacketSocketFactory::CreateAsyncResolver() { + return new rtc::AsyncResolver(); +} + +int BasicPacketSocketFactory::BindSocket(AsyncSocket* socket, + const SocketAddress& local_address, + uint16_t min_port, + uint16_t max_port) { + int ret = -1; + if (min_port == 0 && max_port == 0) { + // If there's no port range, let the OS pick a port for us. + ret = socket->Bind(local_address); + } else { + // Otherwise, try to find a port in the provided range. + for (int port = min_port; ret < 0 && port <= max_port; ++port) { + ret = socket->Bind(rtc::SocketAddress(local_address.ipaddr(), + port)); + } + } + return ret; +} + +SocketFactory* BasicPacketSocketFactory::socket_factory() { + if (thread_) { + ASSERT(thread_ == Thread::Current()); + return thread_->socketserver(); + } else { + return socket_factory_; + } +} + +} // namespace rtc diff --git a/webrtc/p2p/base/basicpacketsocketfactory.h b/webrtc/p2p/base/basicpacketsocketfactory.h new file mode 100644 index 0000000000..5046e0f518 --- /dev/null +++ b/webrtc/p2p/base/basicpacketsocketfactory.h @@ -0,0 +1,58 @@ +/* + * Copyright 2011 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_BASICPACKETSOCKETFACTORY_H_ +#define WEBRTC_P2P_BASE_BASICPACKETSOCKETFACTORY_H_ + +#include "webrtc/p2p/base/packetsocketfactory.h" + +namespace rtc { + +class AsyncSocket; +class SocketFactory; +class Thread; + +class BasicPacketSocketFactory : public PacketSocketFactory { + public: + BasicPacketSocketFactory(); + explicit BasicPacketSocketFactory(Thread* thread); + explicit BasicPacketSocketFactory(SocketFactory* socket_factory); + ~BasicPacketSocketFactory() override; + + AsyncPacketSocket* CreateUdpSocket(const SocketAddress& local_address, + uint16_t min_port, + uint16_t max_port) override; + AsyncPacketSocket* CreateServerTcpSocket(const SocketAddress& local_address, + uint16_t min_port, + uint16_t max_port, + int opts) override; + AsyncPacketSocket* CreateClientTcpSocket(const SocketAddress& local_address, + const SocketAddress& remote_address, + const ProxyInfo& proxy_info, + const std::string& user_agent, + int opts) override; + + AsyncResolverInterface* CreateAsyncResolver() override; + + private: + int BindSocket(AsyncSocket* socket, + const SocketAddress& local_address, + uint16_t min_port, + uint16_t max_port); + + SocketFactory* socket_factory(); + + Thread* thread_; + SocketFactory* socket_factory_; +}; + +} // namespace rtc + +#endif // WEBRTC_P2P_BASE_BASICPACKETSOCKETFACTORY_H_ diff --git a/webrtc/p2p/base/candidate.h b/webrtc/p2p/base/candidate.h new file mode 100644 index 0000000000..3f0ea43cde --- /dev/null +++ b/webrtc/p2p/base/candidate.h @@ -0,0 +1,251 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_CANDIDATE_H_ +#define WEBRTC_P2P_BASE_CANDIDATE_H_ + +#include <limits.h> +#include <math.h> + +#include <algorithm> +#include <iomanip> +#include <sstream> +#include <string> + +#include "webrtc/p2p/base/constants.h" +#include "webrtc/base/basictypes.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/network.h" +#include "webrtc/base/socketaddress.h" + +namespace cricket { + +// Candidate for ICE based connection discovery. + +class Candidate { + public: + // TODO: Match the ordering and param list as per RFC 5245 + // candidate-attribute syntax. http://tools.ietf.org/html/rfc5245#section-15.1 + Candidate() + : id_(rtc::CreateRandomString(8)), + component_(0), + priority_(0), + network_type_(rtc::ADAPTER_TYPE_UNKNOWN), + generation_(0) {} + + Candidate(int component, + const std::string& protocol, + const rtc::SocketAddress& address, + uint32_t priority, + const std::string& username, + const std::string& password, + const std::string& type, + uint32_t generation, + const std::string& foundation) + : id_(rtc::CreateRandomString(8)), + component_(component), + protocol_(protocol), + address_(address), + priority_(priority), + username_(username), + password_(password), + type_(type), + network_type_(rtc::ADAPTER_TYPE_UNKNOWN), + generation_(generation), + foundation_(foundation) {} + + const std::string & id() const { return id_; } + void set_id(const std::string & id) { id_ = id; } + + int component() const { return component_; } + void set_component(int component) { component_ = component; } + + const std::string & protocol() const { return protocol_; } + void set_protocol(const std::string & protocol) { protocol_ = protocol; } + + // The protocol used to talk to relay. + const std::string& relay_protocol() const { return relay_protocol_; } + void set_relay_protocol(const std::string& protocol) { + relay_protocol_ = protocol; + } + + const rtc::SocketAddress & address() const { return address_; } + void set_address(const rtc::SocketAddress & address) { + address_ = address; + } + + uint32_t priority() const { return priority_; } + void set_priority(const uint32_t priority) { priority_ = priority; } + + // TODO(pthatcher): Remove once Chromium's jingle/glue/utils.cc + // doesn't use it. + // Maps old preference (which was 0.0-1.0) to match priority (which + // is 0-2^32-1) to to match RFC 5245, section 4.1.2.1. Also see + // https://docs.google.com/a/google.com/document/d/ + // 1iNQDiwDKMh0NQOrCqbj3DKKRT0Dn5_5UJYhmZO-t7Uc/edit + float preference() const { + // The preference value is clamped to two decimal precision. + return static_cast<float>(((priority_ >> 24) * 100 / 127) / 100.0); + } + + // TODO(pthatcher): Remove once Chromium's jingle/glue/utils.cc + // doesn't use it. + void set_preference(float preference) { + // Limiting priority to UINT_MAX when value exceeds uint32_t max. + // This can happen for e.g. when preference = 3. + uint64_t prio_val = static_cast<uint64_t>(preference * 127) << 24; + priority_ = static_cast<uint32_t>( + std::min(prio_val, static_cast<uint64_t>(UINT_MAX))); + } + + const std::string & username() const { return username_; } + void set_username(const std::string & username) { username_ = username; } + + const std::string & password() const { return password_; } + void set_password(const std::string & password) { password_ = password; } + + const std::string & type() const { return type_; } + void set_type(const std::string & type) { type_ = type; } + + const std::string & network_name() const { return network_name_; } + void set_network_name(const std::string & network_name) { + network_name_ = network_name; + } + + rtc::AdapterType network_type() const { return network_type_; } + void set_network_type(rtc::AdapterType network_type) { + network_type_ = network_type; + } + + // Candidates in a new generation replace those in the old generation. + uint32_t generation() const { return generation_; } + void set_generation(uint32_t generation) { generation_ = generation; } + const std::string generation_str() const { + std::ostringstream ost; + ost << generation_; + return ost.str(); + } + void set_generation_str(const std::string& str) { + std::istringstream ist(str); + ist >> generation_; + } + + const std::string& foundation() const { + return foundation_; + } + + void set_foundation(const std::string& foundation) { + foundation_ = foundation; + } + + const rtc::SocketAddress & related_address() const { + return related_address_; + } + void set_related_address( + const rtc::SocketAddress & related_address) { + related_address_ = related_address; + } + const std::string& tcptype() const { return tcptype_; } + void set_tcptype(const std::string& tcptype){ + tcptype_ = tcptype; + } + + // Determines whether this candidate is equivalent to the given one. + bool IsEquivalent(const Candidate& c) const { + // We ignore the network name, since that is just debug information, and + // the priority, since that should be the same if the rest is (and it's + // a float so equality checking is always worrisome). + return (component_ == c.component_) && (protocol_ == c.protocol_) && + (address_ == c.address_) && (username_ == c.username_) && + (password_ == c.password_) && (type_ == c.type_) && + (generation_ == c.generation_) && (foundation_ == c.foundation_) && + (related_address_ == c.related_address_); + } + + std::string ToString() const { + return ToStringInternal(false); + } + + std::string ToSensitiveString() const { + return ToStringInternal(true); + } + + uint32_t GetPriority(uint32_t type_preference, + int network_adapter_preference, + int relay_preference) const { + // RFC 5245 - 4.1.2.1. + // priority = (2^24)*(type preference) + + // (2^8)*(local preference) + + // (2^0)*(256 - component ID) + + // |local_preference| length is 2 bytes, 0-65535 inclusive. + // In our implemenation we will partion local_preference into + // 0 1 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | NIC Pref | Addr Pref | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // NIC Type - Type of the network adapter e.g. 3G/Wifi/Wired. + // Addr Pref - Address preference value as per RFC 3484. + // local preference = (NIC Type << 8 | Addr_Pref) - relay preference. + + int addr_pref = IPAddressPrecedence(address_.ipaddr()); + int local_preference = ((network_adapter_preference << 8) | addr_pref) + + relay_preference; + + return (type_preference << 24) | + (local_preference << 8) | + (256 - component_); + } + + private: + std::string ToStringInternal(bool sensitive) const { + std::ostringstream ost; + std::string address = sensitive ? address_.ToSensitiveString() : + address_.ToString(); + ost << "Cand[" << foundation_ << ":" << component_ << ":" + << protocol_ << ":" << priority_ << ":" + << address << ":" << type_ << ":" << related_address_ << ":" + << username_ << ":" << password_ << "]"; + return ost.str(); + } + + std::string id_; + int component_; + std::string protocol_; + std::string relay_protocol_; + rtc::SocketAddress address_; + uint32_t priority_; + std::string username_; + std::string password_; + std::string type_; + std::string network_name_; + rtc::AdapterType network_type_; + uint32_t generation_; + std::string foundation_; + rtc::SocketAddress related_address_; + std::string tcptype_; +}; + +// Used during parsing and writing to map component to channel name +// and back. This is primarily for converting old G-ICE candidate +// signalling to new ICE candidate classes. +class CandidateTranslator { + public: + virtual ~CandidateTranslator() {} + virtual bool GetChannelNameFromComponent( + int component, std::string* channel_name) const = 0; + virtual bool GetComponentFromChannelName( + const std::string& channel_name, int* component) const = 0; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_CANDIDATE_H_ diff --git a/webrtc/p2p/base/common.h b/webrtc/p2p/base/common.h new file mode 100644 index 0000000000..8a3178c801 --- /dev/null +++ b/webrtc/p2p/base/common.h @@ -0,0 +1,20 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_COMMON_H_ +#define WEBRTC_P2P_BASE_COMMON_H_ + +#include "webrtc/base/logging.h" + +// Common log description format for jingle messages +#define LOG_J(sev, obj) LOG(sev) << "Jingle:" << obj->ToString() << ": " +#define LOG_JV(sev, obj) LOG_V(sev) << "Jingle:" << obj->ToString() << ": " + +#endif // WEBRTC_P2P_BASE_COMMON_H_ diff --git a/webrtc/p2p/base/constants.cc b/webrtc/p2p/base/constants.cc new file mode 100644 index 0000000000..2a258718f4 --- /dev/null +++ b/webrtc/p2p/base/constants.cc @@ -0,0 +1,50 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/constants.h" + +#include <string> + +namespace cricket { + +const char CN_AUDIO[] = "audio"; +const char CN_VIDEO[] = "video"; +const char CN_DATA[] = "data"; +const char CN_OTHER[] = "main"; + +const char GROUP_TYPE_BUNDLE[] = "BUNDLE"; + +// Minimum ufrag length is 4 characters as per RFC5245. We chose 16 because +// some internal systems expect username to be 16 bytes. +const int ICE_UFRAG_LENGTH = 16; +// Minimum password length of 22 characters as per RFC5245. We chose 24 because +// some internal systems expect password to be multiple of 4. +const int ICE_PWD_LENGTH = 24; +const size_t ICE_UFRAG_MIN_LENGTH = 4; +const size_t ICE_PWD_MIN_LENGTH = 22; +const size_t ICE_UFRAG_MAX_LENGTH = 255; +const size_t ICE_PWD_MAX_LENGTH = 256; + +// TODO: This is media-specific, so might belong +// somewhere like media/base/constants.h +const int ICE_CANDIDATE_COMPONENT_RTP = 1; +const int ICE_CANDIDATE_COMPONENT_RTCP = 2; +const int ICE_CANDIDATE_COMPONENT_DEFAULT = 1; + +const char NS_JINGLE_RTP[] = "urn:xmpp:jingle:apps:rtp:1"; +const char NS_JINGLE_DRAFT_SCTP[] = "google:jingle:sctp"; + +// From RFC 4145, SDP setup attribute values. +const char CONNECTIONROLE_ACTIVE_STR[] = "active"; +const char CONNECTIONROLE_PASSIVE_STR[] = "passive"; +const char CONNECTIONROLE_ACTPASS_STR[] = "actpass"; +const char CONNECTIONROLE_HOLDCONN_STR[] = "holdconn"; + +} // namespace cricket diff --git a/webrtc/p2p/base/constants.h b/webrtc/p2p/base/constants.h new file mode 100644 index 0000000000..c3e1b781dc --- /dev/null +++ b/webrtc/p2p/base/constants.h @@ -0,0 +1,53 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_CONSTANTS_H_ +#define WEBRTC_P2P_BASE_CONSTANTS_H_ + +#include <string> + +namespace cricket { + +// CN_ == "content name". When we initiate a session, we choose the +// name, and when we receive a Gingle session, we provide default +// names (since Gingle has no content names). But when we receive a +// Jingle call, the content name can be anything, so don't rely on +// these values being the same as the ones received. +extern const char CN_AUDIO[]; +extern const char CN_VIDEO[]; +extern const char CN_DATA[]; +extern const char CN_OTHER[]; + +// GN stands for group name +extern const char GROUP_TYPE_BUNDLE[]; + +extern const int ICE_UFRAG_LENGTH; +extern const int ICE_PWD_LENGTH; +extern const size_t ICE_UFRAG_MIN_LENGTH; +extern const size_t ICE_PWD_MIN_LENGTH; +extern const size_t ICE_UFRAG_MAX_LENGTH; +extern const size_t ICE_PWD_MAX_LENGTH; + +extern const int ICE_CANDIDATE_COMPONENT_RTP; +extern const int ICE_CANDIDATE_COMPONENT_RTCP; +extern const int ICE_CANDIDATE_COMPONENT_DEFAULT; + +extern const char NS_JINGLE_RTP[]; +extern const char NS_JINGLE_DRAFT_SCTP[]; + +// RFC 4145, SDP setup attribute values. +extern const char CONNECTIONROLE_ACTIVE_STR[]; +extern const char CONNECTIONROLE_PASSIVE_STR[]; +extern const char CONNECTIONROLE_ACTPASS_STR[]; +extern const char CONNECTIONROLE_HOLDCONN_STR[]; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_CONSTANTS_H_ diff --git a/webrtc/p2p/base/dtlstransport.h b/webrtc/p2p/base/dtlstransport.h new file mode 100644 index 0000000000..e9a1ae2ada --- /dev/null +++ b/webrtc/p2p/base/dtlstransport.h @@ -0,0 +1,250 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_DTLSTRANSPORT_H_ +#define WEBRTC_P2P_BASE_DTLSTRANSPORT_H_ + +#include "webrtc/p2p/base/dtlstransportchannel.h" +#include "webrtc/p2p/base/transport.h" + +namespace rtc { +class SSLIdentity; +} + +namespace cricket { + +class PortAllocator; + +// Base should be a descendant of cricket::Transport and have a constructor +// that takes a transport name and PortAllocator. +// +// Everything in this class should be called on the worker thread. +template<class Base> +class DtlsTransport : public Base { + public: + DtlsTransport(const std::string& name, + PortAllocator* allocator, + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) + : Base(name, allocator), + certificate_(certificate), + secure_role_(rtc::SSL_CLIENT), + ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_10) {} + + ~DtlsTransport() { + Base::DestroyAllChannels(); + } + + void SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override { + certificate_ = certificate; + } + bool GetLocalCertificate( + rtc::scoped_refptr<rtc::RTCCertificate>* certificate) override { + if (!certificate_) + return false; + + *certificate = certificate_; + return true; + } + + bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override { + ssl_max_version_ = version; + return true; + } + + bool ApplyLocalTransportDescription(TransportChannelImpl* channel, + std::string* error_desc) override { + rtc::SSLFingerprint* local_fp = + Base::local_description()->identity_fingerprint.get(); + + if (local_fp) { + // Sanity check local fingerprint. + if (certificate_) { + rtc::scoped_ptr<rtc::SSLFingerprint> local_fp_tmp( + rtc::SSLFingerprint::Create(local_fp->algorithm, + certificate_->identity())); + ASSERT(local_fp_tmp.get() != NULL); + if (!(*local_fp_tmp == *local_fp)) { + std::ostringstream desc; + desc << "Local fingerprint does not match identity. Expected: "; + desc << local_fp_tmp->ToString(); + desc << " Got: " << local_fp->ToString(); + return BadTransportDescription(desc.str(), error_desc); + } + } else { + return BadTransportDescription( + "Local fingerprint provided but no identity available.", + error_desc); + } + } else { + certificate_ = nullptr; + } + + if (!channel->SetLocalCertificate(certificate_)) { + return BadTransportDescription("Failed to set local identity.", + error_desc); + } + + // Apply the description in the base class. + return Base::ApplyLocalTransportDescription(channel, error_desc); + } + + bool NegotiateTransportDescription(ContentAction local_role, + std::string* error_desc) override { + if (!Base::local_description() || !Base::remote_description()) { + const std::string msg = "Local and Remote description must be set before " + "transport descriptions are negotiated"; + return BadTransportDescription(msg, error_desc); + } + + rtc::SSLFingerprint* local_fp = + Base::local_description()->identity_fingerprint.get(); + rtc::SSLFingerprint* remote_fp = + Base::remote_description()->identity_fingerprint.get(); + + if (remote_fp && local_fp) { + remote_fingerprint_.reset(new rtc::SSLFingerprint(*remote_fp)); + + // From RFC 4145, section-4.1, The following are the values that the + // 'setup' attribute can take in an offer/answer exchange: + // Offer Answer + // ________________ + // active passive / holdconn + // passive active / holdconn + // actpass active / passive / holdconn + // holdconn holdconn + // + // Set the role that is most conformant with RFC 5763, Section 5, bullet 1 + // The endpoint MUST use the setup attribute defined in [RFC4145]. + // The endpoint that is the offerer MUST use the setup attribute + // value of setup:actpass and be prepared to receive a client_hello + // before it receives the answer. The answerer MUST use either a + // setup attribute value of setup:active or setup:passive. Note that + // if the answerer uses setup:passive, then the DTLS handshake will + // not begin until the answerer is received, which adds additional + // latency. setup:active allows the answer and the DTLS handshake to + // occur in parallel. Thus, setup:active is RECOMMENDED. Whichever + // party is active MUST initiate a DTLS handshake by sending a + // ClientHello over each flow (host/port quartet). + // IOW - actpass and passive modes should be treated as server and + // active as client. + ConnectionRole local_connection_role = + Base::local_description()->connection_role; + ConnectionRole remote_connection_role = + Base::remote_description()->connection_role; + + bool is_remote_server = false; + if (local_role == CA_OFFER) { + if (local_connection_role != CONNECTIONROLE_ACTPASS) { + return BadTransportDescription( + "Offerer must use actpass value for setup attribute.", + error_desc); + } + + if (remote_connection_role == CONNECTIONROLE_ACTIVE || + remote_connection_role == CONNECTIONROLE_PASSIVE || + remote_connection_role == CONNECTIONROLE_NONE) { + is_remote_server = (remote_connection_role == CONNECTIONROLE_PASSIVE); + } else { + const std::string msg = + "Answerer must use either active or passive value " + "for setup attribute."; + return BadTransportDescription(msg, error_desc); + } + // If remote is NONE or ACTIVE it will act as client. + } else { + if (remote_connection_role != CONNECTIONROLE_ACTPASS && + remote_connection_role != CONNECTIONROLE_NONE) { + return BadTransportDescription( + "Offerer must use actpass value for setup attribute.", + error_desc); + } + + if (local_connection_role == CONNECTIONROLE_ACTIVE || + local_connection_role == CONNECTIONROLE_PASSIVE) { + is_remote_server = (local_connection_role == CONNECTIONROLE_ACTIVE); + } else { + const std::string msg = + "Answerer must use either active or passive value " + "for setup attribute."; + return BadTransportDescription(msg, error_desc); + } + + // If local is passive, local will act as server. + } + + secure_role_ = is_remote_server ? rtc::SSL_CLIENT : + rtc::SSL_SERVER; + + } else if (local_fp && (local_role == CA_ANSWER)) { + return BadTransportDescription( + "Local fingerprint supplied when caller didn't offer DTLS.", + error_desc); + } else { + // We are not doing DTLS + remote_fingerprint_.reset(new rtc::SSLFingerprint( + "", NULL, 0)); + } + + // Now run the negotiation for the base class. + return Base::NegotiateTransportDescription(local_role, error_desc); + } + + DtlsTransportChannelWrapper* CreateTransportChannel(int component) override { + DtlsTransportChannelWrapper* channel = new DtlsTransportChannelWrapper( + this, Base::CreateTransportChannel(component)); + channel->SetSslMaxProtocolVersion(ssl_max_version_); + return channel; + } + + void DestroyTransportChannel(TransportChannelImpl* channel) override { + // Kind of ugly, but this lets us do the exact inverse of the create. + DtlsTransportChannelWrapper* dtls_channel = + static_cast<DtlsTransportChannelWrapper*>(channel); + TransportChannelImpl* base_channel = dtls_channel->channel(); + delete dtls_channel; + Base::DestroyTransportChannel(base_channel); + } + + bool GetSslRole(rtc::SSLRole* ssl_role) const override { + ASSERT(ssl_role != NULL); + *ssl_role = secure_role_; + return true; + } + + private: + bool ApplyNegotiatedTransportDescription(TransportChannelImpl* channel, + std::string* error_desc) override { + // Set ssl role. Role must be set before fingerprint is applied, which + // initiates DTLS setup. + if (!channel->SetSslRole(secure_role_)) { + return BadTransportDescription("Failed to set ssl role for the channel.", + error_desc); + } + // Apply remote fingerprint. + if (!channel->SetRemoteFingerprint(remote_fingerprint_->algorithm, + reinterpret_cast<const uint8_t*>( + remote_fingerprint_->digest.data()), + remote_fingerprint_->digest.size())) { + return BadTransportDescription("Failed to apply remote fingerprint.", + error_desc); + } + return Base::ApplyNegotiatedTransportDescription(channel, error_desc); + } + + rtc::scoped_refptr<rtc::RTCCertificate> certificate_; + rtc::SSLRole secure_role_; + rtc::SSLProtocolVersion ssl_max_version_; + rtc::scoped_ptr<rtc::SSLFingerprint> remote_fingerprint_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_DTLSTRANSPORT_H_ diff --git a/webrtc/p2p/base/dtlstransportchannel.cc b/webrtc/p2p/base/dtlstransportchannel.cc new file mode 100644 index 0000000000..0c063e0323 --- /dev/null +++ b/webrtc/p2p/base/dtlstransportchannel.cc @@ -0,0 +1,620 @@ +/* + * Copyright 2011 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/dtlstransportchannel.h" + +#include "webrtc/p2p/base/common.h" +#include "webrtc/base/buffer.h" +#include "webrtc/base/checks.h" +#include "webrtc/base/dscp.h" +#include "webrtc/base/messagequeue.h" +#include "webrtc/base/sslstreamadapter.h" +#include "webrtc/base/stream.h" +#include "webrtc/base/thread.h" + +namespace cricket { + +// We don't pull the RTP constants from rtputils.h, to avoid a layer violation. +static const size_t kDtlsRecordHeaderLen = 13; +static const size_t kMaxDtlsPacketLen = 2048; +static const size_t kMinRtpPacketLen = 12; + +// Maximum number of pending packets in the queue. Packets are read immediately +// after they have been written, so a capacity of "1" is sufficient. +static const size_t kMaxPendingPackets = 1; + +static bool IsDtlsPacket(const char* data, size_t len) { + const uint8_t* u = reinterpret_cast<const uint8_t*>(data); + return (len >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64)); +} +static bool IsRtpPacket(const char* data, size_t len) { + const uint8_t* u = reinterpret_cast<const uint8_t*>(data); + return (len >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80); +} + +StreamInterfaceChannel::StreamInterfaceChannel(TransportChannel* channel) + : channel_(channel), + state_(rtc::SS_OPEN), + packets_(kMaxPendingPackets, kMaxDtlsPacketLen) { +} + +rtc::StreamResult StreamInterfaceChannel::Read(void* buffer, + size_t buffer_len, + size_t* read, + int* error) { + if (state_ == rtc::SS_CLOSED) + return rtc::SR_EOS; + if (state_ == rtc::SS_OPENING) + return rtc::SR_BLOCK; + + if (!packets_.ReadFront(buffer, buffer_len, read)) { + return rtc::SR_BLOCK; + } + + return rtc::SR_SUCCESS; +} + +rtc::StreamResult StreamInterfaceChannel::Write(const void* data, + size_t data_len, + size_t* written, + int* error) { + // Always succeeds, since this is an unreliable transport anyway. + // TODO: Should this block if channel_'s temporarily unwritable? + rtc::PacketOptions packet_options; + channel_->SendPacket(static_cast<const char*>(data), data_len, + packet_options); + if (written) { + *written = data_len; + } + return rtc::SR_SUCCESS; +} + +bool StreamInterfaceChannel::OnPacketReceived(const char* data, size_t size) { + // We force a read event here to ensure that we don't overflow our queue. + bool ret = packets_.WriteBack(data, size, NULL); + RTC_CHECK(ret) << "Failed to write packet to queue."; + if (ret) { + SignalEvent(this, rtc::SE_READ, 0); + } + return ret; +} + +DtlsTransportChannelWrapper::DtlsTransportChannelWrapper( + Transport* transport, + TransportChannelImpl* channel) + : TransportChannelImpl(channel->transport_name(), channel->component()), + transport_(transport), + worker_thread_(rtc::Thread::Current()), + channel_(channel), + downward_(NULL), + ssl_role_(rtc::SSL_CLIENT), + ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_10) { + channel_->SignalWritableState.connect(this, + &DtlsTransportChannelWrapper::OnWritableState); + channel_->SignalReadPacket.connect(this, + &DtlsTransportChannelWrapper::OnReadPacket); + channel_->SignalSentPacket.connect( + this, &DtlsTransportChannelWrapper::OnSentPacket); + channel_->SignalReadyToSend.connect(this, + &DtlsTransportChannelWrapper::OnReadyToSend); + channel_->SignalGatheringState.connect( + this, &DtlsTransportChannelWrapper::OnGatheringState); + channel_->SignalCandidateGathered.connect( + this, &DtlsTransportChannelWrapper::OnCandidateGathered); + channel_->SignalRoleConflict.connect(this, + &DtlsTransportChannelWrapper::OnRoleConflict); + channel_->SignalRouteChange.connect(this, + &DtlsTransportChannelWrapper::OnRouteChange); + channel_->SignalConnectionRemoved.connect(this, + &DtlsTransportChannelWrapper::OnConnectionRemoved); + channel_->SignalReceivingState.connect(this, + &DtlsTransportChannelWrapper::OnReceivingState); +} + +DtlsTransportChannelWrapper::~DtlsTransportChannelWrapper() { +} + +void DtlsTransportChannelWrapper::Connect() { + // We should only get a single call to Connect. + ASSERT(dtls_state() == DTLS_TRANSPORT_NEW); + channel_->Connect(); +} + +bool DtlsTransportChannelWrapper::SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { + if (dtls_active_) { + if (certificate == local_certificate_) { + // This may happen during renegotiation. + LOG_J(LS_INFO, this) << "Ignoring identical DTLS identity"; + return true; + } else { + LOG_J(LS_ERROR, this) << "Can't change DTLS local identity in this state"; + return false; + } + } + + if (certificate) { + local_certificate_ = certificate; + dtls_active_ = true; + } else { + LOG_J(LS_INFO, this) << "NULL DTLS identity supplied. Not doing DTLS"; + } + + return true; +} + +rtc::scoped_refptr<rtc::RTCCertificate> +DtlsTransportChannelWrapper::GetLocalCertificate() const { + return local_certificate_; +} + +bool DtlsTransportChannelWrapper::SetSslMaxProtocolVersion( + rtc::SSLProtocolVersion version) { + if (dtls_active_) { + LOG(LS_ERROR) << "Not changing max. protocol version " + << "while DTLS is negotiating"; + return false; + } + + ssl_max_version_ = version; + return true; +} + +bool DtlsTransportChannelWrapper::SetSslRole(rtc::SSLRole role) { + if (dtls_state() == DTLS_TRANSPORT_CONNECTED) { + if (ssl_role_ != role) { + LOG(LS_ERROR) << "SSL Role can't be reversed after the session is setup."; + return false; + } + return true; + } + + ssl_role_ = role; + return true; +} + +bool DtlsTransportChannelWrapper::GetSslRole(rtc::SSLRole* role) const { + *role = ssl_role_; + return true; +} + +bool DtlsTransportChannelWrapper::GetSslCipherSuite(int* cipher) { + if (dtls_state() != DTLS_TRANSPORT_CONNECTED) { + return false; + } + + return dtls_->GetSslCipherSuite(cipher); +} + +bool DtlsTransportChannelWrapper::SetRemoteFingerprint( + const std::string& digest_alg, + const uint8_t* digest, + size_t digest_len) { + rtc::Buffer remote_fingerprint_value(digest, digest_len); + + if (dtls_active_ && remote_fingerprint_value_ == remote_fingerprint_value && + !digest_alg.empty()) { + // This may happen during renegotiation. + LOG_J(LS_INFO, this) << "Ignoring identical remote DTLS fingerprint"; + return true; + } + + // Allow SetRemoteFingerprint with a NULL digest even if SetLocalCertificate + // hasn't been called. + if (dtls_ || (!dtls_active_ && !digest_alg.empty())) { + LOG_J(LS_ERROR, this) << "Can't set DTLS remote settings in this state."; + return false; + } + + if (digest_alg.empty()) { + LOG_J(LS_INFO, this) << "Other side didn't support DTLS."; + dtls_active_ = false; + return true; + } + + // At this point we know we are doing DTLS + remote_fingerprint_value_ = remote_fingerprint_value.Pass(); + remote_fingerprint_algorithm_ = digest_alg; + + if (!SetupDtls()) { + set_dtls_state(DTLS_TRANSPORT_FAILED); + return false; + } + + return true; +} + +bool DtlsTransportChannelWrapper::GetRemoteSSLCertificate( + rtc::SSLCertificate** cert) const { + if (!dtls_) { + return false; + } + + return dtls_->GetPeerCertificate(cert); +} + +bool DtlsTransportChannelWrapper::SetupDtls() { + StreamInterfaceChannel* downward = new StreamInterfaceChannel(channel_); + + dtls_.reset(rtc::SSLStreamAdapter::Create(downward)); + if (!dtls_) { + LOG_J(LS_ERROR, this) << "Failed to create DTLS adapter."; + delete downward; + return false; + } + + downward_ = downward; + + dtls_->SetIdentity(local_certificate_->identity()->GetReference()); + dtls_->SetMode(rtc::SSL_MODE_DTLS); + dtls_->SetMaxProtocolVersion(ssl_max_version_); + dtls_->SetServerRole(ssl_role_); + dtls_->SignalEvent.connect(this, &DtlsTransportChannelWrapper::OnDtlsEvent); + if (!dtls_->SetPeerCertificateDigest( + remote_fingerprint_algorithm_, + reinterpret_cast<unsigned char*>(remote_fingerprint_value_.data()), + remote_fingerprint_value_.size())) { + LOG_J(LS_ERROR, this) << "Couldn't set DTLS certificate digest."; + return false; + } + + // Set up DTLS-SRTP, if it's been enabled. + if (!srtp_ciphers_.empty()) { + if (!dtls_->SetDtlsSrtpCiphers(srtp_ciphers_)) { + LOG_J(LS_ERROR, this) << "Couldn't set DTLS-SRTP ciphers."; + return false; + } + } else { + LOG_J(LS_INFO, this) << "Not using DTLS-SRTP."; + } + + LOG_J(LS_INFO, this) << "DTLS setup complete."; + return true; +} + +bool DtlsTransportChannelWrapper::SetSrtpCiphers( + const std::vector<std::string>& ciphers) { + if (srtp_ciphers_ == ciphers) { + return true; + } + + if (dtls_state() == DTLS_TRANSPORT_CONNECTING) { + LOG(LS_WARNING) << "Ignoring new SRTP ciphers while DTLS is negotiating"; + return true; + } + + if (dtls_state() == DTLS_TRANSPORT_CONNECTED) { + // We don't support DTLS renegotiation currently. If new set of srtp ciphers + // are different than what's being used currently, we will not use it. + // So for now, let's be happy (or sad) with a warning message. + std::string current_srtp_cipher; + if (!dtls_->GetDtlsSrtpCipher(¤t_srtp_cipher)) { + LOG(LS_ERROR) << "Failed to get the current SRTP cipher for DTLS channel"; + return false; + } + const std::vector<std::string>::const_iterator iter = + std::find(ciphers.begin(), ciphers.end(), current_srtp_cipher); + if (iter == ciphers.end()) { + std::string requested_str; + for (size_t i = 0; i < ciphers.size(); ++i) { + requested_str.append(" "); + requested_str.append(ciphers[i]); + requested_str.append(" "); + } + LOG(LS_WARNING) << "Ignoring new set of SRTP ciphers, as DTLS " + << "renegotiation is not supported currently " + << "current cipher = " << current_srtp_cipher << " and " + << "requested = " << "[" << requested_str << "]"; + } + return true; + } + + if (!VERIFY(dtls_state() == DTLS_TRANSPORT_NEW)) { + return false; + } + + srtp_ciphers_ = ciphers; + return true; +} + +bool DtlsTransportChannelWrapper::GetSrtpCryptoSuite(std::string* cipher) { + if (dtls_state() != DTLS_TRANSPORT_CONNECTED) { + return false; + } + + return dtls_->GetDtlsSrtpCipher(cipher); +} + + +// Called from upper layers to send a media packet. +int DtlsTransportChannelWrapper::SendPacket( + const char* data, size_t size, + const rtc::PacketOptions& options, int flags) { + if (!dtls_active_) { + // Not doing DTLS. + return channel_->SendPacket(data, size, options); + } + + switch (dtls_state()) { + case DTLS_TRANSPORT_NEW: + // Can't send data until the connection is active. + // TODO(ekr@rtfm.com): assert here if dtls_ is NULL? + return -1; + case DTLS_TRANSPORT_CONNECTING: + // Can't send data until the connection is active. + return -1; + case DTLS_TRANSPORT_CONNECTED: + if (flags & PF_SRTP_BYPASS) { + ASSERT(!srtp_ciphers_.empty()); + if (!IsRtpPacket(data, size)) { + return -1; + } + + return channel_->SendPacket(data, size, options); + } else { + return (dtls_->WriteAll(data, size, NULL, NULL) == rtc::SR_SUCCESS) + ? static_cast<int>(size) + : -1; + } + case DTLS_TRANSPORT_FAILED: + case DTLS_TRANSPORT_CLOSED: + // Can't send anything when we're closed. + return -1; + default: + ASSERT(false); + return -1; + } +} + +// The state transition logic here is as follows: +// (1) If we're not doing DTLS-SRTP, then the state is just the +// state of the underlying impl() +// (2) If we're doing DTLS-SRTP: +// - Prior to the DTLS handshake, the state is neither receiving nor +// writable +// - When the impl goes writable for the first time we +// start the DTLS handshake +// - Once the DTLS handshake completes, the state is that of the +// impl again +void DtlsTransportChannelWrapper::OnWritableState(TransportChannel* channel) { + ASSERT(rtc::Thread::Current() == worker_thread_); + ASSERT(channel == channel_); + LOG_J(LS_VERBOSE, this) + << "DTLSTransportChannelWrapper: channel writable state changed to " + << channel_->writable(); + + if (!dtls_active_) { + // Not doing DTLS. + // Note: SignalWritableState fired by set_writable. + set_writable(channel_->writable()); + return; + } + + switch (dtls_state()) { + case DTLS_TRANSPORT_NEW: + // This should never fail: + // Because we are operating in a nonblocking mode and all + // incoming packets come in via OnReadPacket(), which rejects + // packets in this state, the incoming queue must be empty. We + // ignore write errors, thus any errors must be because of + // configuration and therefore are our fault. + // Note that in non-debug configurations, failure in + // MaybeStartDtls() changes the state to DTLS_TRANSPORT_FAILED. + VERIFY(MaybeStartDtls()); + break; + case DTLS_TRANSPORT_CONNECTED: + // Note: SignalWritableState fired by set_writable. + set_writable(channel_->writable()); + break; + case DTLS_TRANSPORT_CONNECTING: + // Do nothing. + break; + case DTLS_TRANSPORT_FAILED: + case DTLS_TRANSPORT_CLOSED: + // Should not happen. Do nothing. + break; + } +} + +void DtlsTransportChannelWrapper::OnReceivingState(TransportChannel* channel) { + ASSERT(rtc::Thread::Current() == worker_thread_); + ASSERT(channel == channel_); + LOG_J(LS_VERBOSE, this) + << "DTLSTransportChannelWrapper: channel receiving state changed to " + << channel_->receiving(); + if (!dtls_active_ || dtls_state() == DTLS_TRANSPORT_CONNECTED) { + // Note: SignalReceivingState fired by set_receiving. + set_receiving(channel_->receiving()); + } +} + +void DtlsTransportChannelWrapper::OnReadPacket( + TransportChannel* channel, const char* data, size_t size, + const rtc::PacketTime& packet_time, int flags) { + ASSERT(rtc::Thread::Current() == worker_thread_); + ASSERT(channel == channel_); + ASSERT(flags == 0); + + if (!dtls_active_) { + // Not doing DTLS. + SignalReadPacket(this, data, size, packet_time, 0); + return; + } + + switch (dtls_state()) { + case DTLS_TRANSPORT_NEW: + if (dtls_) { + // Drop packets received before DTLS has actually started. + LOG_J(LS_INFO, this) << "Dropping packet received before DTLS started."; + } else { + // Currently drop the packet, but we might in future + // decide to take this as evidence that the other + // side is ready to do DTLS and start the handshake + // on our end. + LOG_J(LS_WARNING, this) << "Received packet before we know if we are " + << "doing DTLS or not; dropping."; + } + break; + + case DTLS_TRANSPORT_CONNECTING: + case DTLS_TRANSPORT_CONNECTED: + // We should only get DTLS or SRTP packets; STUN's already been demuxed. + // Is this potentially a DTLS packet? + if (IsDtlsPacket(data, size)) { + if (!HandleDtlsPacket(data, size)) { + LOG_J(LS_ERROR, this) << "Failed to handle DTLS packet."; + return; + } + } else { + // Not a DTLS packet; our handshake should be complete by now. + if (dtls_state() != DTLS_TRANSPORT_CONNECTED) { + LOG_J(LS_ERROR, this) << "Received non-DTLS packet before DTLS " + << "complete."; + return; + } + + // And it had better be a SRTP packet. + if (!IsRtpPacket(data, size)) { + LOG_J(LS_ERROR, this) << "Received unexpected non-DTLS packet."; + return; + } + + // Sanity check. + ASSERT(!srtp_ciphers_.empty()); + + // Signal this upwards as a bypass packet. + SignalReadPacket(this, data, size, packet_time, PF_SRTP_BYPASS); + } + break; + case DTLS_TRANSPORT_FAILED: + case DTLS_TRANSPORT_CLOSED: + // This shouldn't be happening. Drop the packet. + break; + } +} + +void DtlsTransportChannelWrapper::OnSentPacket( + TransportChannel* channel, + const rtc::SentPacket& sent_packet) { + ASSERT(rtc::Thread::Current() == worker_thread_); + + SignalSentPacket(this, sent_packet); +} + +void DtlsTransportChannelWrapper::OnReadyToSend(TransportChannel* channel) { + if (writable()) { + SignalReadyToSend(this); + } +} + +void DtlsTransportChannelWrapper::OnDtlsEvent(rtc::StreamInterface* dtls, + int sig, int err) { + ASSERT(rtc::Thread::Current() == worker_thread_); + ASSERT(dtls == dtls_.get()); + if (sig & rtc::SE_OPEN) { + // This is the first time. + LOG_J(LS_INFO, this) << "DTLS handshake complete."; + if (dtls_->GetState() == rtc::SS_OPEN) { + // The check for OPEN shouldn't be necessary but let's make + // sure we don't accidentally frob the state if it's closed. + set_dtls_state(DTLS_TRANSPORT_CONNECTED); + set_writable(true); + } + } + if (sig & rtc::SE_READ) { + char buf[kMaxDtlsPacketLen]; + size_t read; + if (dtls_->Read(buf, sizeof(buf), &read, NULL) == rtc::SR_SUCCESS) { + SignalReadPacket(this, buf, read, rtc::CreatePacketTime(0), 0); + } + } + if (sig & rtc::SE_CLOSE) { + ASSERT(sig == rtc::SE_CLOSE); // SE_CLOSE should be by itself. + set_writable(false); + if (!err) { + LOG_J(LS_INFO, this) << "DTLS channel closed"; + set_dtls_state(DTLS_TRANSPORT_CLOSED); + } else { + LOG_J(LS_INFO, this) << "DTLS channel error, code=" << err; + set_dtls_state(DTLS_TRANSPORT_FAILED); + } + } +} + +bool DtlsTransportChannelWrapper::MaybeStartDtls() { + if (dtls_ && channel_->writable()) { + if (dtls_->StartSSLWithPeer()) { + LOG_J(LS_ERROR, this) << "Couldn't start DTLS handshake"; + set_dtls_state(DTLS_TRANSPORT_FAILED); + return false; + } + LOG_J(LS_INFO, this) + << "DtlsTransportChannelWrapper: Started DTLS handshake"; + set_dtls_state(DTLS_TRANSPORT_CONNECTING); + } + return true; +} + +// Called from OnReadPacket when a DTLS packet is received. +bool DtlsTransportChannelWrapper::HandleDtlsPacket(const char* data, + size_t size) { + // Sanity check we're not passing junk that + // just looks like DTLS. + const uint8_t* tmp_data = reinterpret_cast<const uint8_t*>(data); + size_t tmp_size = size; + while (tmp_size > 0) { + if (tmp_size < kDtlsRecordHeaderLen) + return false; // Too short for the header + + size_t record_len = (tmp_data[11] << 8) | (tmp_data[12]); + if ((record_len + kDtlsRecordHeaderLen) > tmp_size) + return false; // Body too short + + tmp_data += record_len + kDtlsRecordHeaderLen; + tmp_size -= record_len + kDtlsRecordHeaderLen; + } + + // Looks good. Pass to the SIC which ends up being passed to + // the DTLS stack. + return downward_->OnPacketReceived(data, size); +} + +void DtlsTransportChannelWrapper::OnGatheringState( + TransportChannelImpl* channel) { + ASSERT(channel == channel_); + SignalGatheringState(this); +} + +void DtlsTransportChannelWrapper::OnCandidateGathered( + TransportChannelImpl* channel, + const Candidate& c) { + ASSERT(channel == channel_); + SignalCandidateGathered(this, c); +} + +void DtlsTransportChannelWrapper::OnRoleConflict( + TransportChannelImpl* channel) { + ASSERT(channel == channel_); + SignalRoleConflict(this); +} + +void DtlsTransportChannelWrapper::OnRouteChange( + TransportChannel* channel, const Candidate& candidate) { + ASSERT(channel == channel_); + SignalRouteChange(this, candidate); +} + +void DtlsTransportChannelWrapper::OnConnectionRemoved( + TransportChannelImpl* channel) { + ASSERT(channel == channel_); + SignalConnectionRemoved(this); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/dtlstransportchannel.h b/webrtc/p2p/base/dtlstransportchannel.h new file mode 100644 index 0000000000..41e081b7fe --- /dev/null +++ b/webrtc/p2p/base/dtlstransportchannel.h @@ -0,0 +1,239 @@ +/* + * Copyright 2011 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_DTLSTRANSPORTCHANNEL_H_ +#define WEBRTC_P2P_BASE_DTLSTRANSPORTCHANNEL_H_ + +#include <string> +#include <vector> + +#include "webrtc/p2p/base/transportchannelimpl.h" +#include "webrtc/base/buffer.h" +#include "webrtc/base/bufferqueue.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/sslstreamadapter.h" +#include "webrtc/base/stream.h" + +namespace cricket { + +// A bridge between a packet-oriented/channel-type interface on +// the bottom and a StreamInterface on the top. +class StreamInterfaceChannel : public rtc::StreamInterface { + public: + explicit StreamInterfaceChannel(TransportChannel* channel); + + // Push in a packet; this gets pulled out from Read(). + bool OnPacketReceived(const char* data, size_t size); + + // Implementations of StreamInterface + rtc::StreamState GetState() const override { return state_; } + void Close() override { state_ = rtc::SS_CLOSED; } + rtc::StreamResult Read(void* buffer, + size_t buffer_len, + size_t* read, + int* error) override; + rtc::StreamResult Write(const void* data, + size_t data_len, + size_t* written, + int* error) override; + + private: + TransportChannel* channel_; // owned by DtlsTransportChannelWrapper + rtc::StreamState state_; + rtc::BufferQueue packets_; + + RTC_DISALLOW_COPY_AND_ASSIGN(StreamInterfaceChannel); +}; + + +// This class provides a DTLS SSLStreamAdapter inside a TransportChannel-style +// packet-based interface, wrapping an existing TransportChannel instance +// (e.g a P2PTransportChannel) +// Here's the way this works: +// +// DtlsTransportChannelWrapper { +// SSLStreamAdapter* dtls_ { +// StreamInterfaceChannel downward_ { +// TransportChannelImpl* channel_; +// } +// } +// } +// +// - Data which comes into DtlsTransportChannelWrapper from the underlying +// channel_ via OnReadPacket() is checked for whether it is DTLS +// or not, and if it is, is passed to DtlsTransportChannelWrapper:: +// HandleDtlsPacket, which pushes it into to downward_. +// dtls_ is listening for events on downward_, so it immediately calls +// downward_->Read(). +// +// - Data written to DtlsTransportChannelWrapper is passed either to +// downward_ or directly to channel_, depending on whether DTLS is +// negotiated and whether the flags include PF_SRTP_BYPASS +// +// - The SSLStreamAdapter writes to downward_->Write() +// which translates it into packet writes on channel_. +class DtlsTransportChannelWrapper : public TransportChannelImpl { + public: + // The parameters here are: + // transport -- the DtlsTransport that created us + // channel -- the TransportChannel we are wrapping + DtlsTransportChannelWrapper(Transport* transport, + TransportChannelImpl* channel); + ~DtlsTransportChannelWrapper() override; + + void SetIceRole(IceRole role) override { channel_->SetIceRole(role); } + IceRole GetIceRole() const override { return channel_->GetIceRole(); } + bool SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override; + rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const override; + + bool SetRemoteFingerprint(const std::string& digest_alg, + const uint8_t* digest, + size_t digest_len) override; + + // Returns false if no local certificate was set, or if the peer doesn't + // support DTLS. + bool IsDtlsActive() const override { return dtls_active_; } + + // Called to send a packet (via DTLS, if turned on). + int SendPacket(const char* data, + size_t size, + const rtc::PacketOptions& options, + int flags) override; + + // TransportChannel calls that we forward to the wrapped transport. + int SetOption(rtc::Socket::Option opt, int value) override { + return channel_->SetOption(opt, value); + } + bool GetOption(rtc::Socket::Option opt, int* value) override { + return channel_->GetOption(opt, value); + } + int GetError() override { return channel_->GetError(); } + bool GetStats(ConnectionInfos* infos) override { + return channel_->GetStats(infos); + } + const std::string SessionId() const override { return channel_->SessionId(); } + + virtual bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version); + + // Set up the ciphers to use for DTLS-SRTP. If this method is not called + // before DTLS starts, or |ciphers| is empty, SRTP keys won't be negotiated. + // This method should be called before SetupDtls. + bool SetSrtpCiphers(const std::vector<std::string>& ciphers) override; + + // Find out which DTLS-SRTP cipher was negotiated + bool GetSrtpCryptoSuite(std::string* cipher) override; + + bool GetSslRole(rtc::SSLRole* role) const override; + bool SetSslRole(rtc::SSLRole role) override; + + // Find out which DTLS cipher was negotiated + bool GetSslCipherSuite(int* cipher) override; + + // Once DTLS has been established, this method retrieves the certificate in + // use by the remote peer, for use in external identity verification. + bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override; + + // Once DTLS has established (i.e., this channel is writable), this method + // extracts the keys negotiated during the DTLS handshake, for use in external + // encryption. DTLS-SRTP uses this to extract the needed SRTP keys. + // See the SSLStreamAdapter documentation for info on the specific parameters. + bool ExportKeyingMaterial(const std::string& label, + const uint8_t* context, + size_t context_len, + bool use_context, + uint8_t* result, + size_t result_len) override { + return (dtls_.get()) ? dtls_->ExportKeyingMaterial(label, context, + context_len, + use_context, + result, result_len) + : false; + } + + // TransportChannelImpl calls. + Transport* GetTransport() override { return transport_; } + + TransportChannelState GetState() const override { + return channel_->GetState(); + } + void SetIceTiebreaker(uint64_t tiebreaker) override { + channel_->SetIceTiebreaker(tiebreaker); + } + void SetIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) override { + channel_->SetIceCredentials(ice_ufrag, ice_pwd); + } + void SetRemoteIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) override { + channel_->SetRemoteIceCredentials(ice_ufrag, ice_pwd); + } + void SetRemoteIceMode(IceMode mode) override { + channel_->SetRemoteIceMode(mode); + } + + void Connect() override; + + void MaybeStartGathering() override { channel_->MaybeStartGathering(); } + + IceGatheringState gathering_state() const override { + return channel_->gathering_state(); + } + + void AddRemoteCandidate(const Candidate& candidate) override { + channel_->AddRemoteCandidate(candidate); + } + + void SetIceConfig(const IceConfig& config) override { + channel_->SetIceConfig(config); + } + + // Needed by DtlsTransport. + TransportChannelImpl* channel() { return channel_; } + + private: + void OnReadableState(TransportChannel* channel); + void OnWritableState(TransportChannel* channel); + void OnReadPacket(TransportChannel* channel, const char* data, size_t size, + const rtc::PacketTime& packet_time, int flags); + void OnSentPacket(TransportChannel* channel, + const rtc::SentPacket& sent_packet); + void OnReadyToSend(TransportChannel* channel); + void OnReceivingState(TransportChannel* channel); + void OnDtlsEvent(rtc::StreamInterface* stream_, int sig, int err); + bool SetupDtls(); + bool MaybeStartDtls(); + bool HandleDtlsPacket(const char* data, size_t size); + void OnGatheringState(TransportChannelImpl* channel); + void OnCandidateGathered(TransportChannelImpl* channel, const Candidate& c); + void OnRoleConflict(TransportChannelImpl* channel); + void OnRouteChange(TransportChannel* channel, const Candidate& candidate); + void OnConnectionRemoved(TransportChannelImpl* channel); + + Transport* transport_; // The transport_ that created us. + rtc::Thread* worker_thread_; // Everything should occur on this thread. + // Underlying channel, owned by transport_. + TransportChannelImpl* const channel_; + rtc::scoped_ptr<rtc::SSLStreamAdapter> dtls_; // The DTLS stream + StreamInterfaceChannel* downward_; // Wrapper for channel_, owned by dtls_. + std::vector<std::string> srtp_ciphers_; // SRTP ciphers to use with DTLS. + bool dtls_active_ = false; + rtc::scoped_refptr<rtc::RTCCertificate> local_certificate_; + rtc::SSLRole ssl_role_; + rtc::SSLProtocolVersion ssl_max_version_; + rtc::Buffer remote_fingerprint_value_; + std::string remote_fingerprint_algorithm_; + + RTC_DISALLOW_COPY_AND_ASSIGN(DtlsTransportChannelWrapper); +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_DTLSTRANSPORTCHANNEL_H_ diff --git a/webrtc/p2p/base/dtlstransportchannel_unittest.cc b/webrtc/p2p/base/dtlstransportchannel_unittest.cc new file mode 100644 index 0000000000..07e3b87847 --- /dev/null +++ b/webrtc/p2p/base/dtlstransportchannel_unittest.cc @@ -0,0 +1,894 @@ +/* + * Copyright 2011 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <set> + +#include "webrtc/p2p/base/dtlstransport.h" +#include "webrtc/p2p/base/faketransportcontroller.h" +#include "webrtc/base/common.h" +#include "webrtc/base/dscp.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/ssladapter.h" +#include "webrtc/base/sslidentity.h" +#include "webrtc/base/sslstreamadapter.h" +#include "webrtc/base/stringutils.h" + +#define MAYBE_SKIP_TEST(feature) \ + if (!(rtc::SSLStreamAdapter::feature())) { \ + LOG(LS_INFO) << "Feature disabled... skipping"; \ + return; \ + } + +static const char AES_CM_128_HMAC_SHA1_80[] = "AES_CM_128_HMAC_SHA1_80"; +static const char kIceUfrag1[] = "TESTICEUFRAG0001"; +static const char kIcePwd1[] = "TESTICEPWD00000000000001"; +static const size_t kPacketNumOffset = 8; +static const size_t kPacketHeaderLen = 12; +static const int kFakePacketId = 0x1234; + +static bool IsRtpLeadByte(uint8_t b) { + return ((b & 0xC0) == 0x80); +} + +using cricket::ConnectionRole; + +enum Flags { NF_REOFFER = 0x1, NF_EXPECT_FAILURE = 0x2 }; + +class DtlsTestClient : public sigslot::has_slots<> { + public: + DtlsTestClient(const std::string& name) + : name_(name), + packet_size_(0), + use_dtls_srtp_(false), + ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_10), + negotiated_dtls_(false), + received_dtls_client_hello_(false), + received_dtls_server_hello_(false) {} + void CreateCertificate(rtc::KeyType key_type) { + certificate_ = rtc::RTCCertificate::Create( + rtc::scoped_ptr<rtc::SSLIdentity>( + rtc::SSLIdentity::Generate(name_, key_type)).Pass()); + } + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate() { + return certificate_; + } + void SetupSrtp() { + ASSERT(certificate_); + use_dtls_srtp_ = true; + } + void SetupMaxProtocolVersion(rtc::SSLProtocolVersion version) { + ASSERT(!transport_); + ssl_max_version_ = version; + } + void SetupChannels(int count, cricket::IceRole role) { + transport_.reset(new cricket::DtlsTransport<cricket::FakeTransport>( + "dtls content name", nullptr, certificate_)); + transport_->SetAsync(true); + transport_->SetIceRole(role); + transport_->SetIceTiebreaker( + (role == cricket::ICEROLE_CONTROLLING) ? 1 : 2); + + for (int i = 0; i < count; ++i) { + cricket::DtlsTransportChannelWrapper* channel = + static_cast<cricket::DtlsTransportChannelWrapper*>( + transport_->CreateChannel(i)); + ASSERT_TRUE(channel != NULL); + channel->SetSslMaxProtocolVersion(ssl_max_version_); + channel->SignalWritableState.connect(this, + &DtlsTestClient::OnTransportChannelWritableState); + channel->SignalReadPacket.connect(this, + &DtlsTestClient::OnTransportChannelReadPacket); + channel->SignalSentPacket.connect( + this, &DtlsTestClient::OnTransportChannelSentPacket); + channels_.push_back(channel); + + // Hook the raw packets so that we can verify they are encrypted. + channel->channel()->SignalReadPacket.connect( + this, &DtlsTestClient::OnFakeTransportChannelReadPacket); + } + } + + cricket::Transport* transport() { return transport_.get(); } + + cricket::FakeTransportChannel* GetFakeChannel(int component) { + cricket::TransportChannelImpl* ch = transport_->GetChannel(component); + cricket::DtlsTransportChannelWrapper* wrapper = + static_cast<cricket::DtlsTransportChannelWrapper*>(ch); + return (wrapper) ? + static_cast<cricket::FakeTransportChannel*>(wrapper->channel()) : NULL; + } + + // Offer DTLS if we have an identity; pass in a remote fingerprint only if + // both sides support DTLS. + void Negotiate(DtlsTestClient* peer, cricket::ContentAction action, + ConnectionRole local_role, ConnectionRole remote_role, + int flags) { + Negotiate(certificate_, certificate_ ? peer->certificate_ : nullptr, action, + local_role, remote_role, flags); + } + + // Allow any DTLS configuration to be specified (including invalid ones). + void Negotiate(const rtc::scoped_refptr<rtc::RTCCertificate>& local_cert, + const rtc::scoped_refptr<rtc::RTCCertificate>& remote_cert, + cricket::ContentAction action, + ConnectionRole local_role, + ConnectionRole remote_role, + int flags) { + rtc::scoped_ptr<rtc::SSLFingerprint> local_fingerprint; + rtc::scoped_ptr<rtc::SSLFingerprint> remote_fingerprint; + if (local_cert) { + std::string digest_algorithm; + ASSERT_TRUE(local_cert->ssl_certificate().GetSignatureDigestAlgorithm( + &digest_algorithm)); + ASSERT_FALSE(digest_algorithm.empty()); + local_fingerprint.reset(rtc::SSLFingerprint::Create( + digest_algorithm, local_cert->identity())); + ASSERT_TRUE(local_fingerprint.get() != NULL); + EXPECT_EQ(rtc::DIGEST_SHA_256, digest_algorithm); + } + if (remote_cert) { + std::string digest_algorithm; + ASSERT_TRUE(remote_cert->ssl_certificate().GetSignatureDigestAlgorithm( + &digest_algorithm)); + ASSERT_FALSE(digest_algorithm.empty()); + remote_fingerprint.reset(rtc::SSLFingerprint::Create( + digest_algorithm, remote_cert->identity())); + ASSERT_TRUE(remote_fingerprint.get() != NULL); + EXPECT_EQ(rtc::DIGEST_SHA_256, digest_algorithm); + } + + if (use_dtls_srtp_ && !(flags & NF_REOFFER)) { + // SRTP ciphers will be set only in the beginning. + for (std::vector<cricket::DtlsTransportChannelWrapper*>::iterator it = + channels_.begin(); it != channels_.end(); ++it) { + std::vector<std::string> ciphers; + ciphers.push_back(AES_CM_128_HMAC_SHA1_80); + ASSERT_TRUE((*it)->SetSrtpCiphers(ciphers)); + } + } + + cricket::TransportDescription local_desc( + std::vector<std::string>(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, + local_role, + // If remote if the offerer and has no DTLS support, answer will be + // without any fingerprint. + (action == cricket::CA_ANSWER && !remote_cert) + ? nullptr + : local_fingerprint.get(), + cricket::Candidates()); + + cricket::TransportDescription remote_desc( + std::vector<std::string>(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, + remote_role, remote_fingerprint.get(), cricket::Candidates()); + + bool expect_success = (flags & NF_EXPECT_FAILURE) ? false : true; + // If |expect_success| is false, expect SRTD or SLTD to fail when + // content action is CA_ANSWER. + if (action == cricket::CA_OFFER) { + ASSERT_TRUE(transport_->SetLocalTransportDescription( + local_desc, cricket::CA_OFFER, NULL)); + ASSERT_EQ(expect_success, transport_->SetRemoteTransportDescription( + remote_desc, cricket::CA_ANSWER, NULL)); + } else { + ASSERT_TRUE(transport_->SetRemoteTransportDescription( + remote_desc, cricket::CA_OFFER, NULL)); + ASSERT_EQ(expect_success, transport_->SetLocalTransportDescription( + local_desc, cricket::CA_ANSWER, NULL)); + } + negotiated_dtls_ = (local_cert && remote_cert); + } + + bool Connect(DtlsTestClient* peer) { + transport_->ConnectChannels(); + transport_->SetDestination(peer->transport_.get()); + return true; + } + + bool all_channels_writable() const { + if (channels_.empty()) { + return false; + } + for (cricket::DtlsTransportChannelWrapper* channel : channels_) { + if (!channel->writable()) { + return false; + } + } + return true; + } + + void CheckRole(rtc::SSLRole role) { + if (role == rtc::SSL_CLIENT) { + ASSERT_FALSE(received_dtls_client_hello_); + ASSERT_TRUE(received_dtls_server_hello_); + } else { + ASSERT_TRUE(received_dtls_client_hello_); + ASSERT_FALSE(received_dtls_server_hello_); + } + } + + void CheckSrtp(const std::string& expected_cipher) { + for (std::vector<cricket::DtlsTransportChannelWrapper*>::iterator it = + channels_.begin(); it != channels_.end(); ++it) { + std::string cipher; + + bool rv = (*it)->GetSrtpCryptoSuite(&cipher); + if (negotiated_dtls_ && !expected_cipher.empty()) { + ASSERT_TRUE(rv); + + ASSERT_EQ(cipher, expected_cipher); + } else { + ASSERT_FALSE(rv); + } + } + } + + void CheckSsl(int expected_cipher) { + for (std::vector<cricket::DtlsTransportChannelWrapper*>::iterator it = + channels_.begin(); it != channels_.end(); ++it) { + int cipher; + + bool rv = (*it)->GetSslCipherSuite(&cipher); + if (negotiated_dtls_ && expected_cipher) { + ASSERT_TRUE(rv); + + ASSERT_EQ(cipher, expected_cipher); + } else { + ASSERT_FALSE(rv); + } + } + } + + void SendPackets(size_t channel, size_t size, size_t count, bool srtp) { + ASSERT(channel < channels_.size()); + rtc::scoped_ptr<char[]> packet(new char[size]); + size_t sent = 0; + do { + // Fill the packet with a known value and a sequence number to check + // against, and make sure that it doesn't look like DTLS. + memset(packet.get(), sent & 0xff, size); + packet[0] = (srtp) ? 0x80 : 0x00; + rtc::SetBE32(packet.get() + kPacketNumOffset, + static_cast<uint32_t>(sent)); + + // Only set the bypass flag if we've activated DTLS. + int flags = (certificate_ && srtp) ? cricket::PF_SRTP_BYPASS : 0; + rtc::PacketOptions packet_options; + packet_options.packet_id = kFakePacketId; + int rv = channels_[channel]->SendPacket( + packet.get(), size, packet_options, flags); + ASSERT_GT(rv, 0); + ASSERT_EQ(size, static_cast<size_t>(rv)); + ++sent; + } while (sent < count); + } + + int SendInvalidSrtpPacket(size_t channel, size_t size) { + ASSERT(channel < channels_.size()); + rtc::scoped_ptr<char[]> packet(new char[size]); + // Fill the packet with 0 to form an invalid SRTP packet. + memset(packet.get(), 0, size); + + rtc::PacketOptions packet_options; + return channels_[channel]->SendPacket( + packet.get(), size, packet_options, cricket::PF_SRTP_BYPASS); + } + + void ExpectPackets(size_t channel, size_t size) { + packet_size_ = size; + received_.clear(); + } + + size_t NumPacketsReceived() { + return received_.size(); + } + + bool VerifyPacket(const char* data, size_t size, uint32_t* out_num) { + if (size != packet_size_ || + (data[0] != 0 && static_cast<uint8_t>(data[0]) != 0x80)) { + return false; + } + uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset); + for (size_t i = kPacketHeaderLen; i < size; ++i) { + if (static_cast<uint8_t>(data[i]) != (packet_num & 0xff)) { + return false; + } + } + if (out_num) { + *out_num = packet_num; + } + return true; + } + bool VerifyEncryptedPacket(const char* data, size_t size) { + // This is an encrypted data packet; let's make sure it's mostly random; + // less than 10% of the bytes should be equal to the cleartext packet. + if (size <= packet_size_) { + return false; + } + uint32_t packet_num = rtc::GetBE32(data + kPacketNumOffset); + int num_matches = 0; + for (size_t i = kPacketNumOffset; i < size; ++i) { + if (static_cast<uint8_t>(data[i]) == (packet_num & 0xff)) { + ++num_matches; + } + } + return (num_matches < ((static_cast<int>(size) - 5) / 10)); + } + + // Transport channel callbacks + void OnTransportChannelWritableState(cricket::TransportChannel* channel) { + LOG(LS_INFO) << name_ << ": Channel '" << channel->component() + << "' is writable"; + } + + void OnTransportChannelReadPacket(cricket::TransportChannel* channel, + const char* data, size_t size, + const rtc::PacketTime& packet_time, + int flags) { + uint32_t packet_num = 0; + ASSERT_TRUE(VerifyPacket(data, size, &packet_num)); + received_.insert(packet_num); + // Only DTLS-SRTP packets should have the bypass flag set. + int expected_flags = + (certificate_ && IsRtpLeadByte(data[0])) ? cricket::PF_SRTP_BYPASS : 0; + ASSERT_EQ(expected_flags, flags); + } + + void OnTransportChannelSentPacket(cricket::TransportChannel* channel, + const rtc::SentPacket& sent_packet) { + sent_packet_ = sent_packet; + } + + rtc::SentPacket sent_packet() const { return sent_packet_; } + + // Hook into the raw packet stream to make sure DTLS packets are encrypted. + void OnFakeTransportChannelReadPacket(cricket::TransportChannel* channel, + const char* data, size_t size, + const rtc::PacketTime& time, + int flags) { + // Flags shouldn't be set on the underlying TransportChannel packets. + ASSERT_EQ(0, flags); + + // Look at the handshake packets to see what role we played. + // Check that non-handshake packets are DTLS data or SRTP bypass. + if (negotiated_dtls_) { + if (data[0] == 22 && size > 17) { + if (data[13] == 1) { + received_dtls_client_hello_ = true; + } else if (data[13] == 2) { + received_dtls_server_hello_ = true; + } + } else if (!(data[0] >= 20 && data[0] <= 22)) { + ASSERT_TRUE(data[0] == 23 || IsRtpLeadByte(data[0])); + if (data[0] == 23) { + ASSERT_TRUE(VerifyEncryptedPacket(data, size)); + } else if (IsRtpLeadByte(data[0])) { + ASSERT_TRUE(VerifyPacket(data, size, NULL)); + } + } + } + } + + private: + std::string name_; + rtc::scoped_refptr<rtc::RTCCertificate> certificate_; + rtc::scoped_ptr<cricket::FakeTransport> transport_; + std::vector<cricket::DtlsTransportChannelWrapper*> channels_; + size_t packet_size_; + std::set<int> received_; + bool use_dtls_srtp_; + rtc::SSLProtocolVersion ssl_max_version_; + bool negotiated_dtls_; + bool received_dtls_client_hello_; + bool received_dtls_server_hello_; + rtc::SentPacket sent_packet_; +}; + + +class DtlsTransportChannelTest : public testing::Test { + public: + DtlsTransportChannelTest() + : client1_("P1"), + client2_("P2"), + channel_ct_(1), + use_dtls_(false), + use_dtls_srtp_(false), + ssl_expected_version_(rtc::SSL_PROTOCOL_DTLS_10) {} + + void SetChannelCount(size_t channel_ct) { + channel_ct_ = static_cast<int>(channel_ct); + } + void SetMaxProtocolVersions(rtc::SSLProtocolVersion c1, + rtc::SSLProtocolVersion c2) { + client1_.SetupMaxProtocolVersion(c1); + client2_.SetupMaxProtocolVersion(c2); + ssl_expected_version_ = std::min(c1, c2); + } + void PrepareDtls(bool c1, bool c2, rtc::KeyType key_type) { + if (c1) { + client1_.CreateCertificate(key_type); + } + if (c2) { + client2_.CreateCertificate(key_type); + } + if (c1 && c2) + use_dtls_ = true; + } + void PrepareDtlsSrtp(bool c1, bool c2) { + if (!use_dtls_) + return; + + if (c1) + client1_.SetupSrtp(); + if (c2) + client2_.SetupSrtp(); + + if (c1 && c2) + use_dtls_srtp_ = true; + } + + bool Connect(ConnectionRole client1_role, ConnectionRole client2_role) { + Negotiate(client1_role, client2_role); + + bool rv = client1_.Connect(&client2_); + EXPECT_TRUE(rv); + if (!rv) + return false; + + EXPECT_TRUE_WAIT( + client1_.all_channels_writable() && client2_.all_channels_writable(), + 10000); + if (!client1_.all_channels_writable() || !client2_.all_channels_writable()) + return false; + + // Check that we used the right roles. + if (use_dtls_) { + rtc::SSLRole client1_ssl_role = + (client1_role == cricket::CONNECTIONROLE_ACTIVE || + (client2_role == cricket::CONNECTIONROLE_PASSIVE && + client1_role == cricket::CONNECTIONROLE_ACTPASS)) ? + rtc::SSL_CLIENT : rtc::SSL_SERVER; + + rtc::SSLRole client2_ssl_role = + (client2_role == cricket::CONNECTIONROLE_ACTIVE || + (client1_role == cricket::CONNECTIONROLE_PASSIVE && + client2_role == cricket::CONNECTIONROLE_ACTPASS)) ? + rtc::SSL_CLIENT : rtc::SSL_SERVER; + + client1_.CheckRole(client1_ssl_role); + client2_.CheckRole(client2_ssl_role); + } + + // Check that we negotiated the right ciphers. + if (use_dtls_srtp_) { + client1_.CheckSrtp(AES_CM_128_HMAC_SHA1_80); + client2_.CheckSrtp(AES_CM_128_HMAC_SHA1_80); + } else { + client1_.CheckSrtp(""); + client2_.CheckSrtp(""); + } + client1_.CheckSsl(rtc::SSLStreamAdapter::GetDefaultSslCipherForTest( + ssl_expected_version_, rtc::KT_DEFAULT)); + client2_.CheckSsl(rtc::SSLStreamAdapter::GetDefaultSslCipherForTest( + ssl_expected_version_, rtc::KT_DEFAULT)); + + return true; + } + + bool Connect() { + // By default, Client1 will be Server and Client2 will be Client. + return Connect(cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_ACTIVE); + } + + void Negotiate() { + Negotiate(cricket::CONNECTIONROLE_ACTPASS, cricket::CONNECTIONROLE_ACTIVE); + } + + void Negotiate(ConnectionRole client1_role, ConnectionRole client2_role) { + client1_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLING); + client2_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLED); + // Expect success from SLTD and SRTD. + client1_.Negotiate(&client2_, cricket::CA_OFFER, + client1_role, client2_role, 0); + client2_.Negotiate(&client1_, cricket::CA_ANSWER, + client2_role, client1_role, 0); + } + + // Negotiate with legacy client |client2|. Legacy client doesn't use setup + // attributes, except NONE. + void NegotiateWithLegacy() { + client1_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLING); + client2_.SetupChannels(channel_ct_, cricket::ICEROLE_CONTROLLED); + // Expect success from SLTD and SRTD. + client1_.Negotiate(&client2_, cricket::CA_OFFER, + cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_NONE, 0); + client2_.Negotiate(&client1_, cricket::CA_ANSWER, + cricket::CONNECTIONROLE_ACTIVE, + cricket::CONNECTIONROLE_NONE, 0); + } + + void Renegotiate(DtlsTestClient* reoffer_initiator, + ConnectionRole client1_role, ConnectionRole client2_role, + int flags) { + if (reoffer_initiator == &client1_) { + client1_.Negotiate(&client2_, cricket::CA_OFFER, + client1_role, client2_role, flags); + client2_.Negotiate(&client1_, cricket::CA_ANSWER, + client2_role, client1_role, flags); + } else { + client2_.Negotiate(&client1_, cricket::CA_OFFER, + client2_role, client1_role, flags); + client1_.Negotiate(&client2_, cricket::CA_ANSWER, + client1_role, client2_role, flags); + } + } + + void TestTransfer(size_t channel, size_t size, size_t count, bool srtp) { + LOG(LS_INFO) << "Expect packets, size=" << size; + client2_.ExpectPackets(channel, size); + client1_.SendPackets(channel, size, count, srtp); + EXPECT_EQ_WAIT(count, client2_.NumPacketsReceived(), 10000); + } + + protected: + DtlsTestClient client1_; + DtlsTestClient client2_; + int channel_ct_; + bool use_dtls_; + bool use_dtls_srtp_; + rtc::SSLProtocolVersion ssl_expected_version_; +}; + +// Test that transport negotiation of ICE, no DTLS works properly. +TEST_F(DtlsTransportChannelTest, TestChannelSetupIce) { + Negotiate(); + cricket::FakeTransportChannel* channel1 = client1_.GetFakeChannel(0); + cricket::FakeTransportChannel* channel2 = client2_.GetFakeChannel(0); + ASSERT_TRUE(channel1 != NULL); + ASSERT_TRUE(channel2 != NULL); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel1->GetIceRole()); + EXPECT_EQ(1U, channel1->IceTiebreaker()); + EXPECT_EQ(kIceUfrag1, channel1->ice_ufrag()); + EXPECT_EQ(kIcePwd1, channel1->ice_pwd()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, channel2->GetIceRole()); + EXPECT_EQ(2U, channel2->IceTiebreaker()); +} + +// Connect without DTLS, and transfer some data. +TEST_F(DtlsTransportChannelTest, TestTransfer) { + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, false); +} + +// Connect without DTLS, and transfer some data. +TEST_F(DtlsTransportChannelTest, TestOnSentPacket) { + ASSERT_TRUE(Connect()); + EXPECT_EQ(client1_.sent_packet().send_time_ms, -1); + TestTransfer(0, 1000, 100, false); + EXPECT_EQ(kFakePacketId, client1_.sent_packet().packet_id); + EXPECT_GE(client1_.sent_packet().send_time_ms, 0); +} + +// Create two channels without DTLS, and transfer some data. +TEST_F(DtlsTransportChannelTest, TestTransferTwoChannels) { + SetChannelCount(2); + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, false); + TestTransfer(1, 1000, 100, false); +} + +// Connect without DTLS, and transfer SRTP data. +TEST_F(DtlsTransportChannelTest, TestTransferSrtp) { + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, true); +} + +// Create two channels without DTLS, and transfer SRTP data. +TEST_F(DtlsTransportChannelTest, TestTransferSrtpTwoChannels) { + SetChannelCount(2); + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, true); + TestTransfer(1, 1000, 100, true); +} + +// Connect with DTLS, and transfer some data. +TEST_F(DtlsTransportChannelTest, TestTransferDtls) { + MAYBE_SKIP_TEST(HaveDtls); + PrepareDtls(true, true, rtc::KT_DEFAULT); + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, false); +} + +// Create two channels with DTLS, and transfer some data. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsTwoChannels) { + MAYBE_SKIP_TEST(HaveDtls); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, false); + TestTransfer(1, 1000, 100, false); +} + +// Connect with A doing DTLS and B not, and transfer some data. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsRejected) { + PrepareDtls(true, false, rtc::KT_DEFAULT); + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, false); +} + +// Connect with B doing DTLS and A not, and transfer some data. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsNotOffered) { + PrepareDtls(false, true, rtc::KT_DEFAULT); + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, false); +} + +// Create two channels with DTLS 1.0 and check ciphers. +TEST_F(DtlsTransportChannelTest, TestDtls12None) { + MAYBE_SKIP_TEST(HaveDtls); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + SetMaxProtocolVersions(rtc::SSL_PROTOCOL_DTLS_10, rtc::SSL_PROTOCOL_DTLS_10); + ASSERT_TRUE(Connect()); +} + +// Create two channels with DTLS 1.2 and check ciphers. +TEST_F(DtlsTransportChannelTest, TestDtls12Both) { + MAYBE_SKIP_TEST(HaveDtls); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + SetMaxProtocolVersions(rtc::SSL_PROTOCOL_DTLS_12, rtc::SSL_PROTOCOL_DTLS_12); + ASSERT_TRUE(Connect()); +} + +// Create two channels with DTLS 1.0 / DTLS 1.2 and check ciphers. +TEST_F(DtlsTransportChannelTest, TestDtls12Client1) { + MAYBE_SKIP_TEST(HaveDtls); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + SetMaxProtocolVersions(rtc::SSL_PROTOCOL_DTLS_12, rtc::SSL_PROTOCOL_DTLS_10); + ASSERT_TRUE(Connect()); +} + +// Create two channels with DTLS 1.2 / DTLS 1.0 and check ciphers. +TEST_F(DtlsTransportChannelTest, TestDtls12Client2) { + MAYBE_SKIP_TEST(HaveDtls); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + SetMaxProtocolVersions(rtc::SSL_PROTOCOL_DTLS_10, rtc::SSL_PROTOCOL_DTLS_12); + ASSERT_TRUE(Connect()); +} + +// Connect with DTLS, negotiate DTLS-SRTP, and transfer SRTP using bypass. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsSrtp) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, true); +} + +// Connect with DTLS-SRTP, transfer an invalid SRTP packet, and expects -1 +// returned. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsInvalidSrtpPacket) { + MAYBE_SKIP_TEST(HaveDtls); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + ASSERT_TRUE(Connect()); + int result = client1_.SendInvalidSrtpPacket(0, 100); + ASSERT_EQ(-1, result); +} + +// Connect with DTLS. A does DTLS-SRTP but B does not. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsSrtpRejected) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, false); + ASSERT_TRUE(Connect()); +} + +// Connect with DTLS. B does DTLS-SRTP but A does not. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsSrtpNotOffered) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(false, true); + ASSERT_TRUE(Connect()); +} + +// Create two channels with DTLS, negotiate DTLS-SRTP, and transfer bypass SRTP. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsSrtpTwoChannels) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, true); + TestTransfer(1, 1000, 100, true); +} + +// Create a single channel with DTLS, and send normal data and SRTP data on it. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsSrtpDemux) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + ASSERT_TRUE(Connect()); + TestTransfer(0, 1000, 100, false); + TestTransfer(0, 1000, 100, true); +} + +// Testing when the remote is passive. +TEST_F(DtlsTransportChannelTest, TestTransferDtlsAnswererIsPassive) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + ASSERT_TRUE(Connect(cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_PASSIVE)); + TestTransfer(0, 1000, 100, true); + TestTransfer(1, 1000, 100, true); +} + +// Testing with the legacy DTLS client which doesn't use setup attribute. +// In this case legacy is the answerer. +TEST_F(DtlsTransportChannelTest, TestDtlsSetupWithLegacyAsAnswerer) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + PrepareDtls(true, true, rtc::KT_DEFAULT); + NegotiateWithLegacy(); + rtc::SSLRole channel1_role; + rtc::SSLRole channel2_role; + EXPECT_TRUE(client1_.transport()->GetSslRole(&channel1_role)); + EXPECT_TRUE(client2_.transport()->GetSslRole(&channel2_role)); + EXPECT_EQ(rtc::SSL_SERVER, channel1_role); + EXPECT_EQ(rtc::SSL_CLIENT, channel2_role); +} + +// Testing re offer/answer after the session is estbalished. Roles will be +// kept same as of the previous negotiation. +TEST_F(DtlsTransportChannelTest, TestDtlsReOfferFromOfferer) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + // Initial role for client1 is ACTPASS and client2 is ACTIVE. + ASSERT_TRUE(Connect(cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_ACTIVE)); + TestTransfer(0, 1000, 100, true); + TestTransfer(1, 1000, 100, true); + // Using input roles for the re-offer. + Renegotiate(&client1_, cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_ACTIVE, NF_REOFFER); + TestTransfer(0, 1000, 100, true); + TestTransfer(1, 1000, 100, true); +} + +TEST_F(DtlsTransportChannelTest, TestDtlsReOfferFromAnswerer) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + // Initial role for client1 is ACTPASS and client2 is ACTIVE. + ASSERT_TRUE(Connect(cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_ACTIVE)); + TestTransfer(0, 1000, 100, true); + TestTransfer(1, 1000, 100, true); + // Using input roles for the re-offer. + Renegotiate(&client2_, cricket::CONNECTIONROLE_PASSIVE, + cricket::CONNECTIONROLE_ACTPASS, NF_REOFFER); + TestTransfer(0, 1000, 100, true); + TestTransfer(1, 1000, 100, true); +} + +// Test that any change in role after the intial setup will result in failure. +TEST_F(DtlsTransportChannelTest, TestDtlsRoleReversal) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + ASSERT_TRUE(Connect(cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_PASSIVE)); + + // Renegotiate from client2 with actpass and client1 as active. + Renegotiate(&client2_, cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_ACTIVE, + NF_REOFFER | NF_EXPECT_FAILURE); +} + +// Test that using different setup attributes which results in similar ssl +// role as the initial negotiation will result in success. +TEST_F(DtlsTransportChannelTest, TestDtlsReOfferWithDifferentSetupAttr) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + ASSERT_TRUE(Connect(cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_PASSIVE)); + // Renegotiate from client2 with actpass and client1 as active. + Renegotiate(&client2_, cricket::CONNECTIONROLE_ACTIVE, + cricket::CONNECTIONROLE_ACTPASS, NF_REOFFER); + TestTransfer(0, 1000, 100, true); + TestTransfer(1, 1000, 100, true); +} + +// Test that re-negotiation can be started before the clients become connected +// in the first negotiation. +TEST_F(DtlsTransportChannelTest, TestRenegotiateBeforeConnect) { + MAYBE_SKIP_TEST(HaveDtlsSrtp); + SetChannelCount(2); + PrepareDtls(true, true, rtc::KT_DEFAULT); + PrepareDtlsSrtp(true, true); + Negotiate(); + + Renegotiate(&client1_, cricket::CONNECTIONROLE_ACTPASS, + cricket::CONNECTIONROLE_ACTIVE, NF_REOFFER); + bool rv = client1_.Connect(&client2_); + EXPECT_TRUE(rv); + EXPECT_TRUE_WAIT( + client1_.all_channels_writable() && client2_.all_channels_writable(), + 10000); + + TestTransfer(0, 1000, 100, true); + TestTransfer(1, 1000, 100, true); +} + +// Test Certificates state after negotiation but before connection. +TEST_F(DtlsTransportChannelTest, TestCertificatesBeforeConnect) { + MAYBE_SKIP_TEST(HaveDtls); + PrepareDtls(true, true, rtc::KT_DEFAULT); + Negotiate(); + + rtc::scoped_refptr<rtc::RTCCertificate> certificate1; + rtc::scoped_refptr<rtc::RTCCertificate> certificate2; + rtc::scoped_ptr<rtc::SSLCertificate> remote_cert1; + rtc::scoped_ptr<rtc::SSLCertificate> remote_cert2; + + // After negotiation, each side has a distinct local certificate, but still no + // remote certificate, because connection has not yet occurred. + ASSERT_TRUE(client1_.transport()->GetLocalCertificate(&certificate1)); + ASSERT_TRUE(client2_.transport()->GetLocalCertificate(&certificate2)); + ASSERT_NE(certificate1->ssl_certificate().ToPEMString(), + certificate2->ssl_certificate().ToPEMString()); + ASSERT_FALSE( + client1_.transport()->GetRemoteSSLCertificate(remote_cert1.accept())); + ASSERT_FALSE(remote_cert1 != NULL); + ASSERT_FALSE( + client2_.transport()->GetRemoteSSLCertificate(remote_cert2.accept())); + ASSERT_FALSE(remote_cert2 != NULL); +} + +// Test Certificates state after connection. +TEST_F(DtlsTransportChannelTest, TestCertificatesAfterConnect) { + MAYBE_SKIP_TEST(HaveDtls); + PrepareDtls(true, true, rtc::KT_DEFAULT); + ASSERT_TRUE(Connect()); + + rtc::scoped_refptr<rtc::RTCCertificate> certificate1; + rtc::scoped_refptr<rtc::RTCCertificate> certificate2; + rtc::scoped_ptr<rtc::SSLCertificate> remote_cert1; + rtc::scoped_ptr<rtc::SSLCertificate> remote_cert2; + + // After connection, each side has a distinct local certificate. + ASSERT_TRUE(client1_.transport()->GetLocalCertificate(&certificate1)); + ASSERT_TRUE(client2_.transport()->GetLocalCertificate(&certificate2)); + ASSERT_NE(certificate1->ssl_certificate().ToPEMString(), + certificate2->ssl_certificate().ToPEMString()); + + // Each side's remote certificate is the other side's local certificate. + ASSERT_TRUE( + client1_.transport()->GetRemoteSSLCertificate(remote_cert1.accept())); + ASSERT_EQ(remote_cert1->ToPEMString(), + certificate2->ssl_certificate().ToPEMString()); + ASSERT_TRUE( + client2_.transport()->GetRemoteSSLCertificate(remote_cert2.accept())); + ASSERT_EQ(remote_cert2->ToPEMString(), + certificate1->ssl_certificate().ToPEMString()); +} diff --git a/webrtc/p2p/base/faketransportcontroller.h b/webrtc/p2p/base/faketransportcontroller.h new file mode 100644 index 0000000000..3e656fa4a3 --- /dev/null +++ b/webrtc/p2p/base/faketransportcontroller.h @@ -0,0 +1,544 @@ +/* + * Copyright 2009 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_ +#define WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_ + +#include <map> +#include <string> +#include <vector> + +#include "webrtc/p2p/base/transport.h" +#include "webrtc/p2p/base/transportchannel.h" +#include "webrtc/p2p/base/transportcontroller.h" +#include "webrtc/p2p/base/transportchannelimpl.h" +#include "webrtc/base/bind.h" +#include "webrtc/base/buffer.h" +#include "webrtc/base/fakesslidentity.h" +#include "webrtc/base/messagequeue.h" +#include "webrtc/base/sigslot.h" +#include "webrtc/base/sslfingerprint.h" +#include "webrtc/base/thread.h" + +namespace cricket { + +class FakeTransport; + +namespace { +struct PacketMessageData : public rtc::MessageData { + PacketMessageData(const char* data, size_t len) : packet(data, len) {} + rtc::Buffer packet; +}; +} // namespace + +// Fake transport channel class, which can be passed to anything that needs a +// transport channel. Can be informed of another FakeTransportChannel via +// SetDestination. +// TODO(hbos): Move implementation to .cc file, this and other classes in file. +class FakeTransportChannel : public TransportChannelImpl, + public rtc::MessageHandler { + public: + explicit FakeTransportChannel(Transport* transport, + const std::string& name, + int component) + : TransportChannelImpl(name, component), + transport_(transport), + dtls_fingerprint_("", nullptr, 0) {} + ~FakeTransportChannel() { Reset(); } + + uint64_t IceTiebreaker() const { return tiebreaker_; } + IceMode remote_ice_mode() const { return remote_ice_mode_; } + const std::string& ice_ufrag() const { return ice_ufrag_; } + const std::string& ice_pwd() const { return ice_pwd_; } + const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; } + const std::string& remote_ice_pwd() const { return remote_ice_pwd_; } + const rtc::SSLFingerprint& dtls_fingerprint() const { + return dtls_fingerprint_; + } + + // If async, will send packets by "Post"-ing to message queue instead of + // synchronously "Send"-ing. + void SetAsync(bool async) { async_ = async; } + + Transport* GetTransport() override { return transport_; } + + TransportChannelState GetState() const override { + if (connection_count_ == 0) { + return had_connection_ ? TransportChannelState::STATE_FAILED + : TransportChannelState::STATE_INIT; + } + + if (connection_count_ == 1) { + return TransportChannelState::STATE_COMPLETED; + } + + return TransportChannelState::STATE_CONNECTING; + } + + void SetIceRole(IceRole role) override { role_ = role; } + IceRole GetIceRole() const override { return role_; } + void SetIceTiebreaker(uint64_t tiebreaker) override { + tiebreaker_ = tiebreaker; + } + void SetIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) override { + ice_ufrag_ = ice_ufrag; + ice_pwd_ = ice_pwd; + } + void SetRemoteIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) override { + remote_ice_ufrag_ = ice_ufrag; + remote_ice_pwd_ = ice_pwd; + } + + void SetRemoteIceMode(IceMode mode) override { remote_ice_mode_ = mode; } + bool SetRemoteFingerprint(const std::string& alg, + const uint8_t* digest, + size_t digest_len) override { + dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len); + return true; + } + bool SetSslRole(rtc::SSLRole role) override { + ssl_role_ = role; + return true; + } + bool GetSslRole(rtc::SSLRole* role) const override { + *role = ssl_role_; + return true; + } + + void Connect() override { + if (state_ == STATE_INIT) { + state_ = STATE_CONNECTING; + } + } + + void MaybeStartGathering() override { + if (gathering_state_ == kIceGatheringNew) { + gathering_state_ = kIceGatheringGathering; + SignalGatheringState(this); + } + } + + IceGatheringState gathering_state() const override { + return gathering_state_; + } + + void Reset() { + if (state_ != STATE_INIT) { + state_ = STATE_INIT; + if (dest_) { + dest_->state_ = STATE_INIT; + dest_->dest_ = nullptr; + dest_ = nullptr; + } + } + } + + void SetWritable(bool writable) { set_writable(writable); } + + void SetDestination(FakeTransportChannel* dest) { + if (state_ == STATE_CONNECTING && dest) { + // This simulates the delivery of candidates. + dest_ = dest; + dest_->dest_ = this; + if (local_cert_ && dest_->local_cert_) { + do_dtls_ = true; + dest_->do_dtls_ = true; + NegotiateSrtpCiphers(); + } + state_ = STATE_CONNECTED; + dest_->state_ = STATE_CONNECTED; + set_writable(true); + dest_->set_writable(true); + } else if (state_ == STATE_CONNECTED && !dest) { + // Simulates loss of connectivity, by asymmetrically forgetting dest_. + dest_ = nullptr; + state_ = STATE_CONNECTING; + set_writable(false); + } + } + + void SetConnectionCount(size_t connection_count) { + size_t old_connection_count = connection_count_; + connection_count_ = connection_count; + if (connection_count) + had_connection_ = true; + if (connection_count_ < old_connection_count) + SignalConnectionRemoved(this); + } + + void SetCandidatesGatheringComplete() { + if (gathering_state_ != kIceGatheringComplete) { + gathering_state_ = kIceGatheringComplete; + SignalGatheringState(this); + } + } + + void SetReceiving(bool receiving) { set_receiving(receiving); } + + void SetIceConfig(const IceConfig& config) override { + receiving_timeout_ = config.receiving_timeout_ms; + gather_continually_ = config.gather_continually; + } + + int receiving_timeout() const { return receiving_timeout_; } + bool gather_continually() const { return gather_continually_; } + + int SendPacket(const char* data, + size_t len, + const rtc::PacketOptions& options, + int flags) override { + if (state_ != STATE_CONNECTED) { + return -1; + } + + if (flags != PF_SRTP_BYPASS && flags != 0) { + return -1; + } + + PacketMessageData* packet = new PacketMessageData(data, len); + if (async_) { + rtc::Thread::Current()->Post(this, 0, packet); + } else { + rtc::Thread::Current()->Send(this, 0, packet); + } + rtc::SentPacket sent_packet(options.packet_id, rtc::Time()); + SignalSentPacket(this, sent_packet); + return static_cast<int>(len); + } + int SetOption(rtc::Socket::Option opt, int value) override { return true; } + bool GetOption(rtc::Socket::Option opt, int* value) override { return true; } + int GetError() override { return 0; } + + void AddRemoteCandidate(const Candidate& candidate) override { + remote_candidates_.push_back(candidate); + } + const Candidates& remote_candidates() const { return remote_candidates_; } + + void OnMessage(rtc::Message* msg) override { + PacketMessageData* data = static_cast<PacketMessageData*>(msg->pdata); + dest_->SignalReadPacket(dest_, data->packet.data<char>(), + data->packet.size(), rtc::CreatePacketTime(0), 0); + delete data; + } + + bool SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { + local_cert_ = certificate; + return true; + } + + void SetRemoteSSLCertificate(rtc::FakeSSLCertificate* cert) { + remote_cert_ = cert; + } + + bool IsDtlsActive() const override { return do_dtls_; } + + bool SetSrtpCiphers(const std::vector<std::string>& ciphers) override { + srtp_ciphers_ = ciphers; + return true; + } + + bool GetSrtpCryptoSuite(std::string* cipher) override { + if (!chosen_srtp_cipher_.empty()) { + *cipher = chosen_srtp_cipher_; + return true; + } + return false; + } + + bool GetSslCipherSuite(int* cipher) override { return false; } + + rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const { + return local_cert_; + } + + bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override { + if (!remote_cert_) + return false; + + *cert = remote_cert_->GetReference(); + return true; + } + + bool ExportKeyingMaterial(const std::string& label, + const uint8_t* context, + size_t context_len, + bool use_context, + uint8_t* result, + size_t result_len) override { + if (!chosen_srtp_cipher_.empty()) { + memset(result, 0xff, result_len); + return true; + } + + return false; + } + + void NegotiateSrtpCiphers() { + for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin(); + it1 != srtp_ciphers_.end(); ++it1) { + for (std::vector<std::string>::const_iterator it2 = + dest_->srtp_ciphers_.begin(); + it2 != dest_->srtp_ciphers_.end(); ++it2) { + if (*it1 == *it2) { + chosen_srtp_cipher_ = *it1; + dest_->chosen_srtp_cipher_ = *it2; + return; + } + } + } + } + + bool GetStats(ConnectionInfos* infos) override { + ConnectionInfo info; + infos->clear(); + infos->push_back(info); + return true; + } + + void set_ssl_max_protocol_version(rtc::SSLProtocolVersion version) { + ssl_max_version_ = version; + } + rtc::SSLProtocolVersion ssl_max_protocol_version() const { + return ssl_max_version_; + } + + private: + enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED }; + Transport* transport_; + FakeTransportChannel* dest_ = nullptr; + State state_ = STATE_INIT; + bool async_ = false; + Candidates remote_candidates_; + rtc::scoped_refptr<rtc::RTCCertificate> local_cert_; + rtc::FakeSSLCertificate* remote_cert_ = nullptr; + bool do_dtls_ = false; + std::vector<std::string> srtp_ciphers_; + std::string chosen_srtp_cipher_; + int receiving_timeout_ = -1; + bool gather_continually_ = false; + IceRole role_ = ICEROLE_UNKNOWN; + uint64_t tiebreaker_ = 0; + std::string ice_ufrag_; + std::string ice_pwd_; + std::string remote_ice_ufrag_; + std::string remote_ice_pwd_; + IceMode remote_ice_mode_ = ICEMODE_FULL; + rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_10; + rtc::SSLFingerprint dtls_fingerprint_; + rtc::SSLRole ssl_role_ = rtc::SSL_CLIENT; + size_t connection_count_ = 0; + IceGatheringState gathering_state_ = kIceGatheringNew; + bool had_connection_ = false; +}; + +// Fake transport class, which can be passed to anything that needs a Transport. +// Can be informed of another FakeTransport via SetDestination (low-tech way +// of doing candidates) +class FakeTransport : public Transport { + public: + typedef std::map<int, FakeTransportChannel*> ChannelMap; + + explicit FakeTransport(const std::string& name) : Transport(name, nullptr) {} + + // Note that we only have a constructor with the allocator parameter so it can + // be wrapped by a DtlsTransport. + FakeTransport(const std::string& name, PortAllocator* allocator) + : Transport(name, nullptr) {} + + ~FakeTransport() { DestroyAllChannels(); } + + const ChannelMap& channels() const { return channels_; } + + // If async, will send packets by "Post"-ing to message queue instead of + // synchronously "Send"-ing. + void SetAsync(bool async) { async_ = async; } + void SetDestination(FakeTransport* dest) { + dest_ = dest; + for (const auto& kv : channels_) { + kv.second->SetLocalCertificate(certificate_); + SetChannelDestination(kv.first, kv.second); + } + } + + void SetWritable(bool writable) { + for (const auto& kv : channels_) { + kv.second->SetWritable(writable); + } + } + + void SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override { + certificate_ = certificate; + } + bool GetLocalCertificate( + rtc::scoped_refptr<rtc::RTCCertificate>* certificate) override { + if (!certificate_) + return false; + + *certificate = certificate_; + return true; + } + + bool GetSslRole(rtc::SSLRole* role) const override { + if (channels_.empty()) { + return false; + } + return channels_.begin()->second->GetSslRole(role); + } + + bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override { + ssl_max_version_ = version; + for (const auto& kv : channels_) { + kv.second->set_ssl_max_protocol_version(ssl_max_version_); + } + return true; + } + rtc::SSLProtocolVersion ssl_max_protocol_version() const { + return ssl_max_version_; + } + + using Transport::local_description; + using Transport::remote_description; + + protected: + TransportChannelImpl* CreateTransportChannel(int component) override { + if (channels_.find(component) != channels_.end()) { + return nullptr; + } + FakeTransportChannel* channel = + new FakeTransportChannel(this, name(), component); + channel->set_ssl_max_protocol_version(ssl_max_version_); + channel->SetAsync(async_); + SetChannelDestination(component, channel); + channels_[component] = channel; + return channel; + } + + void DestroyTransportChannel(TransportChannelImpl* channel) override { + channels_.erase(channel->component()); + delete channel; + } + + private: + FakeTransportChannel* GetFakeChannel(int component) { + auto it = channels_.find(component); + return (it != channels_.end()) ? it->second : nullptr; + } + + void SetChannelDestination(int component, FakeTransportChannel* channel) { + FakeTransportChannel* dest_channel = nullptr; + if (dest_) { + dest_channel = dest_->GetFakeChannel(component); + if (dest_channel) { + dest_channel->SetLocalCertificate(dest_->certificate_); + } + } + channel->SetDestination(dest_channel); + } + + // Note, this is distinct from the Channel map owned by Transport. + // This map just tracks the FakeTransportChannels created by this class. + // It's mainly needed so that we can access a FakeTransportChannel directly, + // even if wrapped by a DtlsTransportChannelWrapper. + ChannelMap channels_; + FakeTransport* dest_ = nullptr; + bool async_ = false; + rtc::scoped_refptr<rtc::RTCCertificate> certificate_; + rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_10; +}; + +// Fake TransportController class, which can be passed into a BaseChannel object +// for test purposes. Can be connected to other FakeTransportControllers via +// Connect(). +// +// This fake is unusual in that for the most part, it's implemented with the +// real TransportController code, but with fake TransportChannels underneath. +class FakeTransportController : public TransportController { + public: + FakeTransportController() + : TransportController(rtc::Thread::Current(), + rtc::Thread::Current(), + nullptr), + fail_create_channel_(false) {} + + explicit FakeTransportController(IceRole role) + : TransportController(rtc::Thread::Current(), + rtc::Thread::Current(), + nullptr), + fail_create_channel_(false) { + SetIceRole(role); + } + + explicit FakeTransportController(rtc::Thread* worker_thread) + : TransportController(rtc::Thread::Current(), worker_thread, nullptr), + fail_create_channel_(false) {} + + FakeTransportController(rtc::Thread* worker_thread, IceRole role) + : TransportController(rtc::Thread::Current(), worker_thread, nullptr), + fail_create_channel_(false) { + SetIceRole(role); + } + + FakeTransport* GetTransport_w(const std::string& transport_name) { + return static_cast<FakeTransport*>( + TransportController::GetTransport_w(transport_name)); + } + + void Connect(FakeTransportController* dest) { + worker_thread()->Invoke<void>( + rtc::Bind(&FakeTransportController::Connect_w, this, dest)); + } + + TransportChannel* CreateTransportChannel_w(const std::string& transport_name, + int component) override { + if (fail_create_channel_) { + return nullptr; + } + return TransportController::CreateTransportChannel_w(transport_name, + component); + } + + void set_fail_channel_creation(bool fail_channel_creation) { + fail_create_channel_ = fail_channel_creation; + } + + protected: + Transport* CreateTransport_w(const std::string& transport_name) override { + return new FakeTransport(transport_name); + } + + void Connect_w(FakeTransportController* dest) { + // Simulate the exchange of candidates. + ConnectChannels_w(); + dest->ConnectChannels_w(); + for (auto& kv : transports()) { + FakeTransport* transport = static_cast<FakeTransport*>(kv.second); + transport->SetDestination(dest->GetTransport_w(kv.first)); + } + } + + void ConnectChannels_w() { + for (auto& kv : transports()) { + FakeTransport* transport = static_cast<FakeTransport*>(kv.second); + transport->ConnectChannels(); + transport->MaybeStartGathering(); + } + } + + private: + bool fail_create_channel_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_ diff --git a/webrtc/p2p/base/p2ptransport.cc b/webrtc/p2p/base/p2ptransport.cc new file mode 100644 index 0000000000..abc4c14504 --- /dev/null +++ b/webrtc/p2p/base/p2ptransport.cc @@ -0,0 +1,38 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/p2ptransport.h" + +#include <string> + +#include "webrtc/base/base64.h" +#include "webrtc/base/common.h" +#include "webrtc/base/stringencode.h" +#include "webrtc/base/stringutils.h" +#include "webrtc/p2p/base/p2ptransportchannel.h" + +namespace cricket { + +P2PTransport::P2PTransport(const std::string& name, PortAllocator* allocator) + : Transport(name, allocator) {} + +P2PTransport::~P2PTransport() { + DestroyAllChannels(); +} + +TransportChannelImpl* P2PTransport::CreateTransportChannel(int component) { + return new P2PTransportChannel(name(), component, this, port_allocator()); +} + +void P2PTransport::DestroyTransportChannel(TransportChannelImpl* channel) { + delete channel; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/p2ptransport.h b/webrtc/p2p/base/p2ptransport.h new file mode 100644 index 0000000000..0f965b4cdc --- /dev/null +++ b/webrtc/p2p/base/p2ptransport.h @@ -0,0 +1,37 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_P2PTRANSPORT_H_ +#define WEBRTC_P2P_BASE_P2PTRANSPORT_H_ + +#include <string> +#include "webrtc/p2p/base/transport.h" + +namespace cricket { + +// Everything in this class should be called on the worker thread. +class P2PTransport : public Transport { + public: + P2PTransport(const std::string& name, PortAllocator* allocator); + virtual ~P2PTransport(); + + protected: + // Creates and destroys P2PTransportChannel. + virtual TransportChannelImpl* CreateTransportChannel(int component); + virtual void DestroyTransportChannel(TransportChannelImpl* channel); + + friend class P2PTransportChannel; + + RTC_DISALLOW_COPY_AND_ASSIGN(P2PTransport); +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_P2PTRANSPORT_H_ diff --git a/webrtc/p2p/base/p2ptransportchannel.cc b/webrtc/p2p/base/p2ptransportchannel.cc new file mode 100644 index 0000000000..623085f9a8 --- /dev/null +++ b/webrtc/p2p/base/p2ptransportchannel.cc @@ -0,0 +1,1384 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/p2ptransportchannel.h" + +#include <algorithm> +#include <set> +#include "webrtc/p2p/base/common.h" +#include "webrtc/p2p/base/relayport.h" // For RELAY_PORT_TYPE. +#include "webrtc/p2p/base/stunport.h" // For STUN_PORT_TYPE. +#include "webrtc/base/common.h" +#include "webrtc/base/crc32.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/stringencode.h" +#include "webrtc/system_wrappers/include/field_trial.h" + +namespace { + +// messages for queuing up work for ourselves +enum { MSG_SORT = 1, MSG_CHECK_AND_PING }; + +// The minimum improvement in RTT that justifies a switch. +static const double kMinImprovement = 10; + +cricket::PortInterface::CandidateOrigin GetOrigin(cricket::PortInterface* port, + cricket::PortInterface* origin_port) { + if (!origin_port) + return cricket::PortInterface::ORIGIN_MESSAGE; + else if (port == origin_port) + return cricket::PortInterface::ORIGIN_THIS_PORT; + else + return cricket::PortInterface::ORIGIN_OTHER_PORT; +} + +// Compares two connections based only on static information about them. +int CompareConnectionCandidates(cricket::Connection* a, + cricket::Connection* b) { + // Compare connection priority. Lower values get sorted last. + if (a->priority() > b->priority()) + return 1; + if (a->priority() < b->priority()) + return -1; + + // If we're still tied at this point, prefer a younger generation. + return (a->remote_candidate().generation() + a->port()->generation()) - + (b->remote_candidate().generation() + b->port()->generation()); +} + +// Compare two connections based on their writing, receiving, and connected +// states. +int CompareConnectionStates(cricket::Connection* a, cricket::Connection* b) { + // Sort based on write-state. Better states have lower values. + if (a->write_state() < b->write_state()) + return 1; + if (a->write_state() > b->write_state()) + return -1; + + // We prefer a receiving connection to a non-receiving, higher-priority + // connection when sorting connections and choosing which connection to + // switch to. + if (a->receiving() && !b->receiving()) + return 1; + if (!a->receiving() && b->receiving()) + return -1; + + // WARNING: Some complexity here about TCP reconnecting. + // When a TCP connection fails because of a TCP socket disconnecting, the + // active side of the connection will attempt to reconnect for 5 seconds while + // pretending to be writable (the connection is not set to the unwritable + // state). On the passive side, the connection also remains writable even + // though it is disconnected, and a new connection is created when the active + // side connects. At that point, there are two TCP connections on the passive + // side: 1. the old, disconnected one that is pretending to be writable, and + // 2. the new, connected one that is maybe not yet writable. For purposes of + // pruning, pinging, and selecting the best connection, we want to treat the + // new connection as "better" than the old one. We could add a method called + // something like Connection::ImReallyBadEvenThoughImWritable, but that is + // equivalent to the existing Connection::connected(), which we already have. + // So, in code throughout this file, we'll check whether the connection is + // connected() or not, and if it is not, treat it as "worse" than a connected + // one, even though it's writable. In the code below, we're doing so to make + // sure we treat a new writable connection as better than an old disconnected + // connection. + + // In the case where we reconnect TCP connections, the original best + // connection is disconnected without changing to WRITE_TIMEOUT. In this case, + // the new connection, when it becomes writable, should have higher priority. + if (a->write_state() == cricket::Connection::STATE_WRITABLE && + b->write_state() == cricket::Connection::STATE_WRITABLE) { + if (a->connected() && !b->connected()) { + return 1; + } + if (!a->connected() && b->connected()) { + return -1; + } + } + return 0; +} + +int CompareConnections(cricket::Connection* a, cricket::Connection* b) { + int state_cmp = CompareConnectionStates(a, b); + if (state_cmp != 0) { + return state_cmp; + } + // Compare the candidate information. + return CompareConnectionCandidates(a, b); +} + +// Wraps the comparison connection into a less than operator that puts higher +// priority writable connections first. +class ConnectionCompare { + public: + bool operator()(const cricket::Connection *ca, + const cricket::Connection *cb) { + cricket::Connection* a = const_cast<cricket::Connection*>(ca); + cricket::Connection* b = const_cast<cricket::Connection*>(cb); + + // Compare first on writability and static preferences. + int cmp = CompareConnections(a, b); + if (cmp > 0) + return true; + if (cmp < 0) + return false; + + // Otherwise, sort based on latency estimate. + return a->rtt() < b->rtt(); + + // Should we bother checking for the last connection that last received + // data? It would help rendezvous on the connection that is also receiving + // packets. + // + // TODO: Yes we should definitely do this. The TCP protocol gains + // efficiency by being used bidirectionally, as opposed to two separate + // unidirectional streams. This test should probably occur before + // comparison of local prefs (assuming combined prefs are the same). We + // need to be careful though, not to bounce back and forth with both sides + // trying to rendevous with the other. + } +}; + +// Determines whether we should switch between two connections, based first on +// connection states, static preferences, and then (if those are equal) on +// latency estimates. +bool ShouldSwitch(cricket::Connection* a_conn, + cricket::Connection* b_conn, + cricket::IceRole ice_role) { + if (a_conn == b_conn) + return false; + + if (!a_conn || !b_conn) // don't think the latter should happen + return true; + + // We prefer to switch to a writable and receiving connection over a + // non-writable or non-receiving connection, even if the latter has + // been nominated by the controlling side. + int state_cmp = CompareConnectionStates(a_conn, b_conn); + if (state_cmp != 0) { + return state_cmp < 0; + } + if (ice_role == cricket::ICEROLE_CONTROLLED && a_conn->nominated()) { + LOG(LS_VERBOSE) << "Controlled side did not switch due to nominated status"; + return false; + } + + int prefs_cmp = CompareConnectionCandidates(a_conn, b_conn); + if (prefs_cmp != 0) { + return prefs_cmp < 0; + } + + return b_conn->rtt() <= a_conn->rtt() + kMinImprovement; +} + +} // unnamed namespace + +namespace cricket { + +// When the socket is unwritable, we will use 10 Kbps (ignoring IP+UDP headers) +// for pinging. When the socket is writable, we will use only 1 Kbps because +// we don't want to degrade the quality on a modem. These numbers should work +// well on a 28.8K modem, which is the slowest connection on which the voice +// quality is reasonable at all. +static const uint32_t PING_PACKET_SIZE = 60 * 8; +// STRONG_PING_DELAY (480ms) is applied when the best connection is both +// writable and receiving. +static const uint32_t STRONG_PING_DELAY = 1000 * PING_PACKET_SIZE / 1000; +// WEAK_PING_DELAY (48ms) is applied when the best connection is either not +// writable or not receiving. +const uint32_t WEAK_PING_DELAY = 1000 * PING_PACKET_SIZE / 10000; + +// If the current best connection is both writable and receiving, then we will +// also try hard to make sure it is pinged at this rate (a little less than +// 2 * STRONG_PING_DELAY). +static const uint32_t MAX_CURRENT_STRONG_DELAY = 900; + +static const int MIN_CHECK_RECEIVING_DELAY = 50; // ms + + +P2PTransportChannel::P2PTransportChannel(const std::string& transport_name, + int component, + P2PTransport* transport, + PortAllocator* allocator) + : TransportChannelImpl(transport_name, component), + transport_(transport), + allocator_(allocator), + worker_thread_(rtc::Thread::Current()), + incoming_only_(false), + error_(0), + best_connection_(NULL), + pending_best_connection_(NULL), + sort_dirty_(false), + was_writable_(false), + remote_ice_mode_(ICEMODE_FULL), + ice_role_(ICEROLE_UNKNOWN), + tiebreaker_(0), + remote_candidate_generation_(0), + gathering_state_(kIceGatheringNew), + check_receiving_delay_(MIN_CHECK_RECEIVING_DELAY * 5), + receiving_timeout_(MIN_CHECK_RECEIVING_DELAY * 50) { + uint32_t weak_ping_delay = ::strtoul( + webrtc::field_trial::FindFullName("WebRTC-StunInterPacketDelay").c_str(), + nullptr, 10); + if (weak_ping_delay) { + weak_ping_delay_ = weak_ping_delay; + } +} + +P2PTransportChannel::~P2PTransportChannel() { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + for (size_t i = 0; i < allocator_sessions_.size(); ++i) + delete allocator_sessions_[i]; +} + +// Add the allocator session to our list so that we know which sessions +// are still active. +void P2PTransportChannel::AddAllocatorSession(PortAllocatorSession* session) { + session->set_generation(static_cast<uint32_t>(allocator_sessions_.size())); + allocator_sessions_.push_back(session); + + // We now only want to apply new candidates that we receive to the ports + // created by this new session because these are replacing those of the + // previous sessions. + ports_.clear(); + + session->SignalPortReady.connect(this, &P2PTransportChannel::OnPortReady); + session->SignalCandidatesReady.connect( + this, &P2PTransportChannel::OnCandidatesReady); + session->SignalCandidatesAllocationDone.connect( + this, &P2PTransportChannel::OnCandidatesAllocationDone); + session->StartGettingPorts(); +} + +void P2PTransportChannel::AddConnection(Connection* connection) { + connections_.push_back(connection); + connection->set_remote_ice_mode(remote_ice_mode_); + connection->set_receiving_timeout(receiving_timeout_); + connection->SignalReadPacket.connect( + this, &P2PTransportChannel::OnReadPacket); + connection->SignalReadyToSend.connect( + this, &P2PTransportChannel::OnReadyToSend); + connection->SignalStateChange.connect( + this, &P2PTransportChannel::OnConnectionStateChange); + connection->SignalDestroyed.connect( + this, &P2PTransportChannel::OnConnectionDestroyed); + connection->SignalNominated.connect(this, &P2PTransportChannel::OnNominated); + had_connection_ = true; +} + +void P2PTransportChannel::SetIceRole(IceRole ice_role) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + if (ice_role_ != ice_role) { + ice_role_ = ice_role; + for (std::vector<PortInterface *>::iterator it = ports_.begin(); + it != ports_.end(); ++it) { + (*it)->SetIceRole(ice_role); + } + } +} + +void P2PTransportChannel::SetIceTiebreaker(uint64_t tiebreaker) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + if (!ports_.empty()) { + LOG(LS_ERROR) + << "Attempt to change tiebreaker after Port has been allocated."; + return; + } + + tiebreaker_ = tiebreaker; +} + +// A channel is considered ICE completed once there is at most one active +// connection per network and at least one active connection. +TransportChannelState P2PTransportChannel::GetState() const { + if (!had_connection_) { + return TransportChannelState::STATE_INIT; + } + + std::vector<Connection*> active_connections; + for (Connection* connection : connections_) { + if (connection->active()) { + active_connections.push_back(connection); + } + } + if (active_connections.empty()) { + return TransportChannelState::STATE_FAILED; + } + + std::set<rtc::Network*> networks; + for (Connection* connection : active_connections) { + rtc::Network* network = connection->port()->Network(); + if (networks.find(network) == networks.end()) { + networks.insert(network); + } else { + LOG_J(LS_VERBOSE, this) << "Ice not completed yet for this channel as " + << network->ToString() + << " has more than 1 connection."; + return TransportChannelState::STATE_CONNECTING; + } + } + + LOG_J(LS_VERBOSE, this) << "Ice is completed for this channel."; + return TransportChannelState::STATE_COMPLETED; +} + +void P2PTransportChannel::SetIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + ice_ufrag_ = ice_ufrag; + ice_pwd_ = ice_pwd; + // Note: Candidate gathering will restart when MaybeStartGathering is next + // called. +} + +void P2PTransportChannel::SetRemoteIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + bool ice_restart = false; + if (!remote_ice_ufrag_.empty() && !remote_ice_pwd_.empty()) { + ice_restart = (remote_ice_ufrag_ != ice_ufrag) || + (remote_ice_pwd_!= ice_pwd); + } + + remote_ice_ufrag_ = ice_ufrag; + remote_ice_pwd_ = ice_pwd; + + // We need to update the credentials for any peer reflexive candidates. + std::vector<Connection*>::iterator it = connections_.begin(); + for (; it != connections_.end(); ++it) { + (*it)->MaybeSetRemoteIceCredentials(ice_ufrag, ice_pwd); + } + + if (ice_restart) { + // We need to keep track of the remote ice restart so newer + // connections are prioritized over the older. + ++remote_candidate_generation_; + } +} + +void P2PTransportChannel::SetRemoteIceMode(IceMode mode) { + remote_ice_mode_ = mode; +} + +void P2PTransportChannel::SetIceConfig(const IceConfig& config) { + gather_continually_ = config.gather_continually; + LOG(LS_INFO) << "Set gather_continually to " << gather_continually_; + + if (config.receiving_timeout_ms < 0) { + return; + } + receiving_timeout_ = config.receiving_timeout_ms; + check_receiving_delay_ = + std::max(MIN_CHECK_RECEIVING_DELAY, receiving_timeout_ / 10); + + for (Connection* connection : connections_) { + connection->set_receiving_timeout(receiving_timeout_); + } + LOG(LS_INFO) << "Set ICE receiving timeout to " << receiving_timeout_ + << " milliseconds"; +} + +// Go into the state of processing candidates, and running in general +void P2PTransportChannel::Connect() { + ASSERT(worker_thread_ == rtc::Thread::Current()); + if (ice_ufrag_.empty() || ice_pwd_.empty()) { + ASSERT(false); + LOG(LS_ERROR) << "P2PTransportChannel::Connect: The ice_ufrag_ and the " + << "ice_pwd_ are not set."; + return; + } + + // Start checking and pinging as the ports come in. + thread()->Post(this, MSG_CHECK_AND_PING); +} + +void P2PTransportChannel::MaybeStartGathering() { + // Start gathering if we never started before, or if an ICE restart occurred. + if (allocator_sessions_.empty() || + IceCredentialsChanged(allocator_sessions_.back()->ice_ufrag(), + allocator_sessions_.back()->ice_pwd(), ice_ufrag_, + ice_pwd_)) { + if (gathering_state_ != kIceGatheringGathering) { + gathering_state_ = kIceGatheringGathering; + SignalGatheringState(this); + } + // Time for a new allocator + AddAllocatorSession(allocator_->CreateSession( + SessionId(), transport_name(), component(), ice_ufrag_, ice_pwd_)); + } +} + +// A new port is available, attempt to make connections for it +void P2PTransportChannel::OnPortReady(PortAllocatorSession *session, + PortInterface* port) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + // Set in-effect options on the new port + for (OptionMap::const_iterator it = options_.begin(); + it != options_.end(); + ++it) { + int val = port->SetOption(it->first, it->second); + if (val < 0) { + LOG_J(LS_WARNING, port) << "SetOption(" << it->first + << ", " << it->second + << ") failed: " << port->GetError(); + } + } + + // Remember the ports and candidates, and signal that candidates are ready. + // The session will handle this, and send an initiate/accept/modify message + // if one is pending. + + port->SetIceRole(ice_role_); + port->SetIceTiebreaker(tiebreaker_); + ports_.push_back(port); + port->SignalUnknownAddress.connect( + this, &P2PTransportChannel::OnUnknownAddress); + port->SignalDestroyed.connect(this, &P2PTransportChannel::OnPortDestroyed); + port->SignalRoleConflict.connect( + this, &P2PTransportChannel::OnRoleConflict); + port->SignalSentPacket.connect(this, &P2PTransportChannel::OnSentPacket); + + // Attempt to create a connection from this new port to all of the remote + // candidates that we were given so far. + + std::vector<RemoteCandidate>::iterator iter; + for (iter = remote_candidates_.begin(); iter != remote_candidates_.end(); + ++iter) { + CreateConnection(port, *iter, iter->origin_port()); + } + + SortConnections(); +} + +// A new candidate is available, let listeners know +void P2PTransportChannel::OnCandidatesReady( + PortAllocatorSession* session, + const std::vector<Candidate>& candidates) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + for (size_t i = 0; i < candidates.size(); ++i) { + SignalCandidateGathered(this, candidates[i]); + } +} + +void P2PTransportChannel::OnCandidatesAllocationDone( + PortAllocatorSession* session) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + gathering_state_ = kIceGatheringComplete; + LOG(LS_INFO) << "P2PTransportChannel: " << transport_name() << ", component " + << component() << " gathering complete"; + SignalGatheringState(this); +} + +// Handle stun packets +void P2PTransportChannel::OnUnknownAddress( + PortInterface* port, + const rtc::SocketAddress& address, ProtocolType proto, + IceMessage* stun_msg, const std::string &remote_username, + bool port_muxed) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + // Port has received a valid stun packet from an address that no Connection + // is currently available for. See if we already have a candidate with the + // address. If it isn't we need to create new candidate for it. + + // Determine if the remote candidates use shared ufrag. + bool ufrag_per_port = false; + std::vector<RemoteCandidate>::iterator it; + if (remote_candidates_.size() > 0) { + it = remote_candidates_.begin(); + std::string username = it->username(); + for (; it != remote_candidates_.end(); ++it) { + if (it->username() != username) { + ufrag_per_port = true; + break; + } + } + } + + const Candidate* candidate = NULL; + std::string remote_password; + for (it = remote_candidates_.begin(); it != remote_candidates_.end(); ++it) { + if (it->username() == remote_username) { + remote_password = it->password(); + if (ufrag_per_port || + (it->address() == address && + it->protocol() == ProtoToString(proto))) { + candidate = &(*it); + break; + } + // We don't want to break here because we may find a match of the address + // later. + } + } + + // The STUN binding request may arrive after setRemoteDescription and before + // adding remote candidate, so we need to set the password to the shared + // password if the user name matches. + if (remote_password.empty() && remote_username == remote_ice_ufrag_) { + remote_password = remote_ice_pwd_; + } + + Candidate remote_candidate; + bool remote_candidate_is_new = (candidate == nullptr); + if (!remote_candidate_is_new) { + remote_candidate = *candidate; + if (ufrag_per_port) { + remote_candidate.set_address(address); + } + } else { + // Create a new candidate with this address. + int remote_candidate_priority; + + // The priority of the candidate is set to the PRIORITY attribute + // from the request. + const StunUInt32Attribute* priority_attr = + stun_msg->GetUInt32(STUN_ATTR_PRIORITY); + if (!priority_attr) { + LOG(LS_WARNING) << "P2PTransportChannel::OnUnknownAddress - " + << "No STUN_ATTR_PRIORITY found in the " + << "stun request message"; + port->SendBindingErrorResponse(stun_msg, address, STUN_ERROR_BAD_REQUEST, + STUN_ERROR_REASON_BAD_REQUEST); + return; + } + remote_candidate_priority = priority_attr->value(); + + // RFC 5245 + // If the source transport address of the request does not match any + // existing remote candidates, it represents a new peer reflexive remote + // candidate. + remote_candidate = + Candidate(component(), ProtoToString(proto), address, 0, + remote_username, remote_password, PRFLX_PORT_TYPE, 0U, ""); + + // From RFC 5245, section-7.2.1.3: + // The foundation of the candidate is set to an arbitrary value, different + // from the foundation for all other remote candidates. + remote_candidate.set_foundation( + rtc::ToString<uint32_t>(rtc::ComputeCrc32(remote_candidate.id()))); + + remote_candidate.set_priority(remote_candidate_priority); + } + + // RFC5245, the agent constructs a pair whose local candidate is equal to + // the transport address on which the STUN request was received, and a + // remote candidate equal to the source transport address where the + // request came from. + + // There shouldn't be an existing connection with this remote address. + // When ports are muxed, this channel might get multiple unknown address + // signals. In that case if the connection is already exists, we should + // simply ignore the signal otherwise send server error. + if (port->GetConnection(remote_candidate.address())) { + if (port_muxed) { + LOG(LS_INFO) << "Connection already exists for peer reflexive " + << "candidate: " << remote_candidate.ToString(); + return; + } else { + ASSERT(false); + port->SendBindingErrorResponse(stun_msg, address, + STUN_ERROR_SERVER_ERROR, + STUN_ERROR_REASON_SERVER_ERROR); + return; + } + } + + Connection* connection = port->CreateConnection( + remote_candidate, cricket::PortInterface::ORIGIN_THIS_PORT); + if (!connection) { + ASSERT(false); + port->SendBindingErrorResponse(stun_msg, address, STUN_ERROR_SERVER_ERROR, + STUN_ERROR_REASON_SERVER_ERROR); + return; + } + + LOG(LS_INFO) << "Adding connection from " + << (remote_candidate_is_new ? "peer reflexive" : "resurrected") + << " candidate: " << remote_candidate.ToString(); + AddConnection(connection); + connection->ReceivedPing(); + + bool received_use_candidate = + stun_msg->GetByteString(STUN_ATTR_USE_CANDIDATE) != nullptr; + if (received_use_candidate && ice_role_ == ICEROLE_CONTROLLED) { + connection->set_nominated(true); + OnNominated(connection); + } + + // Update the list of connections since we just added another. We do this + // after sending the response since it could (in principle) delete the + // connection in question. + SortConnections(); +} + +void P2PTransportChannel::OnRoleConflict(PortInterface* port) { + SignalRoleConflict(this); // STUN ping will be sent when SetRole is called + // from Transport. +} + +void P2PTransportChannel::OnNominated(Connection* conn) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + ASSERT(ice_role_ == ICEROLE_CONTROLLED); + + if (conn->write_state() == Connection::STATE_WRITABLE) { + if (best_connection_ != conn) { + pending_best_connection_ = NULL; + LOG(LS_INFO) << "Switching best connection on controlled side: " + << conn->ToString(); + SwitchBestConnectionTo(conn); + // Now we have selected the best connection, time to prune other existing + // connections and update the read/write state of the channel. + RequestSort(); + } + } else { + LOG(LS_INFO) << "Not switching the best connection on controlled side yet," + << " because it's not writable: " << conn->ToString(); + pending_best_connection_ = conn; + } +} + +void P2PTransportChannel::AddRemoteCandidate(const Candidate& candidate) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + uint32_t generation = candidate.generation(); + // Network may not guarantee the order of the candidate delivery. If a + // remote candidate with an older generation arrives, drop it. + if (generation != 0 && generation < remote_candidate_generation_) { + LOG(LS_WARNING) << "Dropping a remote candidate because its generation " + << generation + << " is lower than the current remote generation " + << remote_candidate_generation_; + return; + } + + // Create connections to this remote candidate. + CreateConnections(candidate, NULL); + + // Resort the connections list, which may have new elements. + SortConnections(); +} + +// Creates connections from all of the ports that we care about to the given +// remote candidate. The return value is true if we created a connection from +// the origin port. +bool P2PTransportChannel::CreateConnections(const Candidate& remote_candidate, + PortInterface* origin_port) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + Candidate new_remote_candidate(remote_candidate); + new_remote_candidate.set_generation( + GetRemoteCandidateGeneration(remote_candidate)); + // ICE candidates don't need to have username and password set, but + // the code below this (specifically, ConnectionRequest::Prepare in + // port.cc) uses the remote candidates's username. So, we set it + // here. + if (remote_candidate.username().empty()) { + new_remote_candidate.set_username(remote_ice_ufrag_); + } + if (remote_candidate.password().empty()) { + new_remote_candidate.set_password(remote_ice_pwd_); + } + + // If we've already seen the new remote candidate (in the current candidate + // generation), then we shouldn't try creating connections for it. + // We either already have a connection for it, or we previously created one + // and then later pruned it. If we don't return, the channel will again + // re-create any connections that were previously pruned, which will then + // immediately be re-pruned, churning the network for no purpose. + // This only applies to candidates received over signaling (i.e. origin_port + // is NULL). + if (!origin_port && IsDuplicateRemoteCandidate(new_remote_candidate)) { + // return true to indicate success, without creating any new connections. + return true; + } + + // Add a new connection for this candidate to every port that allows such a + // connection (i.e., if they have compatible protocols) and that does not + // already have a connection to an equivalent candidate. We must be careful + // to make sure that the origin port is included, even if it was pruned, + // since that may be the only port that can create this connection. + bool created = false; + std::vector<PortInterface *>::reverse_iterator it; + for (it = ports_.rbegin(); it != ports_.rend(); ++it) { + if (CreateConnection(*it, new_remote_candidate, origin_port)) { + if (*it == origin_port) + created = true; + } + } + + if ((origin_port != NULL) && + std::find(ports_.begin(), ports_.end(), origin_port) == ports_.end()) { + if (CreateConnection(origin_port, new_remote_candidate, origin_port)) + created = true; + } + + // Remember this remote candidate so that we can add it to future ports. + RememberRemoteCandidate(new_remote_candidate, origin_port); + + return created; +} + +// Setup a connection object for the local and remote candidate combination. +// And then listen to connection object for changes. +bool P2PTransportChannel::CreateConnection(PortInterface* port, + const Candidate& remote_candidate, + PortInterface* origin_port) { + // Look for an existing connection with this remote address. If one is not + // found, then we can create a new connection for this address. + Connection* connection = port->GetConnection(remote_candidate.address()); + if (connection != NULL) { + connection->MaybeUpdatePeerReflexiveCandidate(remote_candidate); + + // It is not legal to try to change any of the parameters of an existing + // connection; however, the other side can send a duplicate candidate. + if (!remote_candidate.IsEquivalent(connection->remote_candidate())) { + LOG(INFO) << "Attempt to change a remote candidate." + << " Existing remote candidate: " + << connection->remote_candidate().ToString() + << "New remote candidate: " + << remote_candidate.ToString(); + return false; + } + } else { + PortInterface::CandidateOrigin origin = GetOrigin(port, origin_port); + + // Don't create connection if this is a candidate we received in a + // message and we are not allowed to make outgoing connections. + if (origin == cricket::PortInterface::ORIGIN_MESSAGE && incoming_only_) + return false; + + connection = port->CreateConnection(remote_candidate, origin); + if (!connection) + return false; + + AddConnection(connection); + + LOG_J(LS_INFO, this) << "Created connection with origin=" << origin << ", (" + << connections_.size() << " total)"; + } + + return true; +} + +bool P2PTransportChannel::FindConnection( + cricket::Connection* connection) const { + std::vector<Connection*>::const_iterator citer = + std::find(connections_.begin(), connections_.end(), connection); + return citer != connections_.end(); +} + +uint32_t P2PTransportChannel::GetRemoteCandidateGeneration( + const Candidate& candidate) { + // We need to keep track of the remote ice restart so newer + // connections are prioritized over the older. + ASSERT(candidate.generation() == 0 || + candidate.generation() == remote_candidate_generation_); + return remote_candidate_generation_; +} + +// Check if remote candidate is already cached. +bool P2PTransportChannel::IsDuplicateRemoteCandidate( + const Candidate& candidate) { + for (size_t i = 0; i < remote_candidates_.size(); ++i) { + if (remote_candidates_[i].IsEquivalent(candidate)) { + return true; + } + } + return false; +} + +// Maintain our remote candidate list, adding this new remote one. +void P2PTransportChannel::RememberRemoteCandidate( + const Candidate& remote_candidate, PortInterface* origin_port) { + // Remove any candidates whose generation is older than this one. The + // presence of a new generation indicates that the old ones are not useful. + size_t i = 0; + while (i < remote_candidates_.size()) { + if (remote_candidates_[i].generation() < remote_candidate.generation()) { + LOG(INFO) << "Pruning candidate from old generation: " + << remote_candidates_[i].address().ToSensitiveString(); + remote_candidates_.erase(remote_candidates_.begin() + i); + } else { + i += 1; + } + } + + // Make sure this candidate is not a duplicate. + if (IsDuplicateRemoteCandidate(remote_candidate)) { + LOG(INFO) << "Duplicate candidate: " << remote_candidate.ToString(); + return; + } + + // Try this candidate for all future ports. + remote_candidates_.push_back(RemoteCandidate(remote_candidate, origin_port)); +} + +// Set options on ourselves is simply setting options on all of our available +// port objects. +int P2PTransportChannel::SetOption(rtc::Socket::Option opt, int value) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + OptionMap::iterator it = options_.find(opt); + if (it == options_.end()) { + options_.insert(std::make_pair(opt, value)); + } else if (it->second == value) { + return 0; + } else { + it->second = value; + } + + for (size_t i = 0; i < ports_.size(); ++i) { + int val = ports_[i]->SetOption(opt, value); + if (val < 0) { + // Because this also occurs deferred, probably no point in reporting an + // error + LOG(WARNING) << "SetOption(" << opt << ", " << value << ") failed: " + << ports_[i]->GetError(); + } + } + return 0; +} + +bool P2PTransportChannel::GetOption(rtc::Socket::Option opt, int* value) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + const auto& found = options_.find(opt); + if (found == options_.end()) { + return false; + } + *value = found->second; + return true; +} + +// Send data to the other side, using our best connection. +int P2PTransportChannel::SendPacket(const char *data, size_t len, + const rtc::PacketOptions& options, + int flags) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + if (flags != 0) { + error_ = EINVAL; + return -1; + } + if (best_connection_ == NULL) { + error_ = EWOULDBLOCK; + return -1; + } + + int sent = best_connection_->Send(data, len, options); + if (sent <= 0) { + ASSERT(sent < 0); + error_ = best_connection_->GetError(); + } + return sent; +} + +bool P2PTransportChannel::GetStats(ConnectionInfos *infos) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + // Gather connection infos. + infos->clear(); + + std::vector<Connection *>::const_iterator it; + for (Connection* connection : connections_) { + ConnectionInfo info; + info.best_connection = (best_connection_ == connection); + info.receiving = connection->receiving(); + info.writable = + (connection->write_state() == Connection::STATE_WRITABLE); + info.timeout = + (connection->write_state() == Connection::STATE_WRITE_TIMEOUT); + info.new_connection = !connection->reported(); + connection->set_reported(true); + info.rtt = connection->rtt(); + info.sent_total_bytes = connection->sent_total_bytes(); + info.sent_bytes_second = connection->sent_bytes_second(); + info.sent_discarded_packets = connection->sent_discarded_packets(); + info.sent_total_packets = connection->sent_total_packets(); + info.recv_total_bytes = connection->recv_total_bytes(); + info.recv_bytes_second = connection->recv_bytes_second(); + info.local_candidate = connection->local_candidate(); + info.remote_candidate = connection->remote_candidate(); + info.key = connection; + infos->push_back(info); + } + + return true; +} + +rtc::DiffServCodePoint P2PTransportChannel::DefaultDscpValue() const { + OptionMap::const_iterator it = options_.find(rtc::Socket::OPT_DSCP); + if (it == options_.end()) { + return rtc::DSCP_NO_CHANGE; + } + return static_cast<rtc::DiffServCodePoint> (it->second); +} + +// Monitor connection states. +void P2PTransportChannel::UpdateConnectionStates() { + uint32_t now = rtc::Time(); + + // We need to copy the list of connections since some may delete themselves + // when we call UpdateState. + for (size_t i = 0; i < connections_.size(); ++i) + connections_[i]->UpdateState(now); +} + +// Prepare for best candidate sorting. +void P2PTransportChannel::RequestSort() { + if (!sort_dirty_) { + worker_thread_->Post(this, MSG_SORT); + sort_dirty_ = true; + } +} + +// Sort the available connections to find the best one. We also monitor +// the number of available connections and the current state. +void P2PTransportChannel::SortConnections() { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + // Make sure the connection states are up-to-date since this affects how they + // will be sorted. + UpdateConnectionStates(); + + // Any changes after this point will require a re-sort. + sort_dirty_ = false; + + // Find the best alternative connection by sorting. It is important to note + // that amongst equal preference, writable connections, this will choose the + // one whose estimated latency is lowest. So it is the only one that we + // need to consider switching to. + ConnectionCompare cmp; + std::stable_sort(connections_.begin(), connections_.end(), cmp); + LOG(LS_VERBOSE) << "Sorting " << connections_.size() + << " available connections:"; + for (size_t i = 0; i < connections_.size(); ++i) { + LOG(LS_VERBOSE) << connections_[i]->ToString(); + } + + Connection* top_connection = + (connections_.size() > 0) ? connections_[0] : nullptr; + + // If necessary, switch to the new choice. + // Note that |top_connection| doesn't have to be writable to become the best + // connection although it will have higher priority if it is writable. + if (ShouldSwitch(best_connection_, top_connection, ice_role_)) { + LOG(LS_INFO) << "Switching best connection: " << top_connection->ToString(); + SwitchBestConnectionTo(top_connection); + } + + // Controlled side can prune only if the best connection has been nominated. + // because otherwise it may delete the connection that will be selected by + // the controlling side. + if (ice_role_ == ICEROLE_CONTROLLING || best_nominated_connection()) { + PruneConnections(); + } + + // Check if all connections are timedout. + bool all_connections_timedout = true; + for (size_t i = 0; i < connections_.size(); ++i) { + if (connections_[i]->write_state() != Connection::STATE_WRITE_TIMEOUT) { + all_connections_timedout = false; + break; + } + } + + // Now update the writable state of the channel with the information we have + // so far. + if (best_connection_ && best_connection_->writable()) { + HandleWritable(); + } else if (all_connections_timedout) { + HandleAllTimedOut(); + } else { + HandleNotWritable(); + } + + // Update the state of this channel. This method is called whenever the + // state of any connection changes, so this is a good place to do this. + UpdateChannelState(); +} + +Connection* P2PTransportChannel::best_nominated_connection() const { + return (best_connection_ && best_connection_->nominated()) ? best_connection_ + : nullptr; +} + +void P2PTransportChannel::PruneConnections() { + // We can prune any connection for which there is a connected, writable + // connection on the same network with better or equal priority. We leave + // those with better priority just in case they become writable later (at + // which point, we would prune out the current best connection). We leave + // connections on other networks because they may not be using the same + // resources and they may represent very distinct paths over which we can + // switch. If the |premier| connection is not connected, we may be + // reconnecting a TCP connection and temporarily do not prune connections in + // this network. See the big comment in CompareConnections. + + // Get a list of the networks that we are using. + std::set<rtc::Network*> networks; + for (const Connection* conn : connections_) { + networks.insert(conn->port()->Network()); + } + for (rtc::Network* network : networks) { + Connection* premier = GetBestConnectionOnNetwork(network); + // Do not prune connections if the current best connection is weak on this + // network. Otherwise, it may delete connections prematurely. + if (!premier || premier->weak()) { + continue; + } + + for (Connection* conn : connections_) { + if ((conn != premier) && (conn->port()->Network() == network) && + (CompareConnectionCandidates(premier, conn) >= 0)) { + conn->Prune(); + } + } + } +} + +// Track the best connection, and let listeners know +void P2PTransportChannel::SwitchBestConnectionTo(Connection* conn) { + // Note: if conn is NULL, the previous best_connection_ has been destroyed, + // so don't use it. + Connection* old_best_connection = best_connection_; + best_connection_ = conn; + if (best_connection_) { + if (old_best_connection) { + LOG_J(LS_INFO, this) << "Previous best connection: " + << old_best_connection->ToString(); + } + LOG_J(LS_INFO, this) << "New best connection: " + << best_connection_->ToString(); + SignalRouteChange(this, best_connection_->remote_candidate()); + } else { + LOG_J(LS_INFO, this) << "No best connection"; + } +} + +void P2PTransportChannel::UpdateChannelState() { + // The Handle* functions already set the writable state. We'll just double- + // check it here. + bool writable = best_connection_ && best_connection_->writable(); + ASSERT(writable == this->writable()); + if (writable != this->writable()) + LOG(LS_ERROR) << "UpdateChannelState: writable state mismatch"; + + bool receiving = false; + for (const Connection* connection : connections_) { + if (connection->receiving()) { + receiving = true; + break; + } + } + set_receiving(receiving); +} + +// We checked the status of our connections and we had at least one that +// was writable, go into the writable state. +void P2PTransportChannel::HandleWritable() { + ASSERT(worker_thread_ == rtc::Thread::Current()); + if (writable()) { + return; + } + + for (PortAllocatorSession* session : allocator_sessions_) { + if (!session->IsGettingPorts()) { + continue; + } + // If gathering continually, keep the last session running so that it + // will gather candidates if the networks change. + if (gather_continually_ && session == allocator_sessions_.back()) { + session->ClearGettingPorts(); + break; + } + session->StopGettingPorts(); + } + + was_writable_ = true; + set_writable(true); +} + +// Notify upper layer about channel not writable state, if it was before. +void P2PTransportChannel::HandleNotWritable() { + ASSERT(worker_thread_ == rtc::Thread::Current()); + if (was_writable_) { + was_writable_ = false; + set_writable(false); + } +} + +// If all connections timed out, delete them all. +void P2PTransportChannel::HandleAllTimedOut() { + for (Connection* connection : connections_) { + connection->Destroy(); + } +} + +bool P2PTransportChannel::weak() const { + return !best_connection_ || best_connection_->weak(); +} + +// If we have a best connection, return it, otherwise return top one in the +// list (later we will mark it best). +Connection* P2PTransportChannel::GetBestConnectionOnNetwork( + rtc::Network* network) const { + // If the best connection is on this network, then it wins. + if (best_connection_ && (best_connection_->port()->Network() == network)) + return best_connection_; + + // Otherwise, we return the top-most in sorted order. + for (size_t i = 0; i < connections_.size(); ++i) { + if (connections_[i]->port()->Network() == network) + return connections_[i]; + } + + return NULL; +} + +// Handle any queued up requests +void P2PTransportChannel::OnMessage(rtc::Message *pmsg) { + switch (pmsg->message_id) { + case MSG_SORT: + OnSort(); + break; + case MSG_CHECK_AND_PING: + OnCheckAndPing(); + break; + default: + ASSERT(false); + break; + } +} + +// Handle queued up sort request +void P2PTransportChannel::OnSort() { + // Resort the connections based on the new statistics. + SortConnections(); +} + +// Handle queued up check-and-ping request +void P2PTransportChannel::OnCheckAndPing() { + // Make sure the states of the connections are up-to-date (since this affects + // which ones are pingable). + UpdateConnectionStates(); + // When the best connection is either not receiving or not writable, + // switch to weak ping delay. + int ping_delay = weak() ? weak_ping_delay_ : STRONG_PING_DELAY; + if (rtc::Time() >= last_ping_sent_ms_ + ping_delay) { + Connection* conn = FindNextPingableConnection(); + if (conn) { + PingConnection(conn); + } + } + int check_delay = std::min(ping_delay, check_receiving_delay_); + thread()->PostDelayed(check_delay, this, MSG_CHECK_AND_PING); +} + +// Is the connection in a state for us to even consider pinging the other side? +// We consider a connection pingable even if it's not connected because that's +// how a TCP connection is kicked into reconnecting on the active side. +bool P2PTransportChannel::IsPingable(Connection* conn) { + const Candidate& remote = conn->remote_candidate(); + // We should never get this far with an empty remote ufrag. + ASSERT(!remote.username().empty()); + if (remote.username().empty() || remote.password().empty()) { + // If we don't have an ICE ufrag and pwd, there's no way we can ping. + return false; + } + + // An never connected connection cannot be written to at all, so pinging is + // out of the question. However, if it has become WRITABLE, it is in the + // reconnecting state so ping is needed. + if (!conn->connected() && !conn->writable()) { + return false; + } + + // If the channel is weak, ping all candidates. Otherwise, we only + // want to ping connections that have not timed out on writing. + return weak() || conn->write_state() != Connection::STATE_WRITE_TIMEOUT; +} + +// Returns the next pingable connection to ping. This will be the oldest +// pingable connection unless we have a connected, writable connection that is +// past the maximum acceptable ping delay. When reconnecting a TCP connection, +// the best connection is disconnected, although still WRITABLE while +// reconnecting. The newly created connection should be selected as the ping +// target to become writable instead. See the big comment in CompareConnections. +Connection* P2PTransportChannel::FindNextPingableConnection() { + uint32_t now = rtc::Time(); + if (best_connection_ && best_connection_->connected() && + best_connection_->writable() && + (best_connection_->last_ping_sent() + MAX_CURRENT_STRONG_DELAY <= now)) { + return best_connection_; + } + + // First, find "triggered checks". We ping first those connections + // that have received a ping but have not sent a ping since receiving + // it (last_received_ping > last_sent_ping). But we shouldn't do + // triggered checks if the connection is already writable. + Connection* oldest_needing_triggered_check = nullptr; + Connection* oldest = nullptr; + for (Connection* conn : connections_) { + if (!IsPingable(conn)) { + continue; + } + bool needs_triggered_check = + (!conn->writable() && + conn->last_ping_received() > conn->last_ping_sent()); + if (needs_triggered_check && + (!oldest_needing_triggered_check || + (conn->last_ping_received() < + oldest_needing_triggered_check->last_ping_received()))) { + oldest_needing_triggered_check = conn; + } + if (!oldest || (conn->last_ping_sent() < oldest->last_ping_sent())) { + oldest = conn; + } + } + + if (oldest_needing_triggered_check) { + LOG(LS_INFO) << "Selecting connection for triggered check: " << + oldest_needing_triggered_check->ToString(); + return oldest_needing_triggered_check; + } + return oldest; +} + +// Apart from sending ping from |conn| this method also updates +// |use_candidate_attr| flag. The criteria to update this flag is +// explained below. +// Set USE-CANDIDATE if doing ICE AND this channel is in CONTROLLING AND +// a) Channel is in FULL ICE AND +// a.1) |conn| is the best connection OR +// a.2) there is no best connection OR +// a.3) the best connection is unwritable OR +// a.4) |conn| has higher priority than best_connection. +// b) we're doing LITE ICE AND +// b.1) |conn| is the best_connection AND +// b.2) |conn| is writable. +void P2PTransportChannel::PingConnection(Connection* conn) { + bool use_candidate = false; + if (remote_ice_mode_ == ICEMODE_FULL && ice_role_ == ICEROLE_CONTROLLING) { + use_candidate = (conn == best_connection_) || (best_connection_ == NULL) || + (!best_connection_->writable()) || + (conn->priority() > best_connection_->priority()); + } else if (remote_ice_mode_ == ICEMODE_LITE && conn == best_connection_) { + use_candidate = best_connection_->writable(); + } + conn->set_use_candidate_attr(use_candidate); + last_ping_sent_ms_ = rtc::Time(); + conn->Ping(last_ping_sent_ms_); +} + +// When a connection's state changes, we need to figure out who to use as +// the best connection again. It could have become usable, or become unusable. +void P2PTransportChannel::OnConnectionStateChange(Connection* connection) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + // Update the best connection if the state change is from pending best + // connection and role is controlled. + if (ice_role_ == ICEROLE_CONTROLLED) { + if (connection == pending_best_connection_ && connection->writable()) { + pending_best_connection_ = NULL; + LOG(LS_INFO) << "Switching best connection on controlled side" + << " because it's now writable: " << connection->ToString(); + SwitchBestConnectionTo(connection); + } + } + + // We have to unroll the stack before doing this because we may be changing + // the state of connections while sorting. + RequestSort(); +} + +// When a connection is removed, edit it out, and then update our best +// connection. +void P2PTransportChannel::OnConnectionDestroyed(Connection* connection) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + // Note: the previous best_connection_ may be destroyed by now, so don't + // use it. + + // Remove this connection from the list. + std::vector<Connection*>::iterator iter = + std::find(connections_.begin(), connections_.end(), connection); + ASSERT(iter != connections_.end()); + connections_.erase(iter); + + LOG_J(LS_INFO, this) << "Removed connection (" + << static_cast<int>(connections_.size()) << " remaining)"; + + if (pending_best_connection_ == connection) { + pending_best_connection_ = NULL; + } + + // If this is currently the best connection, then we need to pick a new one. + // The call to SortConnections will pick a new one. It looks at the current + // best connection in order to avoid switching between fairly similar ones. + // Since this connection is no longer an option, we can just set best to NULL + // and re-choose a best assuming that there was no best connection. + if (best_connection_ == connection) { + LOG(LS_INFO) << "Best connection destroyed. Will choose a new one."; + SwitchBestConnectionTo(NULL); + RequestSort(); + } + + SignalConnectionRemoved(this); +} + +// When a port is destroyed remove it from our list of ports to use for +// connection attempts. +void P2PTransportChannel::OnPortDestroyed(PortInterface* port) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + // Remove this port from the list (if we didn't drop it already). + std::vector<PortInterface*>::iterator iter = + std::find(ports_.begin(), ports_.end(), port); + if (iter != ports_.end()) + ports_.erase(iter); + + LOG(INFO) << "Removed port from p2p socket: " + << static_cast<int>(ports_.size()) << " remaining"; +} + +// We data is available, let listeners know +void P2PTransportChannel::OnReadPacket(Connection* connection, + const char* data, + size_t len, + const rtc::PacketTime& packet_time) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + // Do not deliver, if packet doesn't belong to the correct transport channel. + if (!FindConnection(connection)) + return; + + // Let the client know of an incoming packet + SignalReadPacket(this, data, len, packet_time, 0); + + // May need to switch the sending connection based on the receiving media path + // if this is the controlled side. + if (ice_role_ == ICEROLE_CONTROLLED && !best_nominated_connection() && + connection->writable() && best_connection_ != connection) { + SwitchBestConnectionTo(connection); + } +} + +void P2PTransportChannel::OnSentPacket(PortInterface* port, + const rtc::SentPacket& sent_packet) { + ASSERT(worker_thread_ == rtc::Thread::Current()); + + SignalSentPacket(this, sent_packet); +} + +void P2PTransportChannel::OnReadyToSend(Connection* connection) { + if (connection == best_connection_ && writable()) { + SignalReadyToSend(this); + } +} + +} // namespace cricket diff --git a/webrtc/p2p/base/p2ptransportchannel.h b/webrtc/p2p/base/p2ptransportchannel.h new file mode 100644 index 0000000000..9efb96c42d --- /dev/null +++ b/webrtc/p2p/base/p2ptransportchannel.h @@ -0,0 +1,267 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// P2PTransportChannel wraps up the state management of the connection between +// two P2P clients. Clients have candidate ports for connecting, and +// connections which are combinations of candidates from each end (Alice and +// Bob each have candidates, one candidate from Alice and one candidate from +// Bob are used to make a connection, repeat to make many connections). +// +// When all of the available connections become invalid (non-writable), we +// kick off a process of determining more candidates and more connections. +// +#ifndef WEBRTC_P2P_BASE_P2PTRANSPORTCHANNEL_H_ +#define WEBRTC_P2P_BASE_P2PTRANSPORTCHANNEL_H_ + +#include <map> +#include <string> +#include <vector> +#include "webrtc/p2p/base/candidate.h" +#include "webrtc/p2p/base/p2ptransport.h" +#include "webrtc/p2p/base/portallocator.h" +#include "webrtc/p2p/base/portinterface.h" +#include "webrtc/p2p/base/transport.h" +#include "webrtc/p2p/base/transportchannelimpl.h" +#include "webrtc/base/asyncpacketsocket.h" +#include "webrtc/base/sigslot.h" + +namespace cricket { + +extern const uint32_t WEAK_PING_DELAY; + +// Adds the port on which the candidate originated. +class RemoteCandidate : public Candidate { + public: + RemoteCandidate(const Candidate& c, PortInterface* origin_port) + : Candidate(c), origin_port_(origin_port) {} + + PortInterface* origin_port() { return origin_port_; } + + private: + PortInterface* origin_port_; +}; + +// P2PTransportChannel manages the candidates and connection process to keep +// two P2P clients connected to each other. +class P2PTransportChannel : public TransportChannelImpl, + public rtc::MessageHandler { + public: + P2PTransportChannel(const std::string& transport_name, + int component, + P2PTransport* transport, + PortAllocator* allocator); + virtual ~P2PTransportChannel(); + + // From TransportChannelImpl: + Transport* GetTransport() override { return transport_; } + TransportChannelState GetState() const override; + void SetIceRole(IceRole role) override; + IceRole GetIceRole() const override { return ice_role_; } + void SetIceTiebreaker(uint64_t tiebreaker) override; + void SetIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) override; + void SetRemoteIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) override; + void SetRemoteIceMode(IceMode mode) override; + void Connect() override; + void MaybeStartGathering() override; + IceGatheringState gathering_state() const override { + return gathering_state_; + } + void AddRemoteCandidate(const Candidate& candidate) override; + // Sets the receiving timeout and gather_continually. + // This also sets the check_receiving_delay proportionally. + void SetIceConfig(const IceConfig& config) override; + + // From TransportChannel: + int SendPacket(const char* data, + size_t len, + const rtc::PacketOptions& options, + int flags) override; + int SetOption(rtc::Socket::Option opt, int value) override; + bool GetOption(rtc::Socket::Option opt, int* value) override; + int GetError() override { return error_; } + bool GetStats(std::vector<ConnectionInfo>* stats) override; + + const Connection* best_connection() const { return best_connection_; } + void set_incoming_only(bool value) { incoming_only_ = value; } + + // Note: This is only for testing purpose. + // |ports_| should not be changed from outside. + const std::vector<PortInterface*>& ports() { return ports_; } + + IceMode remote_ice_mode() const { return remote_ice_mode_; } + + // DTLS methods. + bool IsDtlsActive() const override { return false; } + + // Default implementation. + bool GetSslRole(rtc::SSLRole* role) const override { return false; } + + bool SetSslRole(rtc::SSLRole role) override { return false; } + + // Set up the ciphers to use for DTLS-SRTP. + bool SetSrtpCiphers(const std::vector<std::string>& ciphers) override { + return false; + } + + // Find out which DTLS-SRTP cipher was negotiated. + bool GetSrtpCryptoSuite(std::string* cipher) override { return false; } + + // Find out which DTLS cipher was negotiated. + bool GetSslCipherSuite(int* cipher) override { return false; } + + // Returns null because the channel is not encrypted by default. + rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const override { + return nullptr; + } + + bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override { + return false; + } + + // Allows key material to be extracted for external encryption. + bool ExportKeyingMaterial(const std::string& label, + const uint8_t* context, + size_t context_len, + bool use_context, + uint8_t* result, + size_t result_len) override { + return false; + } + + bool SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override { + return false; + } + + // Set DTLS Remote fingerprint. Must be after local identity set. + bool SetRemoteFingerprint(const std::string& digest_alg, + const uint8_t* digest, + size_t digest_len) override { + return false; + } + + int receiving_timeout() const { return receiving_timeout_; } + int check_receiving_delay() const { return check_receiving_delay_; } + + // Helper method used only in unittest. + rtc::DiffServCodePoint DefaultDscpValue() const; + + // Public for unit tests. + Connection* FindNextPingableConnection(); + + // Public for unit tests. + const std::vector<Connection*>& connections() const { return connections_; } + + private: + rtc::Thread* thread() { return worker_thread_; } + PortAllocatorSession* allocator_session() { + return allocator_sessions_.back(); + } + + // A transport channel is weak if the current best connection is either + // not receiving or not writable, or if there is no best connection at all. + bool weak() const; + void UpdateConnectionStates(); + void RequestSort(); + void SortConnections(); + void SwitchBestConnectionTo(Connection* conn); + void UpdateChannelState(); + void HandleWritable(); + void HandleNotWritable(); + void HandleAllTimedOut(); + + Connection* GetBestConnectionOnNetwork(rtc::Network* network) const; + bool CreateConnections(const Candidate& remote_candidate, + PortInterface* origin_port); + bool CreateConnection(PortInterface* port, + const Candidate& remote_candidate, + PortInterface* origin_port); + bool FindConnection(cricket::Connection* connection) const; + + uint32_t GetRemoteCandidateGeneration(const Candidate& candidate); + bool IsDuplicateRemoteCandidate(const Candidate& candidate); + void RememberRemoteCandidate(const Candidate& remote_candidate, + PortInterface* origin_port); + bool IsPingable(Connection* conn); + void PingConnection(Connection* conn); + void AddAllocatorSession(PortAllocatorSession* session); + void AddConnection(Connection* connection); + + void OnPortReady(PortAllocatorSession *session, PortInterface* port); + void OnCandidatesReady(PortAllocatorSession *session, + const std::vector<Candidate>& candidates); + void OnCandidatesAllocationDone(PortAllocatorSession* session); + void OnUnknownAddress(PortInterface* port, + const rtc::SocketAddress& addr, + ProtocolType proto, + IceMessage* stun_msg, + const std::string& remote_username, + bool port_muxed); + void OnPortDestroyed(PortInterface* port); + void OnRoleConflict(PortInterface* port); + + void OnConnectionStateChange(Connection* connection); + void OnReadPacket(Connection *connection, const char *data, size_t len, + const rtc::PacketTime& packet_time); + void OnSentPacket(PortInterface* port, const rtc::SentPacket& sent_packet); + void OnReadyToSend(Connection* connection); + void OnConnectionDestroyed(Connection *connection); + + void OnNominated(Connection* conn); + + void OnMessage(rtc::Message* pmsg) override; + void OnSort(); + void OnCheckAndPing(); + + void PruneConnections(); + Connection* best_nominated_connection() const; + + P2PTransport* transport_; + PortAllocator* allocator_; + rtc::Thread* worker_thread_; + bool incoming_only_; + int error_; + std::vector<PortAllocatorSession*> allocator_sessions_; + std::vector<PortInterface *> ports_; + std::vector<Connection *> connections_; + Connection* best_connection_; + // Connection selected by the controlling agent. This should be used only + // at controlled side when protocol type is RFC5245. + Connection* pending_best_connection_; + std::vector<RemoteCandidate> remote_candidates_; + bool sort_dirty_; // indicates whether another sort is needed right now + bool was_writable_; + bool had_connection_ = false; // if connections_ has ever been nonempty + typedef std::map<rtc::Socket::Option, int> OptionMap; + OptionMap options_; + std::string ice_ufrag_; + std::string ice_pwd_; + std::string remote_ice_ufrag_; + std::string remote_ice_pwd_; + IceMode remote_ice_mode_; + IceRole ice_role_; + uint64_t tiebreaker_; + uint32_t remote_candidate_generation_; + IceGatheringState gathering_state_; + + int check_receiving_delay_; + int receiving_timeout_; + uint32_t last_ping_sent_ms_ = 0; + bool gather_continually_ = false; + int weak_ping_delay_ = WEAK_PING_DELAY; + + RTC_DISALLOW_COPY_AND_ASSIGN(P2PTransportChannel); +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_P2PTRANSPORTCHANNEL_H_ diff --git a/webrtc/p2p/base/p2ptransportchannel_unittest.cc b/webrtc/p2p/base/p2ptransportchannel_unittest.cc new file mode 100644 index 0000000000..37cda7c661 --- /dev/null +++ b/webrtc/p2p/base/p2ptransportchannel_unittest.cc @@ -0,0 +1,2192 @@ +/* + * Copyright 2009 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/p2ptransportchannel.h" +#include "webrtc/p2p/base/testrelayserver.h" +#include "webrtc/p2p/base/teststunserver.h" +#include "webrtc/p2p/base/testturnserver.h" +#include "webrtc/p2p/client/basicportallocator.h" +#include "webrtc/p2p/client/fakeportallocator.h" +#include "webrtc/base/dscp.h" +#include "webrtc/base/fakenetwork.h" +#include "webrtc/base/firewallsocketserver.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/natserver.h" +#include "webrtc/base/natsocketfactory.h" +#include "webrtc/base/physicalsocketserver.h" +#include "webrtc/base/proxyserver.h" +#include "webrtc/base/socketaddress.h" +#include "webrtc/base/ssladapter.h" +#include "webrtc/base/thread.h" +#include "webrtc/base/virtualsocketserver.h" + +using cricket::kDefaultPortAllocatorFlags; +using cricket::kMinimumStepDelay; +using cricket::kDefaultStepDelay; +using cricket::PORTALLOCATOR_ENABLE_SHARED_SOCKET; +using cricket::ServerAddresses; +using rtc::SocketAddress; + +static const int kDefaultTimeout = 1000; +static const int kOnlyLocalPorts = cricket::PORTALLOCATOR_DISABLE_STUN | + cricket::PORTALLOCATOR_DISABLE_RELAY | + cricket::PORTALLOCATOR_DISABLE_TCP; +// Addresses on the public internet. +static const SocketAddress kPublicAddrs[2] = + { SocketAddress("11.11.11.11", 0), SocketAddress("22.22.22.22", 0) }; +// IPv6 Addresses on the public internet. +static const SocketAddress kIPv6PublicAddrs[2] = { + SocketAddress("2400:4030:1:2c00:be30:abcd:efab:cdef", 0), + SocketAddress("2620:0:1000:1b03:2e41:38ff:fea6:f2a4", 0) +}; +// For configuring multihomed clients. +static const SocketAddress kAlternateAddrs[2] = + { SocketAddress("11.11.11.101", 0), SocketAddress("22.22.22.202", 0) }; +// Addresses for HTTP proxy servers. +static const SocketAddress kHttpsProxyAddrs[2] = + { SocketAddress("11.11.11.1", 443), SocketAddress("22.22.22.1", 443) }; +// Addresses for SOCKS proxy servers. +static const SocketAddress kSocksProxyAddrs[2] = + { SocketAddress("11.11.11.1", 1080), SocketAddress("22.22.22.1", 1080) }; +// Internal addresses for NAT boxes. +static const SocketAddress kNatAddrs[2] = + { SocketAddress("192.168.1.1", 0), SocketAddress("192.168.2.1", 0) }; +// Private addresses inside the NAT private networks. +static const SocketAddress kPrivateAddrs[2] = + { SocketAddress("192.168.1.11", 0), SocketAddress("192.168.2.22", 0) }; +// For cascaded NATs, the internal addresses of the inner NAT boxes. +static const SocketAddress kCascadedNatAddrs[2] = + { SocketAddress("192.168.10.1", 0), SocketAddress("192.168.20.1", 0) }; +// For cascaded NATs, private addresses inside the inner private networks. +static const SocketAddress kCascadedPrivateAddrs[2] = + { SocketAddress("192.168.10.11", 0), SocketAddress("192.168.20.22", 0) }; +// The address of the public STUN server. +static const SocketAddress kStunAddr("99.99.99.1", cricket::STUN_SERVER_PORT); +// The addresses for the public relay server. +static const SocketAddress kRelayUdpIntAddr("99.99.99.2", 5000); +static const SocketAddress kRelayUdpExtAddr("99.99.99.3", 5001); +static const SocketAddress kRelayTcpIntAddr("99.99.99.2", 5002); +static const SocketAddress kRelayTcpExtAddr("99.99.99.3", 5003); +static const SocketAddress kRelaySslTcpIntAddr("99.99.99.2", 5004); +static const SocketAddress kRelaySslTcpExtAddr("99.99.99.3", 5005); +// The addresses for the public turn server. +static const SocketAddress kTurnUdpIntAddr("99.99.99.4", + cricket::STUN_SERVER_PORT); +static const SocketAddress kTurnUdpExtAddr("99.99.99.5", 0); +static const cricket::RelayCredentials kRelayCredentials("test", "test"); + +// Based on ICE_UFRAG_LENGTH +static const char* kIceUfrag[4] = {"TESTICEUFRAG0000", "TESTICEUFRAG0001", + "TESTICEUFRAG0002", "TESTICEUFRAG0003"}; +// Based on ICE_PWD_LENGTH +static const char* kIcePwd[4] = {"TESTICEPWD00000000000000", + "TESTICEPWD00000000000001", + "TESTICEPWD00000000000002", + "TESTICEPWD00000000000003"}; + +static const uint64_t kTiebreaker1 = 11111; +static const uint64_t kTiebreaker2 = 22222; + +enum { + MSG_CANDIDATE +}; + +static cricket::IceConfig CreateIceConfig(int receiving_timeout_ms, + bool gather_continually) { + cricket::IceConfig config; + config.receiving_timeout_ms = receiving_timeout_ms; + config.gather_continually = gather_continually; + return config; +} + +// This test simulates 2 P2P endpoints that want to establish connectivity +// with each other over various network topologies and conditions, which can be +// specified in each individial test. +// A virtual network (via VirtualSocketServer) along with virtual firewalls and +// NATs (via Firewall/NATSocketServer) are used to simulate the various network +// conditions. We can configure the IP addresses of the endpoints, +// block various types of connectivity, or add arbitrary levels of NAT. +// We also run a STUN server and a relay server on the virtual network to allow +// our typical P2P mechanisms to do their thing. +// For each case, we expect the P2P stack to eventually settle on a specific +// form of connectivity to the other side. The test checks that the P2P +// negotiation successfully establishes connectivity within a certain time, +// and that the result is what we expect. +// Note that this class is a base class for use by other tests, who will provide +// specialized test behavior. +class P2PTransportChannelTestBase : public testing::Test, + public rtc::MessageHandler, + public sigslot::has_slots<> { + public: + P2PTransportChannelTestBase() + : main_(rtc::Thread::Current()), + pss_(new rtc::PhysicalSocketServer), + vss_(new rtc::VirtualSocketServer(pss_.get())), + nss_(new rtc::NATSocketServer(vss_.get())), + ss_(new rtc::FirewallSocketServer(nss_.get())), + ss_scope_(ss_.get()), + stun_server_(cricket::TestStunServer::Create(main_, kStunAddr)), + turn_server_(main_, kTurnUdpIntAddr, kTurnUdpExtAddr), + relay_server_(main_, kRelayUdpIntAddr, kRelayUdpExtAddr, + kRelayTcpIntAddr, kRelayTcpExtAddr, + kRelaySslTcpIntAddr, kRelaySslTcpExtAddr), + socks_server1_(ss_.get(), kSocksProxyAddrs[0], + ss_.get(), kSocksProxyAddrs[0]), + socks_server2_(ss_.get(), kSocksProxyAddrs[1], + ss_.get(), kSocksProxyAddrs[1]), + clear_remote_candidates_ufrag_pwd_(false), + force_relay_(false) { + ep1_.role_ = cricket::ICEROLE_CONTROLLING; + ep2_.role_ = cricket::ICEROLE_CONTROLLED; + + ServerAddresses stun_servers; + stun_servers.insert(kStunAddr); + ep1_.allocator_.reset(new cricket::BasicPortAllocator( + &ep1_.network_manager_, + stun_servers, kRelayUdpIntAddr, kRelayTcpIntAddr, kRelaySslTcpIntAddr)); + ep2_.allocator_.reset(new cricket::BasicPortAllocator( + &ep2_.network_manager_, + stun_servers, kRelayUdpIntAddr, kRelayTcpIntAddr, kRelaySslTcpIntAddr)); + } + + protected: + enum Config { + OPEN, // Open to the Internet + NAT_FULL_CONE, // NAT, no filtering + NAT_ADDR_RESTRICTED, // NAT, must send to an addr to recv + NAT_PORT_RESTRICTED, // NAT, must send to an addr+port to recv + NAT_SYMMETRIC, // NAT, endpoint-dependent bindings + NAT_DOUBLE_CONE, // Double NAT, both cone + NAT_SYMMETRIC_THEN_CONE, // Double NAT, symmetric outer, cone inner + BLOCK_UDP, // Firewall, UDP in/out blocked + BLOCK_UDP_AND_INCOMING_TCP, // Firewall, UDP in/out and TCP in blocked + BLOCK_ALL_BUT_OUTGOING_HTTP, // Firewall, only TCP out on 80/443 + PROXY_HTTPS, // All traffic through HTTPS proxy + PROXY_SOCKS, // All traffic through SOCKS proxy + NUM_CONFIGS + }; + + struct Result { + Result(const std::string& lt, const std::string& lp, + const std::string& rt, const std::string& rp, + const std::string& lt2, const std::string& lp2, + const std::string& rt2, const std::string& rp2, int wait) + : local_type(lt), local_proto(lp), remote_type(rt), remote_proto(rp), + local_type2(lt2), local_proto2(lp2), remote_type2(rt2), + remote_proto2(rp2), connect_wait(wait) { + } + + std::string local_type; + std::string local_proto; + std::string remote_type; + std::string remote_proto; + std::string local_type2; + std::string local_proto2; + std::string remote_type2; + std::string remote_proto2; + int connect_wait; + }; + + struct ChannelData { + bool CheckData(const char* data, int len) { + bool ret = false; + if (!ch_packets_.empty()) { + std::string packet = ch_packets_.front(); + ret = (packet == std::string(data, len)); + ch_packets_.pop_front(); + } + return ret; + } + + std::string name_; // TODO - Currently not used. + std::list<std::string> ch_packets_; + rtc::scoped_ptr<cricket::P2PTransportChannel> ch_; + }; + + struct CandidateData : public rtc::MessageData { + CandidateData(cricket::TransportChannel* ch, const cricket::Candidate& c) + : channel(ch), candidate(c) { + } + cricket::TransportChannel* channel; + cricket::Candidate candidate; + }; + + struct Endpoint { + Endpoint() + : role_(cricket::ICEROLE_UNKNOWN), + tiebreaker_(0), + role_conflict_(false), + save_candidates_(false) {} + bool HasChannel(cricket::TransportChannel* ch) { + return (ch == cd1_.ch_.get() || ch == cd2_.ch_.get()); + } + ChannelData* GetChannelData(cricket::TransportChannel* ch) { + if (!HasChannel(ch)) return NULL; + if (cd1_.ch_.get() == ch) + return &cd1_; + else + return &cd2_; + } + + void SetIceRole(cricket::IceRole role) { role_ = role; } + cricket::IceRole ice_role() { return role_; } + void SetIceTiebreaker(uint64_t tiebreaker) { tiebreaker_ = tiebreaker; } + uint64_t GetIceTiebreaker() { return tiebreaker_; } + void OnRoleConflict(bool role_conflict) { role_conflict_ = role_conflict; } + bool role_conflict() { return role_conflict_; } + void SetAllocationStepDelay(uint32_t delay) { + allocator_->set_step_delay(delay); + } + void SetAllowTcpListen(bool allow_tcp_listen) { + allocator_->set_allow_tcp_listen(allow_tcp_listen); + } + + rtc::FakeNetworkManager network_manager_; + rtc::scoped_ptr<cricket::BasicPortAllocator> allocator_; + ChannelData cd1_; + ChannelData cd2_; + cricket::IceRole role_; + uint64_t tiebreaker_; + bool role_conflict_; + bool save_candidates_; + std::vector<CandidateData*> saved_candidates_; + }; + + ChannelData* GetChannelData(cricket::TransportChannel* channel) { + if (ep1_.HasChannel(channel)) + return ep1_.GetChannelData(channel); + else + return ep2_.GetChannelData(channel); + } + + void CreateChannels(int num) { + std::string ice_ufrag_ep1_cd1_ch = kIceUfrag[0]; + std::string ice_pwd_ep1_cd1_ch = kIcePwd[0]; + std::string ice_ufrag_ep2_cd1_ch = kIceUfrag[1]; + std::string ice_pwd_ep2_cd1_ch = kIcePwd[1]; + ep1_.cd1_.ch_.reset(CreateChannel( + 0, cricket::ICE_CANDIDATE_COMPONENT_DEFAULT, + ice_ufrag_ep1_cd1_ch, ice_pwd_ep1_cd1_ch, + ice_ufrag_ep2_cd1_ch, ice_pwd_ep2_cd1_ch)); + ep2_.cd1_.ch_.reset(CreateChannel( + 1, cricket::ICE_CANDIDATE_COMPONENT_DEFAULT, + ice_ufrag_ep2_cd1_ch, ice_pwd_ep2_cd1_ch, + ice_ufrag_ep1_cd1_ch, ice_pwd_ep1_cd1_ch)); + if (num == 2) { + std::string ice_ufrag_ep1_cd2_ch = kIceUfrag[2]; + std::string ice_pwd_ep1_cd2_ch = kIcePwd[2]; + std::string ice_ufrag_ep2_cd2_ch = kIceUfrag[3]; + std::string ice_pwd_ep2_cd2_ch = kIcePwd[3]; + ep1_.cd2_.ch_.reset(CreateChannel( + 0, cricket::ICE_CANDIDATE_COMPONENT_DEFAULT, + ice_ufrag_ep1_cd2_ch, ice_pwd_ep1_cd2_ch, + ice_ufrag_ep2_cd2_ch, ice_pwd_ep2_cd2_ch)); + ep2_.cd2_.ch_.reset(CreateChannel( + 1, cricket::ICE_CANDIDATE_COMPONENT_DEFAULT, + ice_ufrag_ep2_cd2_ch, ice_pwd_ep2_cd2_ch, + ice_ufrag_ep1_cd2_ch, ice_pwd_ep1_cd2_ch)); + } + } + cricket::P2PTransportChannel* CreateChannel( + int endpoint, + int component, + const std::string& local_ice_ufrag, + const std::string& local_ice_pwd, + const std::string& remote_ice_ufrag, + const std::string& remote_ice_pwd) { + cricket::P2PTransportChannel* channel = new cricket::P2PTransportChannel( + "test content name", component, NULL, GetAllocator(endpoint)); + channel->SignalCandidateGathered.connect( + this, &P2PTransportChannelTestBase::OnCandidate); + channel->SignalReadPacket.connect( + this, &P2PTransportChannelTestBase::OnReadPacket); + channel->SignalRoleConflict.connect( + this, &P2PTransportChannelTestBase::OnRoleConflict); + channel->SetIceCredentials(local_ice_ufrag, local_ice_pwd); + if (clear_remote_candidates_ufrag_pwd_) { + // This only needs to be set if we're clearing them from the + // candidates. Some unit tests rely on this not being set. + channel->SetRemoteIceCredentials(remote_ice_ufrag, remote_ice_pwd); + } + channel->SetIceRole(GetEndpoint(endpoint)->ice_role()); + channel->SetIceTiebreaker(GetEndpoint(endpoint)->GetIceTiebreaker()); + channel->Connect(); + channel->MaybeStartGathering(); + return channel; + } + void DestroyChannels() { + ep1_.cd1_.ch_.reset(); + ep2_.cd1_.ch_.reset(); + ep1_.cd2_.ch_.reset(); + ep2_.cd2_.ch_.reset(); + } + cricket::P2PTransportChannel* ep1_ch1() { return ep1_.cd1_.ch_.get(); } + cricket::P2PTransportChannel* ep1_ch2() { return ep1_.cd2_.ch_.get(); } + cricket::P2PTransportChannel* ep2_ch1() { return ep2_.cd1_.ch_.get(); } + cricket::P2PTransportChannel* ep2_ch2() { return ep2_.cd2_.ch_.get(); } + + // Common results. + static const Result kLocalUdpToLocalUdp; + static const Result kLocalUdpToStunUdp; + static const Result kLocalUdpToPrflxUdp; + static const Result kPrflxUdpToLocalUdp; + static const Result kStunUdpToLocalUdp; + static const Result kStunUdpToStunUdp; + static const Result kPrflxUdpToStunUdp; + static const Result kLocalUdpToRelayUdp; + static const Result kPrflxUdpToRelayUdp; + static const Result kLocalTcpToLocalTcp; + static const Result kLocalTcpToPrflxTcp; + static const Result kPrflxTcpToLocalTcp; + + rtc::NATSocketServer* nat() { return nss_.get(); } + rtc::FirewallSocketServer* fw() { return ss_.get(); } + + Endpoint* GetEndpoint(int endpoint) { + if (endpoint == 0) { + return &ep1_; + } else if (endpoint == 1) { + return &ep2_; + } else { + return NULL; + } + } + cricket::PortAllocator* GetAllocator(int endpoint) { + return GetEndpoint(endpoint)->allocator_.get(); + } + void AddAddress(int endpoint, const SocketAddress& addr) { + GetEndpoint(endpoint)->network_manager_.AddInterface(addr); + } + void RemoveAddress(int endpoint, const SocketAddress& addr) { + GetEndpoint(endpoint)->network_manager_.RemoveInterface(addr); + } + void SetProxy(int endpoint, rtc::ProxyType type) { + rtc::ProxyInfo info; + info.type = type; + info.address = (type == rtc::PROXY_HTTPS) ? + kHttpsProxyAddrs[endpoint] : kSocksProxyAddrs[endpoint]; + GetAllocator(endpoint)->set_proxy("unittest/1.0", info); + } + void SetAllocatorFlags(int endpoint, int flags) { + GetAllocator(endpoint)->set_flags(flags); + } + void SetIceRole(int endpoint, cricket::IceRole role) { + GetEndpoint(endpoint)->SetIceRole(role); + } + void SetIceTiebreaker(int endpoint, uint64_t tiebreaker) { + GetEndpoint(endpoint)->SetIceTiebreaker(tiebreaker); + } + bool GetRoleConflict(int endpoint) { + return GetEndpoint(endpoint)->role_conflict(); + } + void SetAllocationStepDelay(int endpoint, uint32_t delay) { + return GetEndpoint(endpoint)->SetAllocationStepDelay(delay); + } + void SetAllowTcpListen(int endpoint, bool allow_tcp_listen) { + return GetEndpoint(endpoint)->SetAllowTcpListen(allow_tcp_listen); + } + bool IsLocalToPrflxOrTheReverse(const Result& expected) { + return ( + (expected.local_type == "local" && expected.remote_type == "prflx") || + (expected.local_type == "prflx" && expected.remote_type == "local")); + } + + // Return true if the approprite parts of the expected Result, based + // on the local and remote candidate of ep1_ch1, match. This can be + // used in an EXPECT_TRUE_WAIT. + bool CheckCandidate1(const Result& expected) { + const std::string& local_type = LocalCandidate(ep1_ch1())->type(); + const std::string& local_proto = LocalCandidate(ep1_ch1())->protocol(); + const std::string& remote_type = RemoteCandidate(ep1_ch1())->type(); + const std::string& remote_proto = RemoteCandidate(ep1_ch1())->protocol(); + return ((local_proto == expected.local_proto && + remote_proto == expected.remote_proto) && + ((local_type == expected.local_type && + remote_type == expected.remote_type) || + // Sometimes we expect local -> prflx or prflx -> local + // and instead get prflx -> local or local -> prflx, and + // that's OK. + (IsLocalToPrflxOrTheReverse(expected) && + local_type == expected.remote_type && + remote_type == expected.local_type))); + } + + // EXPECT_EQ on the approprite parts of the expected Result, based + // on the local and remote candidate of ep1_ch1. This is like + // CheckCandidate1, except that it will provide more detail about + // what didn't match. + void ExpectCandidate1(const Result& expected) { + if (CheckCandidate1(expected)) { + return; + } + + const std::string& local_type = LocalCandidate(ep1_ch1())->type(); + const std::string& local_proto = LocalCandidate(ep1_ch1())->protocol(); + const std::string& remote_type = RemoteCandidate(ep1_ch1())->type(); + const std::string& remote_proto = RemoteCandidate(ep1_ch1())->protocol(); + EXPECT_EQ(expected.local_type, local_type); + EXPECT_EQ(expected.remote_type, remote_type); + EXPECT_EQ(expected.local_proto, local_proto); + EXPECT_EQ(expected.remote_proto, remote_proto); + } + + // Return true if the approprite parts of the expected Result, based + // on the local and remote candidate of ep2_ch1, match. This can be + // used in an EXPECT_TRUE_WAIT. + bool CheckCandidate2(const Result& expected) { + const std::string& local_type = LocalCandidate(ep2_ch1())->type(); + // const std::string& remote_type = RemoteCandidate(ep2_ch1())->type(); + const std::string& local_proto = LocalCandidate(ep2_ch1())->protocol(); + const std::string& remote_proto = RemoteCandidate(ep2_ch1())->protocol(); + // Removed remote_type comparision aginst best connection remote + // candidate. This is done to handle remote type discrepancy from + // local to stun based on the test type. + // For example in case of Open -> NAT, ep2 channels will have LULU + // and in other cases like NAT -> NAT it will be LUSU. To avoid these + // mismatches and we are doing comparision in different way. + // i.e. when don't match its remote type is either local or stun. + // TODO(ronghuawu): Refine the test criteria. + // https://code.google.com/p/webrtc/issues/detail?id=1953 + return ((local_proto == expected.local_proto2 && + remote_proto == expected.remote_proto2) && + (local_type == expected.local_type2 || + // Sometimes we expect local -> prflx or prflx -> local + // and instead get prflx -> local or local -> prflx, and + // that's OK. + (IsLocalToPrflxOrTheReverse(expected) && + local_type == expected.remote_type2))); + } + + // EXPECT_EQ on the approprite parts of the expected Result, based + // on the local and remote candidate of ep2_ch1. This is like + // CheckCandidate2, except that it will provide more detail about + // what didn't match. + void ExpectCandidate2(const Result& expected) { + if (CheckCandidate2(expected)) { + return; + } + + const std::string& local_type = LocalCandidate(ep2_ch1())->type(); + const std::string& local_proto = LocalCandidate(ep2_ch1())->protocol(); + const std::string& remote_type = RemoteCandidate(ep2_ch1())->type(); + EXPECT_EQ(expected.local_proto2, local_proto); + EXPECT_EQ(expected.remote_proto2, remote_type); + EXPECT_EQ(expected.local_type2, local_type); + if (remote_type != expected.remote_type2) { + EXPECT_TRUE(expected.remote_type2 == cricket::LOCAL_PORT_TYPE || + expected.remote_type2 == cricket::STUN_PORT_TYPE); + EXPECT_TRUE(remote_type == cricket::LOCAL_PORT_TYPE || + remote_type == cricket::STUN_PORT_TYPE || + remote_type == cricket::PRFLX_PORT_TYPE); + } + } + + void Test(const Result& expected) { + int32_t connect_start = rtc::Time(), connect_time; + + // Create the channels and wait for them to connect. + CreateChannels(1); + EXPECT_TRUE_WAIT_MARGIN(ep1_ch1() != NULL && + ep2_ch1() != NULL && + ep1_ch1()->receiving() && + ep1_ch1()->writable() && + ep2_ch1()->receiving() && + ep2_ch1()->writable(), + expected.connect_wait, + 1000); + connect_time = rtc::TimeSince(connect_start); + if (connect_time < expected.connect_wait) { + LOG(LS_INFO) << "Connect time: " << connect_time << " ms"; + } else { + LOG(LS_INFO) << "Connect time: " << "TIMEOUT (" + << expected.connect_wait << " ms)"; + } + + // Allow a few turns of the crank for the best connections to emerge. + // This may take up to 2 seconds. + if (ep1_ch1()->best_connection() && + ep2_ch1()->best_connection()) { + int32_t converge_start = rtc::Time(), converge_time; + int converge_wait = 2000; + EXPECT_TRUE_WAIT_MARGIN(CheckCandidate1(expected), converge_wait, + converge_wait); + // Also do EXPECT_EQ on each part so that failures are more verbose. + ExpectCandidate1(expected); + + // Verifying remote channel best connection information. This is done + // only for the RFC 5245 as controlled agent will use USE-CANDIDATE + // from controlling (ep1) agent. We can easily predict from EP1 result + // matrix. + + // Checking for best connection candidates information at remote. + EXPECT_TRUE_WAIT(CheckCandidate2(expected), kDefaultTimeout); + // For verbose + ExpectCandidate2(expected); + + converge_time = rtc::TimeSince(converge_start); + if (converge_time < converge_wait) { + LOG(LS_INFO) << "Converge time: " << converge_time << " ms"; + } else { + LOG(LS_INFO) << "Converge time: " << "TIMEOUT (" + << converge_wait << " ms)"; + } + } + // Try sending some data to other end. + TestSendRecv(1); + + // Destroy the channels, and wait for them to be fully cleaned up. + DestroyChannels(); + } + + void TestSendRecv(int channels) { + for (int i = 0; i < 10; ++i) { + const char* data = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + int len = static_cast<int>(strlen(data)); + // local_channel1 <==> remote_channel1 + EXPECT_EQ_WAIT(len, SendData(ep1_ch1(), data, len), 1000); + EXPECT_TRUE_WAIT(CheckDataOnChannel(ep2_ch1(), data, len), 1000); + EXPECT_EQ_WAIT(len, SendData(ep2_ch1(), data, len), 1000); + EXPECT_TRUE_WAIT(CheckDataOnChannel(ep1_ch1(), data, len), 1000); + if (channels == 2 && ep1_ch2() && ep2_ch2()) { + // local_channel2 <==> remote_channel2 + EXPECT_EQ_WAIT(len, SendData(ep1_ch2(), data, len), 1000); + EXPECT_TRUE_WAIT(CheckDataOnChannel(ep2_ch2(), data, len), 1000); + EXPECT_EQ_WAIT(len, SendData(ep2_ch2(), data, len), 1000); + EXPECT_TRUE_WAIT(CheckDataOnChannel(ep1_ch2(), data, len), 1000); + } + } + } + + // This test waits for the transport to become receiving and writable on both + // end points. Once they are, the end points set new local ice credentials and + // restart the ice gathering. Finally it waits for the transport to select a + // new connection using the newly generated ice candidates. + // Before calling this function the end points must be configured. + void TestHandleIceUfragPasswordChanged() { + ep1_ch1()->SetRemoteIceCredentials(kIceUfrag[1], kIcePwd[1]); + ep2_ch1()->SetRemoteIceCredentials(kIceUfrag[0], kIcePwd[0]); + EXPECT_TRUE_WAIT_MARGIN(ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 1000, 1000); + + const cricket::Candidate* old_local_candidate1 = LocalCandidate(ep1_ch1()); + const cricket::Candidate* old_local_candidate2 = LocalCandidate(ep2_ch1()); + const cricket::Candidate* old_remote_candidate1 = + RemoteCandidate(ep1_ch1()); + const cricket::Candidate* old_remote_candidate2 = + RemoteCandidate(ep2_ch1()); + + ep1_ch1()->SetIceCredentials(kIceUfrag[2], kIcePwd[2]); + ep1_ch1()->SetRemoteIceCredentials(kIceUfrag[3], kIcePwd[3]); + ep1_ch1()->MaybeStartGathering(); + ep2_ch1()->SetIceCredentials(kIceUfrag[3], kIcePwd[3]); + ep2_ch1()->SetRemoteIceCredentials(kIceUfrag[2], kIcePwd[2]); + ep2_ch1()->MaybeStartGathering(); + + EXPECT_TRUE_WAIT_MARGIN(LocalCandidate(ep1_ch1())->generation() != + old_local_candidate1->generation(), + 1000, 1000); + EXPECT_TRUE_WAIT_MARGIN(LocalCandidate(ep2_ch1())->generation() != + old_local_candidate2->generation(), + 1000, 1000); + EXPECT_TRUE_WAIT_MARGIN(RemoteCandidate(ep1_ch1())->generation() != + old_remote_candidate1->generation(), + 1000, 1000); + EXPECT_TRUE_WAIT_MARGIN(RemoteCandidate(ep2_ch1())->generation() != + old_remote_candidate2->generation(), + 1000, 1000); + EXPECT_EQ(1u, RemoteCandidate(ep2_ch1())->generation()); + EXPECT_EQ(1u, RemoteCandidate(ep1_ch1())->generation()); + } + + void TestSignalRoleConflict() { + SetIceTiebreaker(0, kTiebreaker1); // Default EP1 is in controlling state. + + SetIceRole(1, cricket::ICEROLE_CONTROLLING); + SetIceTiebreaker(1, kTiebreaker2); + + // Creating channels with both channels role set to CONTROLLING. + CreateChannels(1); + // Since both the channels initiated with controlling state and channel2 + // has higher tiebreaker value, channel1 should receive SignalRoleConflict. + EXPECT_TRUE_WAIT(GetRoleConflict(0), 1000); + EXPECT_FALSE(GetRoleConflict(1)); + + EXPECT_TRUE_WAIT(ep1_ch1()->receiving() && + ep1_ch1()->writable() && + ep2_ch1()->receiving() && + ep2_ch1()->writable(), + 1000); + + EXPECT_TRUE(ep1_ch1()->best_connection() && + ep2_ch1()->best_connection()); + + TestSendRecv(1); + } + + // We pass the candidates directly to the other side. + void OnCandidate(cricket::TransportChannelImpl* ch, + const cricket::Candidate& c) { + if (force_relay_ && c.type() != cricket::RELAY_PORT_TYPE) + return; + + if (GetEndpoint(ch)->save_candidates_) { + GetEndpoint(ch)->saved_candidates_.push_back(new CandidateData(ch, c)); + } else { + main_->Post(this, MSG_CANDIDATE, new CandidateData(ch, c)); + } + } + + void PauseCandidates(int endpoint) { + GetEndpoint(endpoint)->save_candidates_ = true; + } + + void ResumeCandidates(int endpoint) { + Endpoint* ed = GetEndpoint(endpoint); + std::vector<CandidateData*>::iterator it = ed->saved_candidates_.begin(); + for (; it != ed->saved_candidates_.end(); ++it) { + main_->Post(this, MSG_CANDIDATE, *it); + } + ed->saved_candidates_.clear(); + ed->save_candidates_ = false; + } + + void OnMessage(rtc::Message* msg) { + switch (msg->message_id) { + case MSG_CANDIDATE: { + rtc::scoped_ptr<CandidateData> data( + static_cast<CandidateData*>(msg->pdata)); + cricket::P2PTransportChannel* rch = GetRemoteChannel(data->channel); + cricket::Candidate c = data->candidate; + if (clear_remote_candidates_ufrag_pwd_) { + c.set_username(""); + c.set_password(""); + } + LOG(LS_INFO) << "Candidate(" << data->channel->component() << "->" + << rch->component() << "): " << c.ToString(); + rch->AddRemoteCandidate(c); + break; + } + } + } + void OnReadPacket(cricket::TransportChannel* channel, const char* data, + size_t len, const rtc::PacketTime& packet_time, + int flags) { + std::list<std::string>& packets = GetPacketList(channel); + packets.push_front(std::string(data, len)); + } + void OnRoleConflict(cricket::TransportChannelImpl* channel) { + GetEndpoint(channel)->OnRoleConflict(true); + cricket::IceRole new_role = + GetEndpoint(channel)->ice_role() == cricket::ICEROLE_CONTROLLING ? + cricket::ICEROLE_CONTROLLED : cricket::ICEROLE_CONTROLLING; + channel->SetIceRole(new_role); + } + int SendData(cricket::TransportChannel* channel, + const char* data, size_t len) { + rtc::PacketOptions options; + return channel->SendPacket(data, len, options, 0); + } + bool CheckDataOnChannel(cricket::TransportChannel* channel, + const char* data, int len) { + return GetChannelData(channel)->CheckData(data, len); + } + static const cricket::Candidate* LocalCandidate( + cricket::P2PTransportChannel* ch) { + return (ch && ch->best_connection()) ? + &ch->best_connection()->local_candidate() : NULL; + } + static const cricket::Candidate* RemoteCandidate( + cricket::P2PTransportChannel* ch) { + return (ch && ch->best_connection()) ? + &ch->best_connection()->remote_candidate() : NULL; + } + Endpoint* GetEndpoint(cricket::TransportChannel* ch) { + if (ep1_.HasChannel(ch)) { + return &ep1_; + } else if (ep2_.HasChannel(ch)) { + return &ep2_; + } else { + return NULL; + } + } + cricket::P2PTransportChannel* GetRemoteChannel( + cricket::TransportChannel* ch) { + if (ch == ep1_ch1()) + return ep2_ch1(); + else if (ch == ep1_ch2()) + return ep2_ch2(); + else if (ch == ep2_ch1()) + return ep1_ch1(); + else if (ch == ep2_ch2()) + return ep1_ch2(); + else + return NULL; + } + std::list<std::string>& GetPacketList(cricket::TransportChannel* ch) { + return GetChannelData(ch)->ch_packets_; + } + + void set_clear_remote_candidates_ufrag_pwd(bool clear) { + clear_remote_candidates_ufrag_pwd_ = clear; + } + + void set_force_relay(bool relay) { + force_relay_ = relay; + } + + private: + rtc::Thread* main_; + rtc::scoped_ptr<rtc::PhysicalSocketServer> pss_; + rtc::scoped_ptr<rtc::VirtualSocketServer> vss_; + rtc::scoped_ptr<rtc::NATSocketServer> nss_; + rtc::scoped_ptr<rtc::FirewallSocketServer> ss_; + rtc::SocketServerScope ss_scope_; + rtc::scoped_ptr<cricket::TestStunServer> stun_server_; + cricket::TestTurnServer turn_server_; + cricket::TestRelayServer relay_server_; + rtc::SocksProxyServer socks_server1_; + rtc::SocksProxyServer socks_server2_; + Endpoint ep1_; + Endpoint ep2_; + bool clear_remote_candidates_ufrag_pwd_; + bool force_relay_; +}; + +// The tests have only a few outcomes, which we predefine. +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kLocalUdpToLocalUdp("local", "udp", "local", "udp", + "local", "udp", "local", "udp", 1000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kLocalUdpToStunUdp("local", "udp", "stun", "udp", + "local", "udp", "stun", "udp", 1000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kLocalUdpToPrflxUdp("local", "udp", "prflx", "udp", + "prflx", "udp", "local", "udp", 1000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kPrflxUdpToLocalUdp("prflx", "udp", "local", "udp", + "local", "udp", "prflx", "udp", 1000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kStunUdpToLocalUdp("stun", "udp", "local", "udp", + "local", "udp", "stun", "udp", 1000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kStunUdpToStunUdp("stun", "udp", "stun", "udp", + "stun", "udp", "stun", "udp", 1000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kPrflxUdpToStunUdp("prflx", "udp", "stun", "udp", + "local", "udp", "prflx", "udp", 1000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kLocalUdpToRelayUdp("local", "udp", "relay", "udp", + "relay", "udp", "local", "udp", 2000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kPrflxUdpToRelayUdp("prflx", "udp", "relay", "udp", + "relay", "udp", "prflx", "udp", 2000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kLocalTcpToLocalTcp("local", "tcp", "local", "tcp", + "local", "tcp", "local", "tcp", 3000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kLocalTcpToPrflxTcp("local", "tcp", "prflx", "tcp", + "prflx", "tcp", "local", "tcp", 3000); +const P2PTransportChannelTestBase::Result P2PTransportChannelTestBase:: + kPrflxTcpToLocalTcp("prflx", "tcp", "local", "tcp", + "local", "tcp", "prflx", "tcp", 3000); + +// Test the matrix of all the connectivity types we expect to see in the wild. +// Just test every combination of the configs in the Config enum. +class P2PTransportChannelTest : public P2PTransportChannelTestBase { + protected: + static const Result* kMatrix[NUM_CONFIGS][NUM_CONFIGS]; + static const Result* kMatrixSharedUfrag[NUM_CONFIGS][NUM_CONFIGS]; + static const Result* kMatrixSharedSocketAsGice[NUM_CONFIGS][NUM_CONFIGS]; + static const Result* kMatrixSharedSocketAsIce[NUM_CONFIGS][NUM_CONFIGS]; + void ConfigureEndpoints(Config config1, + Config config2, + int allocator_flags1, + int allocator_flags2) { + ServerAddresses stun_servers; + stun_servers.insert(kStunAddr); + GetEndpoint(0)->allocator_.reset( + new cricket::BasicPortAllocator(&(GetEndpoint(0)->network_manager_), + stun_servers, + rtc::SocketAddress(), rtc::SocketAddress(), + rtc::SocketAddress())); + GetEndpoint(1)->allocator_.reset( + new cricket::BasicPortAllocator(&(GetEndpoint(1)->network_manager_), + stun_servers, + rtc::SocketAddress(), rtc::SocketAddress(), + rtc::SocketAddress())); + + cricket::RelayServerConfig relay_server(cricket::RELAY_TURN); + relay_server.credentials = kRelayCredentials; + relay_server.ports.push_back( + cricket::ProtocolAddress(kTurnUdpIntAddr, cricket::PROTO_UDP, false)); + GetEndpoint(0)->allocator_->AddRelay(relay_server); + GetEndpoint(1)->allocator_->AddRelay(relay_server); + + int delay = kMinimumStepDelay; + ConfigureEndpoint(0, config1); + SetAllocatorFlags(0, allocator_flags1); + SetAllocationStepDelay(0, delay); + ConfigureEndpoint(1, config2); + SetAllocatorFlags(1, allocator_flags2); + SetAllocationStepDelay(1, delay); + + set_clear_remote_candidates_ufrag_pwd(true); + } + void ConfigureEndpoint(int endpoint, Config config) { + switch (config) { + case OPEN: + AddAddress(endpoint, kPublicAddrs[endpoint]); + break; + case NAT_FULL_CONE: + case NAT_ADDR_RESTRICTED: + case NAT_PORT_RESTRICTED: + case NAT_SYMMETRIC: + AddAddress(endpoint, kPrivateAddrs[endpoint]); + // Add a single NAT of the desired type + nat()->AddTranslator(kPublicAddrs[endpoint], kNatAddrs[endpoint], + static_cast<rtc::NATType>(config - NAT_FULL_CONE))-> + AddClient(kPrivateAddrs[endpoint]); + break; + case NAT_DOUBLE_CONE: + case NAT_SYMMETRIC_THEN_CONE: + AddAddress(endpoint, kCascadedPrivateAddrs[endpoint]); + // Add a two cascaded NATs of the desired types + nat()->AddTranslator(kPublicAddrs[endpoint], kNatAddrs[endpoint], + (config == NAT_DOUBLE_CONE) ? + rtc::NAT_OPEN_CONE : rtc::NAT_SYMMETRIC)-> + AddTranslator(kPrivateAddrs[endpoint], kCascadedNatAddrs[endpoint], + rtc::NAT_OPEN_CONE)-> + AddClient(kCascadedPrivateAddrs[endpoint]); + break; + case BLOCK_UDP: + case BLOCK_UDP_AND_INCOMING_TCP: + case BLOCK_ALL_BUT_OUTGOING_HTTP: + case PROXY_HTTPS: + case PROXY_SOCKS: + AddAddress(endpoint, kPublicAddrs[endpoint]); + // Block all UDP + fw()->AddRule(false, rtc::FP_UDP, rtc::FD_ANY, + kPublicAddrs[endpoint]); + if (config == BLOCK_UDP_AND_INCOMING_TCP) { + // Block TCP inbound to the endpoint + fw()->AddRule(false, rtc::FP_TCP, SocketAddress(), + kPublicAddrs[endpoint]); + } else if (config == BLOCK_ALL_BUT_OUTGOING_HTTP) { + // Block all TCP to/from the endpoint except 80/443 out + fw()->AddRule(true, rtc::FP_TCP, kPublicAddrs[endpoint], + SocketAddress(rtc::IPAddress(INADDR_ANY), 80)); + fw()->AddRule(true, rtc::FP_TCP, kPublicAddrs[endpoint], + SocketAddress(rtc::IPAddress(INADDR_ANY), 443)); + fw()->AddRule(false, rtc::FP_TCP, rtc::FD_ANY, + kPublicAddrs[endpoint]); + } else if (config == PROXY_HTTPS) { + // Block all TCP to/from the endpoint except to the proxy server + fw()->AddRule(true, rtc::FP_TCP, kPublicAddrs[endpoint], + kHttpsProxyAddrs[endpoint]); + fw()->AddRule(false, rtc::FP_TCP, rtc::FD_ANY, + kPublicAddrs[endpoint]); + SetProxy(endpoint, rtc::PROXY_HTTPS); + } else if (config == PROXY_SOCKS) { + // Block all TCP to/from the endpoint except to the proxy server + fw()->AddRule(true, rtc::FP_TCP, kPublicAddrs[endpoint], + kSocksProxyAddrs[endpoint]); + fw()->AddRule(false, rtc::FP_TCP, rtc::FD_ANY, + kPublicAddrs[endpoint]); + SetProxy(endpoint, rtc::PROXY_SOCKS5); + } + break; + default: + break; + } + } +}; + +// Shorthands for use in the test matrix. +#define LULU &kLocalUdpToLocalUdp +#define LUSU &kLocalUdpToStunUdp +#define LUPU &kLocalUdpToPrflxUdp +#define PULU &kPrflxUdpToLocalUdp +#define SULU &kStunUdpToLocalUdp +#define SUSU &kStunUdpToStunUdp +#define PUSU &kPrflxUdpToStunUdp +#define LURU &kLocalUdpToRelayUdp +#define PURU &kPrflxUdpToRelayUdp +#define LTLT &kLocalTcpToLocalTcp +#define LTPT &kLocalTcpToPrflxTcp +#define PTLT &kPrflxTcpToLocalTcp +// TODO: Enable these once TestRelayServer can accept external TCP. +#define LTRT NULL +#define LSRS NULL + +// Test matrix. Originator behavior defined by rows, receiever by columns. + +// Currently the p2ptransportchannel.cc (specifically the +// P2PTransportChannel::OnUnknownAddress) operates in 2 modes depend on the +// remote candidates - ufrag per port or shared ufrag. +// For example, if the remote candidates have the shared ufrag, for the unknown +// address reaches the OnUnknownAddress, we will try to find the matched +// remote candidate based on the address and protocol, if not found, a new +// remote candidate will be created for this address. But if the remote +// candidates have different ufrags, we will try to find the matched remote +// candidate by comparing the ufrag. If not found, an error will be returned. +// Because currently the shared ufrag feature is under the experiment and will +// be rolled out gradually. We want to test the different combinations of peers +// with/without the shared ufrag enabled. And those different combinations have +// different expectation of the best connection. For example in the OpenToCONE +// case, an unknown address will be updated to a "host" remote candidate if the +// remote peer uses different ufrag per port. But in the shared ufrag case, +// a "stun" (should be peer-reflexive eventually) candidate will be created for +// that. So the expected best candidate will be LUSU instead of LULU. +// With all these, we have to keep 2 test matrixes for the tests: +// kMatrix - for the tests that the remote peer uses different ufrag per port. +// kMatrixSharedUfrag - for the tests that remote peer uses shared ufrag. +// The different between the two matrixes are on: +// OPToCONE, OPTo2CON, +// COToCONE, COToADDR, COToPORT, COToSYMM, COTo2CON, COToSCON, +// ADToCONE, ADToADDR, ADTo2CON, +// POToADDR, +// SYToADDR, +// 2CToCONE, 2CToADDR, 2CToPORT, 2CToSYMM, 2CTo2CON, 2CToSCON, +// SCToADDR, + +// TODO: Fix NULLs caused by lack of TCP support in NATSocket. +// TODO: Fix NULLs caused by no HTTP proxy support. +// TODO: Rearrange rows/columns from best to worst. +// TODO(ronghuawu): Keep only one test matrix once the shared ufrag is enabled. +const P2PTransportChannelTest::Result* + P2PTransportChannelTest::kMatrix[NUM_CONFIGS][NUM_CONFIGS] = { +// OPEN CONE ADDR PORT SYMM 2CON SCON !UDP !TCP HTTP PRXH PRXS +/*OP*/ {LULU, LULU, LULU, LULU, LULU, LULU, LULU, LTLT, LTLT, LSRS, NULL, LTLT}, +/*CO*/ {LULU, LULU, LULU, SULU, SULU, LULU, SULU, NULL, NULL, LSRS, NULL, LTRT}, +/*AD*/ {LULU, LULU, LULU, SUSU, SUSU, LULU, SUSU, NULL, NULL, LSRS, NULL, LTRT}, +/*PO*/ {LULU, LUSU, LUSU, SUSU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*SY*/ {LULU, LUSU, LUSU, LURU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*2C*/ {LULU, LULU, LULU, SULU, SULU, LULU, SULU, NULL, NULL, LSRS, NULL, LTRT}, +/*SC*/ {LULU, LUSU, LUSU, LURU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*!U*/ {LTLT, NULL, NULL, NULL, NULL, NULL, NULL, LTLT, LTLT, LSRS, NULL, LTRT}, +/*!T*/ {LTRT, NULL, NULL, NULL, NULL, NULL, NULL, LTLT, LTRT, LSRS, NULL, LTRT}, +/*HT*/ {LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, NULL, LSRS}, +/*PR*/ {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL}, +/*PR*/ {LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LSRS, NULL, LTRT}, +}; +const P2PTransportChannelTest::Result* + P2PTransportChannelTest::kMatrixSharedUfrag[NUM_CONFIGS][NUM_CONFIGS] = { +// OPEN CONE ADDR PORT SYMM 2CON SCON !UDP !TCP HTTP PRXH PRXS +/*OP*/ {LULU, LUSU, LULU, LULU, LULU, LUSU, LULU, LTLT, LTLT, LSRS, NULL, LTLT}, +/*CO*/ {LULU, LUSU, LUSU, SUSU, SUSU, LUSU, SUSU, NULL, NULL, LSRS, NULL, LTRT}, +/*AD*/ {LULU, LUSU, LUSU, SUSU, SUSU, LUSU, SUSU, NULL, NULL, LSRS, NULL, LTRT}, +/*PO*/ {LULU, LUSU, LUSU, SUSU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*SY*/ {LULU, LUSU, LUSU, LURU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*2C*/ {LULU, LUSU, LUSU, SUSU, SUSU, LUSU, SUSU, NULL, NULL, LSRS, NULL, LTRT}, +/*SC*/ {LULU, LUSU, LUSU, LURU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*!U*/ {LTLT, NULL, NULL, NULL, NULL, NULL, NULL, LTLT, LTLT, LSRS, NULL, LTRT}, +/*!T*/ {LTRT, NULL, NULL, NULL, NULL, NULL, NULL, LTLT, LTRT, LSRS, NULL, LTRT}, +/*HT*/ {LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, NULL, LSRS}, +/*PR*/ {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL}, +/*PR*/ {LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LSRS, NULL, LTRT}, +}; +const P2PTransportChannelTest::Result* + P2PTransportChannelTest::kMatrixSharedSocketAsGice + [NUM_CONFIGS][NUM_CONFIGS] = { +// OPEN CONE ADDR PORT SYMM 2CON SCON !UDP !TCP HTTP PRXH PRXS +/*OP*/ {LULU, LUSU, LUSU, LUSU, LUSU, LUSU, LUSU, LTLT, LTLT, LSRS, NULL, LTLT}, +/*CO*/ {LULU, LUSU, LUSU, LUSU, LUSU, LUSU, LUSU, NULL, NULL, LSRS, NULL, LTRT}, +/*AD*/ {LULU, LUSU, LUSU, LUSU, LUSU, LUSU, LUSU, NULL, NULL, LSRS, NULL, LTRT}, +/*PO*/ {LULU, LUSU, LUSU, LUSU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*SY*/ {LULU, LUSU, LUSU, LURU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*2C*/ {LULU, LUSU, LUSU, LUSU, LUSU, LUSU, LUSU, NULL, NULL, LSRS, NULL, LTRT}, +/*SC*/ {LULU, LUSU, LUSU, LURU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*!U*/ {LTLT, NULL, NULL, NULL, NULL, NULL, NULL, LTLT, LTLT, LSRS, NULL, LTRT}, +/*!T*/ {LTRT, NULL, NULL, NULL, NULL, NULL, NULL, LTLT, LTRT, LSRS, NULL, LTRT}, +/*HT*/ {LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, NULL, LSRS}, +/*PR*/ {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL}, +/*PR*/ {LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LSRS, NULL, LTRT}, +}; +const P2PTransportChannelTest::Result* + P2PTransportChannelTest::kMatrixSharedSocketAsIce + [NUM_CONFIGS][NUM_CONFIGS] = { +// OPEN CONE ADDR PORT SYMM 2CON SCON !UDP !TCP HTTP PRXH PRXS +/*OP*/ {LULU, LUSU, LUSU, LUSU, LUPU, LUSU, LUPU, PTLT, LTPT, LSRS, NULL, LTPT}, +/*CO*/ {LULU, LUSU, LUSU, LUSU, LUPU, LUSU, LUPU, NULL, NULL, LSRS, NULL, LTRT}, +/*AD*/ {LULU, LUSU, LUSU, LUSU, LUPU, LUSU, LUPU, NULL, NULL, LSRS, NULL, LTRT}, +/*PO*/ {LULU, LUSU, LUSU, LUSU, LURU, LUSU, LURU, NULL, NULL, LSRS, NULL, LTRT}, +/*SY*/ {PULU, PUSU, PUSU, PURU, PURU, PUSU, PURU, NULL, NULL, LSRS, NULL, LTRT}, +/*2C*/ {LULU, LUSU, LUSU, LUSU, LUPU, LUSU, LUPU, NULL, NULL, LSRS, NULL, LTRT}, +/*SC*/ {PULU, PUSU, PUSU, PURU, PURU, PUSU, PURU, NULL, NULL, LSRS, NULL, LTRT}, +/*!U*/ {PTLT, NULL, NULL, NULL, NULL, NULL, NULL, PTLT, LTPT, LSRS, NULL, LTRT}, +/*!T*/ {LTRT, NULL, NULL, NULL, NULL, NULL, NULL, PTLT, LTRT, LSRS, NULL, LTRT}, +/*HT*/ {LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, LSRS, NULL, LSRS}, +/*PR*/ {NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL}, +/*PR*/ {LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LTRT, LSRS, NULL, LTRT}, +}; + +// The actual tests that exercise all the various configurations. +// Test names are of the form P2PTransportChannelTest_TestOPENToNAT_FULL_CONE +#define P2P_TEST_DECLARATION(x, y, z) \ + TEST_F(P2PTransportChannelTest, z##Test##x##To##y) { \ + ConfigureEndpoints(x, y, PORTALLOCATOR_ENABLE_SHARED_SOCKET, \ + PORTALLOCATOR_ENABLE_SHARED_SOCKET); \ + if (kMatrixSharedSocketAsIce[x][y] != NULL) \ + Test(*kMatrixSharedSocketAsIce[x][y]); \ + else \ + LOG(LS_WARNING) << "Not yet implemented"; \ + } + +#define P2P_TEST(x, y) \ + P2P_TEST_DECLARATION(x, y,) + +#define FLAKY_P2P_TEST(x, y) \ + P2P_TEST_DECLARATION(x, y, DISABLED_) + +// TODO(holmer): Disabled due to randomly failing on webrtc buildbots. +// Issue: webrtc/2383 +#define P2P_TEST_SET(x) \ + P2P_TEST(x, OPEN) \ + FLAKY_P2P_TEST(x, NAT_FULL_CONE) \ + FLAKY_P2P_TEST(x, NAT_ADDR_RESTRICTED) \ + FLAKY_P2P_TEST(x, NAT_PORT_RESTRICTED) \ + P2P_TEST(x, NAT_SYMMETRIC) \ + FLAKY_P2P_TEST(x, NAT_DOUBLE_CONE) \ + P2P_TEST(x, NAT_SYMMETRIC_THEN_CONE) \ + P2P_TEST(x, BLOCK_UDP) \ + P2P_TEST(x, BLOCK_UDP_AND_INCOMING_TCP) \ + P2P_TEST(x, BLOCK_ALL_BUT_OUTGOING_HTTP) \ + P2P_TEST(x, PROXY_HTTPS) \ + P2P_TEST(x, PROXY_SOCKS) + +#define FLAKY_P2P_TEST_SET(x) \ + P2P_TEST(x, OPEN) \ + P2P_TEST(x, NAT_FULL_CONE) \ + P2P_TEST(x, NAT_ADDR_RESTRICTED) \ + P2P_TEST(x, NAT_PORT_RESTRICTED) \ + P2P_TEST(x, NAT_SYMMETRIC) \ + P2P_TEST(x, NAT_DOUBLE_CONE) \ + P2P_TEST(x, NAT_SYMMETRIC_THEN_CONE) \ + P2P_TEST(x, BLOCK_UDP) \ + P2P_TEST(x, BLOCK_UDP_AND_INCOMING_TCP) \ + P2P_TEST(x, BLOCK_ALL_BUT_OUTGOING_HTTP) \ + P2P_TEST(x, PROXY_HTTPS) \ + P2P_TEST(x, PROXY_SOCKS) + +P2P_TEST_SET(OPEN) +P2P_TEST_SET(NAT_FULL_CONE) +P2P_TEST_SET(NAT_ADDR_RESTRICTED) +P2P_TEST_SET(NAT_PORT_RESTRICTED) +P2P_TEST_SET(NAT_SYMMETRIC) +P2P_TEST_SET(NAT_DOUBLE_CONE) +P2P_TEST_SET(NAT_SYMMETRIC_THEN_CONE) +P2P_TEST_SET(BLOCK_UDP) +P2P_TEST_SET(BLOCK_UDP_AND_INCOMING_TCP) +P2P_TEST_SET(BLOCK_ALL_BUT_OUTGOING_HTTP) +P2P_TEST_SET(PROXY_HTTPS) +P2P_TEST_SET(PROXY_SOCKS) + +// Test that we restart candidate allocation when local ufrag&pwd changed. +// Standard Ice protocol is used. +TEST_F(P2PTransportChannelTest, HandleUfragPwdChange) { + ConfigureEndpoints(OPEN, OPEN, kDefaultPortAllocatorFlags, + kDefaultPortAllocatorFlags); + CreateChannels(1); + TestHandleIceUfragPasswordChanged(); + DestroyChannels(); +} + +// Test the operation of GetStats. +TEST_F(P2PTransportChannelTest, GetStats) { + ConfigureEndpoints(OPEN, OPEN, kDefaultPortAllocatorFlags, + kDefaultPortAllocatorFlags); + CreateChannels(1); + EXPECT_TRUE_WAIT_MARGIN(ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 1000, 1000); + TestSendRecv(1); + cricket::ConnectionInfos infos; + ASSERT_TRUE(ep1_ch1()->GetStats(&infos)); + ASSERT_TRUE(infos.size() >= 1); + cricket::ConnectionInfo* best_conn_info = nullptr; + for (cricket::ConnectionInfo& info : infos) { + if (info.best_connection) { + best_conn_info = &info; + break; + } + } + ASSERT_TRUE(best_conn_info != nullptr); + EXPECT_TRUE(best_conn_info->new_connection); + EXPECT_TRUE(best_conn_info->receiving); + EXPECT_TRUE(best_conn_info->writable); + EXPECT_FALSE(best_conn_info->timeout); + EXPECT_EQ(10U, best_conn_info->sent_total_packets); + EXPECT_EQ(0U, best_conn_info->sent_discarded_packets); + EXPECT_EQ(10 * 36U, best_conn_info->sent_total_bytes); + EXPECT_EQ(10 * 36U, best_conn_info->recv_total_bytes); + EXPECT_GT(best_conn_info->rtt, 0U); + DestroyChannels(); +} + +// Test that we properly create a connection on a STUN ping from unknown address +// when the signaling is slow. +TEST_F(P2PTransportChannelTest, PeerReflexiveCandidateBeforeSignaling) { + ConfigureEndpoints(OPEN, OPEN, kDefaultPortAllocatorFlags, + kDefaultPortAllocatorFlags); + // Emulate no remote credentials coming in. + set_clear_remote_candidates_ufrag_pwd(false); + CreateChannels(1); + // Only have remote credentials come in for ep2, not ep1. + ep2_ch1()->SetRemoteIceCredentials(kIceUfrag[3], kIcePwd[3]); + + // Pause sending ep2's candidates to ep1 until ep1 receives the peer reflexive + // candidate. + PauseCandidates(1); + + // The caller should have the best connection connected to the peer reflexive + // candidate. + const cricket::Connection* best_connection = NULL; + WAIT((best_connection = ep1_ch1()->best_connection()) != NULL, 2000); + EXPECT_EQ("prflx", ep1_ch1()->best_connection()->remote_candidate().type()); + + // Because we don't have a remote pwd, we don't ping yet. + EXPECT_EQ(kIceUfrag[1], + ep1_ch1()->best_connection()->remote_candidate().username()); + EXPECT_EQ("", ep1_ch1()->best_connection()->remote_candidate().password()); + EXPECT_TRUE(nullptr == ep1_ch1()->FindNextPingableConnection()); + + ep1_ch1()->SetRemoteIceCredentials(kIceUfrag[1], kIcePwd[1]); + ResumeCandidates(1); + + EXPECT_EQ(kIcePwd[1], + ep1_ch1()->best_connection()->remote_candidate().password()); + EXPECT_TRUE(nullptr != ep1_ch1()->FindNextPingableConnection()); + + WAIT(ep2_ch1()->best_connection() != NULL, 2000); + // Verify ep1's best connection is updated to use the 'local' candidate. + EXPECT_EQ_WAIT( + "local", + ep1_ch1()->best_connection()->remote_candidate().type(), + 2000); + EXPECT_EQ(best_connection, ep1_ch1()->best_connection()); + DestroyChannels(); +} + +// Test that we properly create a connection on a STUN ping from unknown address +// when the signaling is slow and the end points are behind NAT. +TEST_F(P2PTransportChannelTest, PeerReflexiveCandidateBeforeSignalingWithNAT) { + ConfigureEndpoints(OPEN, NAT_SYMMETRIC, kDefaultPortAllocatorFlags, + kDefaultPortAllocatorFlags); + // Emulate no remote credentials coming in. + set_clear_remote_candidates_ufrag_pwd(false); + CreateChannels(1); + // Only have remote credentials come in for ep2, not ep1. + ep2_ch1()->SetRemoteIceCredentials(kIceUfrag[3], kIcePwd[3]); + // Pause sending ep2's candidates to ep1 until ep1 receives the peer reflexive + // candidate. + PauseCandidates(1); + + // The caller should have the best connection connected to the peer reflexive + // candidate. + WAIT(ep1_ch1()->best_connection() != NULL, 2000); + EXPECT_EQ("prflx", ep1_ch1()->best_connection()->remote_candidate().type()); + + // Because we don't have a remote pwd, we don't ping yet. + EXPECT_EQ(kIceUfrag[1], + ep1_ch1()->best_connection()->remote_candidate().username()); + EXPECT_EQ("", ep1_ch1()->best_connection()->remote_candidate().password()); + EXPECT_TRUE(nullptr == ep1_ch1()->FindNextPingableConnection()); + + ep1_ch1()->SetRemoteIceCredentials(kIceUfrag[1], kIcePwd[1]); + ResumeCandidates(1); + + EXPECT_EQ(kIcePwd[1], + ep1_ch1()->best_connection()->remote_candidate().password()); + EXPECT_TRUE(nullptr != ep1_ch1()->FindNextPingableConnection()); + + const cricket::Connection* best_connection = NULL; + WAIT((best_connection = ep2_ch1()->best_connection()) != NULL, 2000); + + // Wait to verify the connection is not culled. + WAIT(ep1_ch1()->writable(), 2000); + EXPECT_EQ(ep2_ch1()->best_connection(), best_connection); + EXPECT_EQ("prflx", ep1_ch1()->best_connection()->remote_candidate().type()); + DestroyChannels(); +} + +// Test that if remote candidates don't have ufrag and pwd, we still work. +TEST_F(P2PTransportChannelTest, RemoteCandidatesWithoutUfragPwd) { + set_clear_remote_candidates_ufrag_pwd(true); + ConfigureEndpoints(OPEN, OPEN, kDefaultPortAllocatorFlags, + kDefaultPortAllocatorFlags); + CreateChannels(1); + const cricket::Connection* best_connection = NULL; + // Wait until the callee's connections are created. + WAIT((best_connection = ep2_ch1()->best_connection()) != NULL, 1000); + // Wait to see if they get culled; they shouldn't. + WAIT(ep2_ch1()->best_connection() != best_connection, 1000); + EXPECT_TRUE(ep2_ch1()->best_connection() == best_connection); + DestroyChannels(); +} + +// Test that a host behind NAT cannot be reached when incoming_only +// is set to true. +TEST_F(P2PTransportChannelTest, IncomingOnlyBlocked) { + ConfigureEndpoints(NAT_FULL_CONE, OPEN, kDefaultPortAllocatorFlags, + kDefaultPortAllocatorFlags); + + SetAllocatorFlags(0, kOnlyLocalPorts); + CreateChannels(1); + ep1_ch1()->set_incoming_only(true); + + // Pump for 1 second and verify that the channels are not connected. + rtc::Thread::Current()->ProcessMessages(1000); + + EXPECT_FALSE(ep1_ch1()->receiving()); + EXPECT_FALSE(ep1_ch1()->writable()); + EXPECT_FALSE(ep2_ch1()->receiving()); + EXPECT_FALSE(ep2_ch1()->writable()); + + DestroyChannels(); +} + +// Test that a peer behind NAT can connect to a peer that has +// incoming_only flag set. +TEST_F(P2PTransportChannelTest, IncomingOnlyOpen) { + ConfigureEndpoints(OPEN, NAT_FULL_CONE, kDefaultPortAllocatorFlags, + kDefaultPortAllocatorFlags); + + SetAllocatorFlags(0, kOnlyLocalPorts); + CreateChannels(1); + ep1_ch1()->set_incoming_only(true); + + EXPECT_TRUE_WAIT_MARGIN(ep1_ch1() != NULL && ep2_ch1() != NULL && + ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 1000, 1000); + + DestroyChannels(); +} + +TEST_F(P2PTransportChannelTest, TestTcpConnectionsFromActiveToPassive) { + AddAddress(0, kPublicAddrs[0]); + AddAddress(1, kPublicAddrs[1]); + + SetAllocationStepDelay(0, kMinimumStepDelay); + SetAllocationStepDelay(1, kMinimumStepDelay); + + int kOnlyLocalTcpPorts = cricket::PORTALLOCATOR_DISABLE_UDP | + cricket::PORTALLOCATOR_DISABLE_STUN | + cricket::PORTALLOCATOR_DISABLE_RELAY; + // Disable all protocols except TCP. + SetAllocatorFlags(0, kOnlyLocalTcpPorts); + SetAllocatorFlags(1, kOnlyLocalTcpPorts); + + SetAllowTcpListen(0, true); // actpass. + SetAllowTcpListen(1, false); // active. + + CreateChannels(1); + + EXPECT_TRUE_WAIT(ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 1000); + EXPECT_TRUE( + ep1_ch1()->best_connection() && ep2_ch1()->best_connection() && + LocalCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[0]) && + RemoteCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[1])); + + std::string kTcpProtocol = "tcp"; + EXPECT_EQ(kTcpProtocol, RemoteCandidate(ep1_ch1())->protocol()); + EXPECT_EQ(kTcpProtocol, LocalCandidate(ep1_ch1())->protocol()); + EXPECT_EQ(kTcpProtocol, RemoteCandidate(ep2_ch1())->protocol()); + EXPECT_EQ(kTcpProtocol, LocalCandidate(ep2_ch1())->protocol()); + + TestSendRecv(1); + DestroyChannels(); +} + +TEST_F(P2PTransportChannelTest, TestIceRoleConflict) { + AddAddress(0, kPublicAddrs[0]); + AddAddress(1, kPublicAddrs[1]); + TestSignalRoleConflict(); +} + +// Tests that the ice configs (protocol, tiebreaker and role) can be passed +// down to ports. +TEST_F(P2PTransportChannelTest, TestIceConfigWillPassDownToPort) { + AddAddress(0, kPublicAddrs[0]); + AddAddress(1, kPublicAddrs[1]); + + SetIceRole(0, cricket::ICEROLE_CONTROLLING); + SetIceTiebreaker(0, kTiebreaker1); + SetIceRole(1, cricket::ICEROLE_CONTROLLING); + SetIceTiebreaker(1, kTiebreaker2); + + CreateChannels(1); + + EXPECT_EQ_WAIT(2u, ep1_ch1()->ports().size(), 1000); + + const std::vector<cricket::PortInterface *> ports_before = ep1_ch1()->ports(); + for (size_t i = 0; i < ports_before.size(); ++i) { + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, ports_before[i]->GetIceRole()); + EXPECT_EQ(kTiebreaker1, ports_before[i]->IceTiebreaker()); + } + + ep1_ch1()->SetIceRole(cricket::ICEROLE_CONTROLLED); + ep1_ch1()->SetIceTiebreaker(kTiebreaker2); + + const std::vector<cricket::PortInterface *> ports_after = ep1_ch1()->ports(); + for (size_t i = 0; i < ports_after.size(); ++i) { + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, ports_before[i]->GetIceRole()); + // SetIceTiebreaker after Connect() has been called will fail. So expect the + // original value. + EXPECT_EQ(kTiebreaker1, ports_before[i]->IceTiebreaker()); + } + + EXPECT_TRUE_WAIT(ep1_ch1()->receiving() && + ep1_ch1()->writable() && + ep2_ch1()->receiving() && + ep2_ch1()->writable(), + 1000); + + EXPECT_TRUE(ep1_ch1()->best_connection() && + ep2_ch1()->best_connection()); + + TestSendRecv(1); + DestroyChannels(); +} + +// Verify that we can set DSCP value and retrieve properly from P2PTC. +TEST_F(P2PTransportChannelTest, TestDefaultDscpValue) { + AddAddress(0, kPublicAddrs[0]); + AddAddress(1, kPublicAddrs[1]); + + CreateChannels(1); + EXPECT_EQ(rtc::DSCP_NO_CHANGE, + GetEndpoint(0)->cd1_.ch_->DefaultDscpValue()); + EXPECT_EQ(rtc::DSCP_NO_CHANGE, + GetEndpoint(1)->cd1_.ch_->DefaultDscpValue()); + GetEndpoint(0)->cd1_.ch_->SetOption( + rtc::Socket::OPT_DSCP, rtc::DSCP_CS6); + GetEndpoint(1)->cd1_.ch_->SetOption( + rtc::Socket::OPT_DSCP, rtc::DSCP_CS6); + EXPECT_EQ(rtc::DSCP_CS6, + GetEndpoint(0)->cd1_.ch_->DefaultDscpValue()); + EXPECT_EQ(rtc::DSCP_CS6, + GetEndpoint(1)->cd1_.ch_->DefaultDscpValue()); + GetEndpoint(0)->cd1_.ch_->SetOption( + rtc::Socket::OPT_DSCP, rtc::DSCP_AF41); + GetEndpoint(1)->cd1_.ch_->SetOption( + rtc::Socket::OPT_DSCP, rtc::DSCP_AF41); + EXPECT_EQ(rtc::DSCP_AF41, + GetEndpoint(0)->cd1_.ch_->DefaultDscpValue()); + EXPECT_EQ(rtc::DSCP_AF41, + GetEndpoint(1)->cd1_.ch_->DefaultDscpValue()); +} + +// Verify IPv6 connection is preferred over IPv4. +TEST_F(P2PTransportChannelTest, TestIPv6Connections) { + AddAddress(0, kIPv6PublicAddrs[0]); + AddAddress(0, kPublicAddrs[0]); + AddAddress(1, kIPv6PublicAddrs[1]); + AddAddress(1, kPublicAddrs[1]); + + SetAllocationStepDelay(0, kMinimumStepDelay); + SetAllocationStepDelay(1, kMinimumStepDelay); + + // Enable IPv6 + SetAllocatorFlags(0, cricket::PORTALLOCATOR_ENABLE_IPV6); + SetAllocatorFlags(1, cricket::PORTALLOCATOR_ENABLE_IPV6); + + CreateChannels(1); + + EXPECT_TRUE_WAIT(ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 1000); + EXPECT_TRUE( + ep1_ch1()->best_connection() && ep2_ch1()->best_connection() && + LocalCandidate(ep1_ch1())->address().EqualIPs(kIPv6PublicAddrs[0]) && + RemoteCandidate(ep1_ch1())->address().EqualIPs(kIPv6PublicAddrs[1])); + + TestSendRecv(1); + DestroyChannels(); +} + +// Testing forceful TURN connections. +TEST_F(P2PTransportChannelTest, TestForceTurn) { + ConfigureEndpoints( + NAT_PORT_RESTRICTED, NAT_SYMMETRIC, + kDefaultPortAllocatorFlags | cricket::PORTALLOCATOR_ENABLE_SHARED_SOCKET, + kDefaultPortAllocatorFlags | cricket::PORTALLOCATOR_ENABLE_SHARED_SOCKET); + set_force_relay(true); + + SetAllocationStepDelay(0, kMinimumStepDelay); + SetAllocationStepDelay(1, kMinimumStepDelay); + + CreateChannels(1); + + EXPECT_TRUE_WAIT(ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 2000); + + EXPECT_TRUE(ep1_ch1()->best_connection() && + ep2_ch1()->best_connection()); + + EXPECT_EQ("relay", RemoteCandidate(ep1_ch1())->type()); + EXPECT_EQ("relay", LocalCandidate(ep1_ch1())->type()); + EXPECT_EQ("relay", RemoteCandidate(ep2_ch1())->type()); + EXPECT_EQ("relay", LocalCandidate(ep2_ch1())->type()); + + TestSendRecv(1); + DestroyChannels(); +} + +// Test that if continual gathering is set to true, ICE gathering state will +// not change to "Complete", and vice versa. +TEST_F(P2PTransportChannelTest, TestContinualGathering) { + ConfigureEndpoints(OPEN, OPEN, kDefaultPortAllocatorFlags, + kDefaultPortAllocatorFlags); + SetAllocationStepDelay(0, kDefaultStepDelay); + SetAllocationStepDelay(1, kDefaultStepDelay); + CreateChannels(1); + cricket::IceConfig config = CreateIceConfig(1000, true); + ep1_ch1()->SetIceConfig(config); + // By default, ep2 does not gather continually. + + EXPECT_TRUE_WAIT_MARGIN(ep1_ch1() != NULL && ep2_ch1() != NULL && + ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 1000, 1000); + WAIT(cricket::IceGatheringState::kIceGatheringComplete == + ep1_ch1()->gathering_state(), + 1000); + EXPECT_EQ(cricket::IceGatheringState::kIceGatheringGathering, + ep1_ch1()->gathering_state()); + // By now, ep2 should have completed gathering. + EXPECT_EQ(cricket::IceGatheringState::kIceGatheringComplete, + ep2_ch1()->gathering_state()); + + DestroyChannels(); +} + +// Test what happens when we have 2 users behind the same NAT. This can lead +// to interesting behavior because the STUN server will only give out the +// address of the outermost NAT. +class P2PTransportChannelSameNatTest : public P2PTransportChannelTestBase { + protected: + void ConfigureEndpoints(Config nat_type, Config config1, Config config2) { + ASSERT(nat_type >= NAT_FULL_CONE && nat_type <= NAT_SYMMETRIC); + rtc::NATSocketServer::Translator* outer_nat = + nat()->AddTranslator(kPublicAddrs[0], kNatAddrs[0], + static_cast<rtc::NATType>(nat_type - NAT_FULL_CONE)); + ConfigureEndpoint(outer_nat, 0, config1); + ConfigureEndpoint(outer_nat, 1, config2); + } + void ConfigureEndpoint(rtc::NATSocketServer::Translator* nat, + int endpoint, Config config) { + ASSERT(config <= NAT_SYMMETRIC); + if (config == OPEN) { + AddAddress(endpoint, kPrivateAddrs[endpoint]); + nat->AddClient(kPrivateAddrs[endpoint]); + } else { + AddAddress(endpoint, kCascadedPrivateAddrs[endpoint]); + nat->AddTranslator(kPrivateAddrs[endpoint], kCascadedNatAddrs[endpoint], + static_cast<rtc::NATType>(config - NAT_FULL_CONE))->AddClient( + kCascadedPrivateAddrs[endpoint]); + } + } +}; + +TEST_F(P2PTransportChannelSameNatTest, TestConesBehindSameCone) { + ConfigureEndpoints(NAT_FULL_CONE, NAT_FULL_CONE, NAT_FULL_CONE); + Test(P2PTransportChannelTestBase::Result( + "prflx", "udp", "stun", "udp", "stun", "udp", "prflx", "udp", 1000)); +} + +// Test what happens when we have multiple available pathways. +// In the future we will try different RTTs and configs for the different +// interfaces, so that we can simulate a user with Ethernet and VPN networks. +class P2PTransportChannelMultihomedTest : public P2PTransportChannelTestBase { +}; + +// Test that we can establish connectivity when both peers are multihomed. +TEST_F(P2PTransportChannelMultihomedTest, DISABLED_TestBasic) { + AddAddress(0, kPublicAddrs[0]); + AddAddress(0, kAlternateAddrs[0]); + AddAddress(1, kPublicAddrs[1]); + AddAddress(1, kAlternateAddrs[1]); + Test(kLocalUdpToLocalUdp); +} + +// Test that we can quickly switch links if an interface goes down. +// The controlled side has two interfaces and one will die. +TEST_F(P2PTransportChannelMultihomedTest, TestFailoverControlledSide) { + AddAddress(0, kPublicAddrs[0]); + // Adding alternate address will make sure |kPublicAddrs| has the higher + // priority than others. This is due to FakeNetwork::AddInterface method. + AddAddress(1, kAlternateAddrs[1]); + AddAddress(1, kPublicAddrs[1]); + + // Use only local ports for simplicity. + SetAllocatorFlags(0, kOnlyLocalPorts); + SetAllocatorFlags(1, kOnlyLocalPorts); + + // Create channels and let them go writable, as usual. + CreateChannels(1); + + // Make the receiving timeout shorter for testing. + cricket::IceConfig config = CreateIceConfig(1000, false); + ep1_ch1()->SetIceConfig(config); + ep2_ch1()->SetIceConfig(config); + + EXPECT_TRUE_WAIT(ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 1000); + EXPECT_TRUE( + ep1_ch1()->best_connection() && ep2_ch1()->best_connection() && + LocalCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[0]) && + RemoteCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[1])); + + // Blackhole any traffic to or from the public addrs. + LOG(LS_INFO) << "Failing over..."; + fw()->AddRule(false, rtc::FP_ANY, rtc::FD_ANY, kPublicAddrs[1]); + // The best connections will switch, so keep references to them. + const cricket::Connection* best_connection1 = ep1_ch1()->best_connection(); + const cricket::Connection* best_connection2 = ep2_ch1()->best_connection(); + // We should detect loss of receiving within 1 second or so. + EXPECT_TRUE_WAIT( + !best_connection1->receiving() && !best_connection2->receiving(), 3000); + + // We should switch over to use the alternate addr immediately on both sides + // when we are not receiving. + EXPECT_TRUE_WAIT( + ep1_ch1()->best_connection()->receiving() && + ep2_ch1()->best_connection()->receiving(), 1000); + EXPECT_TRUE(LocalCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[0])); + EXPECT_TRUE( + RemoteCandidate(ep1_ch1())->address().EqualIPs(kAlternateAddrs[1])); + EXPECT_TRUE( + LocalCandidate(ep2_ch1())->address().EqualIPs(kAlternateAddrs[1])); + + DestroyChannels(); +} + +// Test that we can quickly switch links if an interface goes down. +// The controlling side has two interfaces and one will die. +TEST_F(P2PTransportChannelMultihomedTest, TestFailoverControllingSide) { + // Adding alternate address will make sure |kPublicAddrs| has the higher + // priority than others. This is due to FakeNetwork::AddInterface method. + AddAddress(0, kAlternateAddrs[0]); + AddAddress(0, kPublicAddrs[0]); + AddAddress(1, kPublicAddrs[1]); + + // Use only local ports for simplicity. + SetAllocatorFlags(0, kOnlyLocalPorts); + SetAllocatorFlags(1, kOnlyLocalPorts); + + // Create channels and let them go writable, as usual. + CreateChannels(1); + // Make the receiving timeout shorter for testing. + cricket::IceConfig config = CreateIceConfig(1000, false); + ep1_ch1()->SetIceConfig(config); + ep2_ch1()->SetIceConfig(config); + EXPECT_TRUE_WAIT(ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 1000); + EXPECT_TRUE( + ep1_ch1()->best_connection() && ep2_ch1()->best_connection() && + LocalCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[0]) && + RemoteCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[1])); + + // Blackhole any traffic to or from the public addrs. + LOG(LS_INFO) << "Failing over..."; + fw()->AddRule(false, rtc::FP_ANY, rtc::FD_ANY, kPublicAddrs[0]); + // The best connections will switch, so keep references to them. + const cricket::Connection* best_connection1 = ep1_ch1()->best_connection(); + const cricket::Connection* best_connection2 = ep2_ch1()->best_connection(); + // We should detect loss of receiving within 1 second or so. + EXPECT_TRUE_WAIT( + !best_connection1->receiving() && !best_connection2->receiving(), 3000); + + // We should switch over to use the alternate addr immediately on both sides + // when we are not receiving. + EXPECT_TRUE_WAIT( + ep1_ch1()->best_connection()->receiving() && + ep2_ch1()->best_connection()->receiving(), 1000); + EXPECT_TRUE( + LocalCandidate(ep1_ch1())->address().EqualIPs(kAlternateAddrs[0])); + EXPECT_TRUE(RemoteCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[1])); + EXPECT_TRUE( + RemoteCandidate(ep2_ch1())->address().EqualIPs(kAlternateAddrs[0])); + + DestroyChannels(); +} + +TEST_F(P2PTransportChannelMultihomedTest, TestGetState) { + AddAddress(0, kAlternateAddrs[0]); + AddAddress(0, kPublicAddrs[0]); + AddAddress(1, kPublicAddrs[1]); + // Create channels and let them go writable, as usual. + CreateChannels(1); + + // Both transport channels will reach STATE_COMPLETED quickly. + EXPECT_EQ_WAIT(cricket::TransportChannelState::STATE_COMPLETED, + ep1_ch1()->GetState(), 1000); + EXPECT_EQ_WAIT(cricket::TransportChannelState::STATE_COMPLETED, + ep2_ch1()->GetState(), 1000); +} + +/* + +TODO(pthatcher): Once have a way to handle network interfaces changes +without signalling an ICE restart, put a test like this back. In the +mean time, this test only worked for GICE. With ICE, it's currently +not possible without an ICE restart. + +// Test that we can switch links in a coordinated fashion. +TEST_F(P2PTransportChannelMultihomedTest, TestDrain) { + AddAddress(0, kPublicAddrs[0]); + AddAddress(1, kPublicAddrs[1]); + // Use only local ports for simplicity. + SetAllocatorFlags(0, kOnlyLocalPorts); + SetAllocatorFlags(1, kOnlyLocalPorts); + + // Create channels and let them go writable, as usual. + CreateChannels(1); + EXPECT_TRUE_WAIT(ep1_ch1()->receiving() && ep1_ch1()->writable() && + ep2_ch1()->receiving() && ep2_ch1()->writable(), + 1000); + EXPECT_TRUE( + ep1_ch1()->best_connection() && ep2_ch1()->best_connection() && + LocalCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[0]) && + RemoteCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[1])); + + + // Remove the public interface, add the alternate interface, and allocate + // a new generation of candidates for the new interface (via + // MaybeStartGathering()). + LOG(LS_INFO) << "Draining..."; + AddAddress(1, kAlternateAddrs[1]); + RemoveAddress(1, kPublicAddrs[1]); + ep2_ch1()->MaybeStartGathering(); + + // We should switch over to use the alternate address after + // an exchange of pings. + EXPECT_TRUE_WAIT( + ep1_ch1()->best_connection() && ep2_ch1()->best_connection() && + LocalCandidate(ep1_ch1())->address().EqualIPs(kPublicAddrs[0]) && + RemoteCandidate(ep1_ch1())->address().EqualIPs(kAlternateAddrs[1]), + 3000); + + DestroyChannels(); +} + +*/ + +// A collection of tests which tests a single P2PTransportChannel by sending +// pings. +class P2PTransportChannelPingTest : public testing::Test, + public sigslot::has_slots<> { + public: + P2PTransportChannelPingTest() + : pss_(new rtc::PhysicalSocketServer), + vss_(new rtc::VirtualSocketServer(pss_.get())), + ss_scope_(vss_.get()) {} + + protected: + void PrepareChannel(cricket::P2PTransportChannel* ch) { + ch->SetIceRole(cricket::ICEROLE_CONTROLLING); + ch->SetIceCredentials(kIceUfrag[0], kIcePwd[0]); + ch->SetRemoteIceCredentials(kIceUfrag[1], kIcePwd[1]); + } + + cricket::Candidate CreateCandidate(const std::string& ip, + int port, + int priority) { + cricket::Candidate c; + c.set_address(rtc::SocketAddress(ip, port)); + c.set_component(1); + c.set_protocol(cricket::UDP_PROTOCOL_NAME); + c.set_priority(priority); + return c; + } + + cricket::Connection* WaitForConnectionTo(cricket::P2PTransportChannel* ch, + const std::string& ip, + int port_num) { + EXPECT_TRUE_WAIT(GetConnectionTo(ch, ip, port_num) != nullptr, 3000); + return GetConnectionTo(ch, ip, port_num); + } + + cricket::Port* GetPort(cricket::P2PTransportChannel* ch) { + if (ch->ports().empty()) { + return nullptr; + } + return static_cast<cricket::Port*>(ch->ports()[0]); + } + + cricket::Connection* GetConnectionTo(cricket::P2PTransportChannel* ch, + const std::string& ip, + int port_num) { + cricket::Port* port = GetPort(ch); + if (!port) { + return nullptr; + } + return port->GetConnection(rtc::SocketAddress(ip, port_num)); + } + + private: + rtc::scoped_ptr<rtc::PhysicalSocketServer> pss_; + rtc::scoped_ptr<rtc::VirtualSocketServer> vss_; + rtc::SocketServerScope ss_scope_; +}; + +TEST_F(P2PTransportChannelPingTest, TestTriggeredChecks) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("trigger checks", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.Connect(); + ch.MaybeStartGathering(); + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 1)); + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 2)); + + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn1 != nullptr); + ASSERT_TRUE(conn2 != nullptr); + + // Before a triggered check, the first connection to ping is the + // highest priority one. + EXPECT_EQ(conn2, ch.FindNextPingableConnection()); + + // Receiving a ping causes a triggered check which should make conn1 + // be pinged first instead of conn2, even though conn2 has a higher + // priority. + conn1->ReceivedPing(); + EXPECT_EQ(conn1, ch.FindNextPingableConnection()); +} + +TEST_F(P2PTransportChannelPingTest, TestNoTriggeredChecksWhenWritable) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("trigger checks", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.Connect(); + ch.MaybeStartGathering(); + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 1)); + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 2)); + + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn1 != nullptr); + ASSERT_TRUE(conn2 != nullptr); + + EXPECT_EQ(conn2, ch.FindNextPingableConnection()); + conn1->ReceivedPingResponse(); + ASSERT_TRUE(conn1->writable()); + conn1->ReceivedPing(); + + // Ping received, but the connection is already writable, so no + // "triggered check" and conn2 is pinged before conn1 because it has + // a higher priority. + EXPECT_EQ(conn2, ch.FindNextPingableConnection()); +} + +TEST_F(P2PTransportChannelPingTest, ConnectionResurrection) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("connection resurrection", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.Connect(); + ch.MaybeStartGathering(); + + // Create conn1 and keep track of original candidate priority. + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 1)); + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + ASSERT_TRUE(conn1 != nullptr); + uint32_t remote_priority = conn1->remote_candidate().priority(); + + // Create a higher priority candidate and make the connection + // receiving/writable. This will prune conn1. + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 2)); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn2 != nullptr); + conn2->ReceivedPing(); + conn2->ReceivedPingResponse(); + + // Wait for conn1 to be pruned. + EXPECT_TRUE_WAIT(conn1->pruned(), 3000); + // Destroy the connection to test SignalUnknownAddress. + conn1->Destroy(); + EXPECT_TRUE_WAIT(GetConnectionTo(&ch, "1.1.1.1", 1) == nullptr, 1000); + + // Create a minimal STUN message with prflx priority. + cricket::IceMessage request; + request.SetType(cricket::STUN_BINDING_REQUEST); + request.AddAttribute(new cricket::StunByteStringAttribute( + cricket::STUN_ATTR_USERNAME, kIceUfrag[1])); + uint32_t prflx_priority = cricket::ICE_TYPE_PREFERENCE_PRFLX << 24; + request.AddAttribute(new cricket::StunUInt32Attribute( + cricket::STUN_ATTR_PRIORITY, prflx_priority)); + EXPECT_NE(prflx_priority, remote_priority); + + cricket::Port* port = GetPort(&ch); + // conn1 should be resurrected with original priority. + port->SignalUnknownAddress(port, rtc::SocketAddress("1.1.1.1", 1), + cricket::PROTO_UDP, &request, kIceUfrag[1], false); + conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + ASSERT_TRUE(conn1 != nullptr); + EXPECT_EQ(conn1->remote_candidate().priority(), remote_priority); + + // conn3, a real prflx connection, should have prflx priority. + port->SignalUnknownAddress(port, rtc::SocketAddress("3.3.3.3", 1), + cricket::PROTO_UDP, &request, kIceUfrag[1], false); + cricket::Connection* conn3 = WaitForConnectionTo(&ch, "3.3.3.3", 1); + ASSERT_TRUE(conn3 != nullptr); + EXPECT_EQ(conn3->remote_candidate().priority(), prflx_priority); +} + +TEST_F(P2PTransportChannelPingTest, TestReceivingStateChange) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("receiving state change", 1, nullptr, &pa); + PrepareChannel(&ch); + // Default receiving timeout and checking receiving delay should not be too + // small. + EXPECT_LE(1000, ch.receiving_timeout()); + EXPECT_LE(200, ch.check_receiving_delay()); + ch.SetIceConfig(CreateIceConfig(500, false)); + EXPECT_EQ(500, ch.receiving_timeout()); + EXPECT_EQ(50, ch.check_receiving_delay()); + ch.Connect(); + ch.MaybeStartGathering(); + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 1)); + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + ASSERT_TRUE(conn1 != nullptr); + + conn1->ReceivedPing(); + conn1->OnReadPacket("ABC", 3, rtc::CreatePacketTime(0)); + EXPECT_TRUE_WAIT(ch.best_connection() != nullptr, 1000) + EXPECT_TRUE_WAIT(ch.receiving(), 1000); + EXPECT_TRUE_WAIT(!ch.receiving(), 1000); +} + +// The controlled side will select a connection as the "best connection" based +// on priority until the controlling side nominates a connection, at which +// point the controlled side will select that connection as the +// "best connection". +TEST_F(P2PTransportChannelPingTest, TestSelectConnectionBeforeNomination) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("receiving state change", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.SetIceRole(cricket::ICEROLE_CONTROLLED); + ch.Connect(); + ch.MaybeStartGathering(); + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 1)); + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + ASSERT_TRUE(conn1 != nullptr); + EXPECT_EQ(conn1, ch.best_connection()); + + // When a higher priority candidate comes in, the new connection is chosen + // as the best connection. + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 10)); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn2 != nullptr); + EXPECT_EQ(conn2, ch.best_connection()); + + // If a stun request with use-candidate attribute arrives, the receiving + // connection will be set as the best connection, even though + // its priority is lower. + ch.AddRemoteCandidate(CreateCandidate("3.3.3.3", 3, 1)); + cricket::Connection* conn3 = WaitForConnectionTo(&ch, "3.3.3.3", 3); + ASSERT_TRUE(conn3 != nullptr); + // Because it has a lower priority, the best connection is still conn2. + EXPECT_EQ(conn2, ch.best_connection()); + conn3->ReceivedPingResponse(); // Become writable. + // But if it is nominated via use_candidate, it is chosen as the best + // connection. + conn3->set_nominated(true); + conn3->SignalNominated(conn3); + EXPECT_EQ(conn3, ch.best_connection()); + + // Even if another higher priority candidate arrives, + // it will not be set as the best connection because the best connection + // is nominated by the controlling side. + ch.AddRemoteCandidate(CreateCandidate("4.4.4.4", 4, 100)); + cricket::Connection* conn4 = WaitForConnectionTo(&ch, "4.4.4.4", 4); + ASSERT_TRUE(conn4 != nullptr); + EXPECT_EQ(conn3, ch.best_connection()); + // But if it is nominated via use_candidate and writable, it will be set as + // the best connection. + conn4->set_nominated(true); + conn4->SignalNominated(conn4); + // Not switched yet because conn4 is not writable. + EXPECT_EQ(conn3, ch.best_connection()); + // The best connection switches after conn4 becomes writable. + conn4->ReceivedPingResponse(); + EXPECT_EQ(conn4, ch.best_connection()); +} + +// The controlled side will select a connection as the "best connection" based +// on requests from an unknown address before the controlling side nominates +// a connection, and will nominate a connection from an unknown address if the +// request contains the use_candidate attribute. +TEST_F(P2PTransportChannelPingTest, TestSelectConnectionFromUnknownAddress) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("receiving state change", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.SetIceRole(cricket::ICEROLE_CONTROLLED); + ch.Connect(); + ch.MaybeStartGathering(); + // A minimal STUN message with prflx priority. + cricket::IceMessage request; + request.SetType(cricket::STUN_BINDING_REQUEST); + request.AddAttribute(new cricket::StunByteStringAttribute( + cricket::STUN_ATTR_USERNAME, kIceUfrag[1])); + uint32_t prflx_priority = cricket::ICE_TYPE_PREFERENCE_PRFLX << 24; + request.AddAttribute(new cricket::StunUInt32Attribute( + cricket::STUN_ATTR_PRIORITY, prflx_priority)); + cricket::Port* port = GetPort(&ch); + port->SignalUnknownAddress(port, rtc::SocketAddress("1.1.1.1", 1), + cricket::PROTO_UDP, &request, kIceUfrag[1], false); + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + ASSERT_TRUE(conn1 != nullptr); + EXPECT_EQ(conn1, ch.best_connection()); + conn1->ReceivedPingResponse(); + EXPECT_EQ(conn1, ch.best_connection()); + + // Another connection is nominated via use_candidate. + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 1)); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn2 != nullptr); + // Because it has a lower priority, the best connection is still conn1. + EXPECT_EQ(conn1, ch.best_connection()); + // When it is nominated via use_candidate and writable, it is chosen as the + // best connection. + conn2->ReceivedPingResponse(); // Become writable. + conn2->set_nominated(true); + conn2->SignalNominated(conn2); + EXPECT_EQ(conn2, ch.best_connection()); + + // Another request with unknown address, it will not be set as the best + // connection because the best connection was nominated by the controlling + // side. + port->SignalUnknownAddress(port, rtc::SocketAddress("3.3.3.3", 3), + cricket::PROTO_UDP, &request, kIceUfrag[1], false); + cricket::Connection* conn3 = WaitForConnectionTo(&ch, "3.3.3.3", 3); + ASSERT_TRUE(conn3 != nullptr); + conn3->ReceivedPingResponse(); // Become writable. + EXPECT_EQ(conn2, ch.best_connection()); + + // However if the request contains use_candidate attribute, it will be + // selected as the best connection. + request.AddAttribute( + new cricket::StunByteStringAttribute(cricket::STUN_ATTR_USE_CANDIDATE)); + port->SignalUnknownAddress(port, rtc::SocketAddress("4.4.4.4", 4), + cricket::PROTO_UDP, &request, kIceUfrag[1], false); + cricket::Connection* conn4 = WaitForConnectionTo(&ch, "4.4.4.4", 4); + ASSERT_TRUE(conn4 != nullptr); + // conn4 is not the best connection yet because it is not writable. + EXPECT_EQ(conn2, ch.best_connection()); + conn4->ReceivedPingResponse(); // Become writable. + EXPECT_EQ(conn4, ch.best_connection()); +} + +// The controlled side will select a connection as the "best connection" +// based on media received until the controlling side nominates a connection, +// at which point the controlled side will select that connection as +// the "best connection". +TEST_F(P2PTransportChannelPingTest, TestSelectConnectionBasedOnMediaReceived) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("receiving state change", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.SetIceRole(cricket::ICEROLE_CONTROLLED); + ch.Connect(); + ch.MaybeStartGathering(); + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 10)); + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + ASSERT_TRUE(conn1 != nullptr); + EXPECT_EQ(conn1, ch.best_connection()); + + // If a data packet is received on conn2, the best connection should + // switch to conn2 because the controlled side must mirror the media path + // chosen by the controlling side. + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 1)); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn2 != nullptr); + conn2->ReceivedPing(); // Start receiving. + // Do not switch because it is not writable. + conn2->OnReadPacket("ABC", 3, rtc::CreatePacketTime(0)); + EXPECT_EQ(conn1, ch.best_connection()); + + conn2->ReceivedPingResponse(); // Become writable. + // Switch because it is writable. + conn2->OnReadPacket("DEF", 3, rtc::CreatePacketTime(0)); + EXPECT_EQ(conn2, ch.best_connection()); + + // Now another STUN message with an unknown address and use_candidate will + // nominate the best connection. + cricket::IceMessage request; + request.SetType(cricket::STUN_BINDING_REQUEST); + request.AddAttribute(new cricket::StunByteStringAttribute( + cricket::STUN_ATTR_USERNAME, kIceUfrag[1])); + uint32_t prflx_priority = cricket::ICE_TYPE_PREFERENCE_PRFLX << 24; + request.AddAttribute(new cricket::StunUInt32Attribute( + cricket::STUN_ATTR_PRIORITY, prflx_priority)); + request.AddAttribute( + new cricket::StunByteStringAttribute(cricket::STUN_ATTR_USE_CANDIDATE)); + cricket::Port* port = GetPort(&ch); + port->SignalUnknownAddress(port, rtc::SocketAddress("3.3.3.3", 3), + cricket::PROTO_UDP, &request, kIceUfrag[1], false); + cricket::Connection* conn3 = WaitForConnectionTo(&ch, "3.3.3.3", 3); + ASSERT_TRUE(conn3 != nullptr); + EXPECT_EQ(conn2, ch.best_connection()); // Not writable yet. + conn3->ReceivedPingResponse(); // Become writable. + EXPECT_EQ(conn3, ch.best_connection()); + + // Now another data packet will not switch the best connection because the + // best connection was nominated by the controlling side. + conn2->ReceivedPing(); + conn2->ReceivedPingResponse(); + conn2->OnReadPacket("XYZ", 3, rtc::CreatePacketTime(0)); + EXPECT_EQ(conn3, ch.best_connection()); +} + +// When the current best connection is strong, lower-priority connections will +// be pruned. Otherwise, lower-priority connections are kept. +TEST_F(P2PTransportChannelPingTest, TestDontPruneWhenWeak) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("test channel", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.SetIceRole(cricket::ICEROLE_CONTROLLED); + ch.Connect(); + ch.MaybeStartGathering(); + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 1)); + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + ASSERT_TRUE(conn1 != nullptr); + EXPECT_EQ(conn1, ch.best_connection()); + conn1->ReceivedPingResponse(); // Becomes writable and receiving + + // When a higher-priority, nominated candidate comes in, the connections with + // lower-priority are pruned. + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 10)); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn2 != nullptr); + conn2->ReceivedPingResponse(); // Becomes writable and receiving + conn2->set_nominated(true); + conn2->SignalNominated(conn2); + EXPECT_TRUE_WAIT(conn1->pruned(), 3000); + + ch.SetIceConfig(CreateIceConfig(500, false)); + // Wait until conn2 becomes not receiving. + EXPECT_TRUE_WAIT(!conn2->receiving(), 3000); + + ch.AddRemoteCandidate(CreateCandidate("3.3.3.3", 3, 1)); + cricket::Connection* conn3 = WaitForConnectionTo(&ch, "3.3.3.3", 3); + ASSERT_TRUE(conn3 != nullptr); + // The best connection should still be conn2. Even through conn3 has lower + // priority and is not receiving/writable, it is not pruned because the best + // connection is not receiving. + WAIT(conn3->pruned(), 1000); + EXPECT_FALSE(conn3->pruned()); +} + +// Test that GetState returns the state correctly. +TEST_F(P2PTransportChannelPingTest, TestGetState) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("test channel", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.Connect(); + ch.MaybeStartGathering(); + EXPECT_EQ(cricket::TransportChannelState::STATE_INIT, ch.GetState()); + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 100)); + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 1)); + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn1 != nullptr); + ASSERT_TRUE(conn2 != nullptr); + // Now there are two connections, so the transport channel is connecting. + EXPECT_EQ(cricket::TransportChannelState::STATE_CONNECTING, ch.GetState()); + // |conn1| becomes writable and receiving; it then should prune |conn2|. + conn1->ReceivedPingResponse(); + EXPECT_TRUE_WAIT(conn2->pruned(), 1000); + EXPECT_EQ(cricket::TransportChannelState::STATE_COMPLETED, ch.GetState()); + conn1->Prune(); // All connections are pruned. + EXPECT_EQ(cricket::TransportChannelState::STATE_FAILED, ch.GetState()); +} + +// Test that when a low-priority connection is pruned, it is not deleted +// right away, and it can become active and be pruned again. +TEST_F(P2PTransportChannelPingTest, TestConnectionPrunedAgain) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("test channel", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.SetIceConfig(CreateIceConfig(1000, false)); + ch.Connect(); + ch.MaybeStartGathering(); + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 100)); + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + ASSERT_TRUE(conn1 != nullptr); + EXPECT_EQ(conn1, ch.best_connection()); + conn1->ReceivedPingResponse(); // Becomes writable and receiving + + // Add a low-priority connection |conn2|, which will be pruned, but it will + // not be deleted right away. Once the current best connection becomes not + // receiving, |conn2| will start to ping and upon receiving the ping response, + // it will become the best connection. + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 1)); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn2 != nullptr); + EXPECT_TRUE_WAIT(!conn2->active(), 1000); + // |conn2| should not send a ping yet. + EXPECT_EQ(cricket::Connection::STATE_WAITING, conn2->state()); + EXPECT_EQ(cricket::TransportChannelState::STATE_COMPLETED, ch.GetState()); + // Wait for |conn1| becoming not receiving. + EXPECT_TRUE_WAIT(!conn1->receiving(), 3000); + // Make sure conn2 is not deleted. + conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn2 != nullptr); + EXPECT_EQ_WAIT(cricket::Connection::STATE_INPROGRESS, conn2->state(), 1000); + conn2->ReceivedPingResponse(); + EXPECT_EQ_WAIT(conn2, ch.best_connection(), 1000); + EXPECT_EQ(cricket::TransportChannelState::STATE_CONNECTING, ch.GetState()); + + // When |conn1| comes back again, |conn2| will be pruned again. + conn1->ReceivedPingResponse(); + EXPECT_EQ_WAIT(conn1, ch.best_connection(), 1000); + EXPECT_TRUE_WAIT(!conn2->active(), 1000); + EXPECT_EQ(cricket::TransportChannelState::STATE_COMPLETED, ch.GetState()); +} + +// Test that if all connections in a channel has timed out on writing, they +// will all be deleted. We use Prune to simulate write_time_out. +TEST_F(P2PTransportChannelPingTest, TestDeleteConnectionsIfAllWriteTimedout) { + cricket::FakePortAllocator pa(rtc::Thread::Current(), nullptr); + cricket::P2PTransportChannel ch("test channel", 1, nullptr, &pa); + PrepareChannel(&ch); + ch.Connect(); + ch.MaybeStartGathering(); + // Have one connection only but later becomes write-time-out. + ch.AddRemoteCandidate(CreateCandidate("1.1.1.1", 1, 100)); + cricket::Connection* conn1 = WaitForConnectionTo(&ch, "1.1.1.1", 1); + ASSERT_TRUE(conn1 != nullptr); + conn1->ReceivedPing(); // Becomes receiving + conn1->Prune(); + EXPECT_TRUE_WAIT(ch.connections().empty(), 1000); + + // Have two connections but both become write-time-out later. + ch.AddRemoteCandidate(CreateCandidate("2.2.2.2", 2, 1)); + cricket::Connection* conn2 = WaitForConnectionTo(&ch, "2.2.2.2", 2); + ASSERT_TRUE(conn2 != nullptr); + conn2->ReceivedPing(); // Becomes receiving + ch.AddRemoteCandidate(CreateCandidate("3.3.3.3", 3, 2)); + cricket::Connection* conn3 = WaitForConnectionTo(&ch, "3.3.3.3", 3); + ASSERT_TRUE(conn3 != nullptr); + conn3->ReceivedPing(); // Becomes receiving + // Now prune both conn2 and conn3; they will be deleted soon. + conn2->Prune(); + conn3->Prune(); + EXPECT_TRUE_WAIT(ch.connections().empty(), 1000); +} diff --git a/webrtc/p2p/base/packetsocketfactory.h b/webrtc/p2p/base/packetsocketfactory.h new file mode 100644 index 0000000000..54037241b0 --- /dev/null +++ b/webrtc/p2p/base/packetsocketfactory.h @@ -0,0 +1,58 @@ +/* + * Copyright 2011 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_PACKETSOCKETFACTORY_H_ +#define WEBRTC_P2P_BASE_PACKETSOCKETFACTORY_H_ + +#include "webrtc/base/proxyinfo.h" + +namespace rtc { + +class AsyncPacketSocket; +class AsyncResolverInterface; + +class PacketSocketFactory { + public: + enum Options { + OPT_SSLTCP = 0x01, // Pseudo-TLS. + OPT_TLS = 0x02, + OPT_STUN = 0x04, + }; + + PacketSocketFactory() { } + virtual ~PacketSocketFactory() { } + + virtual AsyncPacketSocket* CreateUdpSocket(const SocketAddress& address, + uint16_t min_port, + uint16_t max_port) = 0; + virtual AsyncPacketSocket* CreateServerTcpSocket( + const SocketAddress& local_address, + uint16_t min_port, + uint16_t max_port, + int opts) = 0; + + // TODO: |proxy_info| and |user_agent| should be set + // per-factory and not when socket is created. + virtual AsyncPacketSocket* CreateClientTcpSocket( + const SocketAddress& local_address, + const SocketAddress& remote_address, + const ProxyInfo& proxy_info, + const std::string& user_agent, + int opts) = 0; + + virtual AsyncResolverInterface* CreateAsyncResolver() = 0; + + private: + RTC_DISALLOW_COPY_AND_ASSIGN(PacketSocketFactory); +}; + +} // namespace rtc + +#endif // WEBRTC_P2P_BASE_PACKETSOCKETFACTORY_H_ diff --git a/webrtc/p2p/base/port.cc b/webrtc/p2p/base/port.cc new file mode 100644 index 0000000000..d34b05f8e9 --- /dev/null +++ b/webrtc/p2p/base/port.cc @@ -0,0 +1,1423 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/port.h" + +#include <algorithm> +#include <vector> + +#include "webrtc/p2p/base/common.h" +#include "webrtc/p2p/base/portallocator.h" +#include "webrtc/base/base64.h" +#include "webrtc/base/crc32.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/messagedigest.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/stringencode.h" +#include "webrtc/base/stringutils.h" + +namespace { + +// Determines whether we have seen at least the given maximum number of +// pings fail to have a response. +inline bool TooManyFailures( + const std::vector<cricket::Connection::SentPing>& pings_since_last_response, + uint32_t maximum_failures, + uint32_t rtt_estimate, + uint32_t now) { + // If we haven't sent that many pings, then we can't have failed that many. + if (pings_since_last_response.size() < maximum_failures) + return false; + + // Check if the window in which we would expect a response to the ping has + // already elapsed. + uint32_t expected_response_time = + pings_since_last_response[maximum_failures - 1].sent_time + rtt_estimate; + return now > expected_response_time; +} + +// Determines whether we have gone too long without seeing any response. +inline bool TooLongWithoutResponse( + const std::vector<cricket::Connection::SentPing>& pings_since_last_response, + uint32_t maximum_time, + uint32_t now) { + if (pings_since_last_response.size() == 0) + return false; + + auto first = pings_since_last_response[0]; + return now > (first.sent_time + maximum_time); +} + +// We will restrict RTT estimates (when used for determining state) to be +// within a reasonable range. +const uint32_t MINIMUM_RTT = 100; // 0.1 seconds +const uint32_t MAXIMUM_RTT = 3000; // 3 seconds + +// When we don't have any RTT data, we have to pick something reasonable. We +// use a large value just in case the connection is really slow. +const uint32_t DEFAULT_RTT = MAXIMUM_RTT; + +// Computes our estimate of the RTT given the current estimate. +inline uint32_t ConservativeRTTEstimate(uint32_t rtt) { + return std::max(MINIMUM_RTT, std::min(MAXIMUM_RTT, 2 * rtt)); +} + +// Weighting of the old rtt value to new data. +const int RTT_RATIO = 3; // 3 : 1 + +// The delay before we begin checking if this port is useless. +const int kPortTimeoutDelay = 30 * 1000; // 30 seconds +} + +namespace cricket { + +// TODO(ronghuawu): Use "host", "srflx", "prflx" and "relay". But this requires +// the signaling part be updated correspondingly as well. +const char LOCAL_PORT_TYPE[] = "local"; +const char STUN_PORT_TYPE[] = "stun"; +const char PRFLX_PORT_TYPE[] = "prflx"; +const char RELAY_PORT_TYPE[] = "relay"; + +const char UDP_PROTOCOL_NAME[] = "udp"; +const char TCP_PROTOCOL_NAME[] = "tcp"; +const char SSLTCP_PROTOCOL_NAME[] = "ssltcp"; + +static const char* const PROTO_NAMES[] = { UDP_PROTOCOL_NAME, + TCP_PROTOCOL_NAME, + SSLTCP_PROTOCOL_NAME }; + +const char* ProtoToString(ProtocolType proto) { + return PROTO_NAMES[proto]; +} + +bool StringToProto(const char* value, ProtocolType* proto) { + for (size_t i = 0; i <= PROTO_LAST; ++i) { + if (_stricmp(PROTO_NAMES[i], value) == 0) { + *proto = static_cast<ProtocolType>(i); + return true; + } + } + return false; +} + +// RFC 6544, TCP candidate encoding rules. +const int DISCARD_PORT = 9; +const char TCPTYPE_ACTIVE_STR[] = "active"; +const char TCPTYPE_PASSIVE_STR[] = "passive"; +const char TCPTYPE_SIMOPEN_STR[] = "so"; + +// Foundation: An arbitrary string that is the same for two candidates +// that have the same type, base IP address, protocol (UDP, TCP, +// etc.), and STUN or TURN server. If any of these are different, +// then the foundation will be different. Two candidate pairs with +// the same foundation pairs are likely to have similar network +// characteristics. Foundations are used in the frozen algorithm. +static std::string ComputeFoundation( + const std::string& type, + const std::string& protocol, + const rtc::SocketAddress& base_address) { + std::ostringstream ost; + ost << type << base_address.ipaddr().ToString() << protocol; + return rtc::ToString<uint32_t>(rtc::ComputeCrc32(ost.str())); +} + +Port::Port(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + const std::string& username_fragment, + const std::string& password) + : thread_(thread), + factory_(factory), + send_retransmit_count_attribute_(false), + network_(network), + ip_(ip), + min_port_(0), + max_port_(0), + component_(ICE_CANDIDATE_COMPONENT_DEFAULT), + generation_(0), + ice_username_fragment_(username_fragment), + password_(password), + timeout_delay_(kPortTimeoutDelay), + enable_port_packets_(false), + ice_role_(ICEROLE_UNKNOWN), + tiebreaker_(0), + shared_socket_(true), + candidate_filter_(CF_ALL) { + Construct(); +} + +Port::Port(rtc::Thread* thread, + const std::string& type, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username_fragment, + const std::string& password) + : thread_(thread), + factory_(factory), + type_(type), + send_retransmit_count_attribute_(false), + network_(network), + ip_(ip), + min_port_(min_port), + max_port_(max_port), + component_(ICE_CANDIDATE_COMPONENT_DEFAULT), + generation_(0), + ice_username_fragment_(username_fragment), + password_(password), + timeout_delay_(kPortTimeoutDelay), + enable_port_packets_(false), + ice_role_(ICEROLE_UNKNOWN), + tiebreaker_(0), + shared_socket_(false), + candidate_filter_(CF_ALL) { + ASSERT(factory_ != NULL); + Construct(); +} + +void Port::Construct() { + // TODO(pthatcher): Remove this old behavior once we're sure no one + // relies on it. If the username_fragment and password are empty, + // we should just create one. + if (ice_username_fragment_.empty()) { + ASSERT(password_.empty()); + ice_username_fragment_ = rtc::CreateRandomString(ICE_UFRAG_LENGTH); + password_ = rtc::CreateRandomString(ICE_PWD_LENGTH); + } + LOG_J(LS_INFO, this) << "Port created"; +} + +Port::~Port() { + // Delete all of the remaining connections. We copy the list up front + // because each deletion will cause it to be modified. + + std::vector<Connection*> list; + + AddressMap::iterator iter = connections_.begin(); + while (iter != connections_.end()) { + list.push_back(iter->second); + ++iter; + } + + for (uint32_t i = 0; i < list.size(); i++) + delete list[i]; +} + +Connection* Port::GetConnection(const rtc::SocketAddress& remote_addr) { + AddressMap::const_iterator iter = connections_.find(remote_addr); + if (iter != connections_.end()) + return iter->second; + else + return NULL; +} + +void Port::AddAddress(const rtc::SocketAddress& address, + const rtc::SocketAddress& base_address, + const rtc::SocketAddress& related_address, + const std::string& protocol, + const std::string& relay_protocol, + const std::string& tcptype, + const std::string& type, + uint32_t type_preference, + uint32_t relay_preference, + bool final) { + if (protocol == TCP_PROTOCOL_NAME && type == LOCAL_PORT_TYPE) { + ASSERT(!tcptype.empty()); + } + + Candidate c; + c.set_id(rtc::CreateRandomString(8)); + c.set_component(component_); + c.set_type(type); + c.set_protocol(protocol); + c.set_relay_protocol(relay_protocol); + c.set_tcptype(tcptype); + c.set_address(address); + c.set_priority(c.GetPriority(type_preference, network_->preference(), + relay_preference)); + c.set_username(username_fragment()); + c.set_password(password_); + c.set_network_name(network_->name()); + c.set_network_type(network_->type()); + c.set_generation(generation_); + c.set_related_address(related_address); + c.set_foundation(ComputeFoundation(type, protocol, base_address)); + candidates_.push_back(c); + SignalCandidateReady(this, c); + + if (final) { + SignalPortComplete(this); + } +} + +void Port::AddConnection(Connection* conn) { + connections_[conn->remote_candidate().address()] = conn; + conn->SignalDestroyed.connect(this, &Port::OnConnectionDestroyed); + SignalConnectionCreated(this, conn); +} + +void Port::OnReadPacket( + const char* data, size_t size, const rtc::SocketAddress& addr, + ProtocolType proto) { + // If the user has enabled port packets, just hand this over. + if (enable_port_packets_) { + SignalReadPacket(this, data, size, addr); + return; + } + + // If this is an authenticated STUN request, then signal unknown address and + // send back a proper binding response. + rtc::scoped_ptr<IceMessage> msg; + std::string remote_username; + if (!GetStunMessage(data, size, addr, msg.accept(), &remote_username)) { + LOG_J(LS_ERROR, this) << "Received non-STUN packet from unknown address (" + << addr.ToSensitiveString() << ")"; + } else if (!msg) { + // STUN message handled already + } else if (msg->type() == STUN_BINDING_REQUEST) { + LOG(LS_INFO) << "Received STUN ping " + << " id=" << rtc::hex_encode(msg->transaction_id()) + << " from unknown address " << addr.ToSensitiveString(); + + // Check for role conflicts. + if (!MaybeIceRoleConflict(addr, msg.get(), remote_username)) { + LOG(LS_INFO) << "Received conflicting role from the peer."; + return; + } + + SignalUnknownAddress(this, addr, proto, msg.get(), remote_username, false); + } else { + // NOTE(tschmelcher): STUN_BINDING_RESPONSE is benign. It occurs if we + // pruned a connection for this port while it had STUN requests in flight, + // because we then get back responses for them, which this code correctly + // does not handle. + if (msg->type() != STUN_BINDING_RESPONSE) { + LOG_J(LS_ERROR, this) << "Received unexpected STUN message type (" + << msg->type() << ") from unknown address (" + << addr.ToSensitiveString() << ")"; + } + } +} + +void Port::OnSentPacket(const rtc::SentPacket& sent_packet) { + PortInterface::SignalSentPacket(this, sent_packet); +} + +void Port::OnReadyToSend() { + AddressMap::iterator iter = connections_.begin(); + for (; iter != connections_.end(); ++iter) { + iter->second->OnReadyToSend(); + } +} + +size_t Port::AddPrflxCandidate(const Candidate& local) { + candidates_.push_back(local); + return (candidates_.size() - 1); +} + +bool Port::GetStunMessage(const char* data, size_t size, + const rtc::SocketAddress& addr, + IceMessage** out_msg, std::string* out_username) { + // NOTE: This could clearly be optimized to avoid allocating any memory. + // However, at the data rates we'll be looking at on the client side, + // this probably isn't worth worrying about. + ASSERT(out_msg != NULL); + ASSERT(out_username != NULL); + *out_msg = NULL; + out_username->clear(); + + // Don't bother parsing the packet if we can tell it's not STUN. + // In ICE mode, all STUN packets will have a valid fingerprint. + if (!StunMessage::ValidateFingerprint(data, size)) { + return false; + } + + // Parse the request message. If the packet is not a complete and correct + // STUN message, then ignore it. + rtc::scoped_ptr<IceMessage> stun_msg(new IceMessage()); + rtc::ByteBuffer buf(data, size); + if (!stun_msg->Read(&buf) || (buf.Length() > 0)) { + return false; + } + + if (stun_msg->type() == STUN_BINDING_REQUEST) { + // Check for the presence of USERNAME and MESSAGE-INTEGRITY (if ICE) first. + // If not present, fail with a 400 Bad Request. + if (!stun_msg->GetByteString(STUN_ATTR_USERNAME) || + !stun_msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY)) { + LOG_J(LS_ERROR, this) << "Received STUN request without username/M-I " + << "from " << addr.ToSensitiveString(); + SendBindingErrorResponse(stun_msg.get(), addr, STUN_ERROR_BAD_REQUEST, + STUN_ERROR_REASON_BAD_REQUEST); + return true; + } + + // If the username is bad or unknown, fail with a 401 Unauthorized. + std::string local_ufrag; + std::string remote_ufrag; + if (!ParseStunUsername(stun_msg.get(), &local_ufrag, &remote_ufrag) || + local_ufrag != username_fragment()) { + LOG_J(LS_ERROR, this) << "Received STUN request with bad local username " + << local_ufrag << " from " + << addr.ToSensitiveString(); + SendBindingErrorResponse(stun_msg.get(), addr, STUN_ERROR_UNAUTHORIZED, + STUN_ERROR_REASON_UNAUTHORIZED); + return true; + } + + // If ICE, and the MESSAGE-INTEGRITY is bad, fail with a 401 Unauthorized + if (!stun_msg->ValidateMessageIntegrity(data, size, password_)) { + LOG_J(LS_ERROR, this) << "Received STUN request with bad M-I " + << "from " << addr.ToSensitiveString() + << ", password_=" << password_; + SendBindingErrorResponse(stun_msg.get(), addr, STUN_ERROR_UNAUTHORIZED, + STUN_ERROR_REASON_UNAUTHORIZED); + return true; + } + out_username->assign(remote_ufrag); + } else if ((stun_msg->type() == STUN_BINDING_RESPONSE) || + (stun_msg->type() == STUN_BINDING_ERROR_RESPONSE)) { + if (stun_msg->type() == STUN_BINDING_ERROR_RESPONSE) { + if (const StunErrorCodeAttribute* error_code = stun_msg->GetErrorCode()) { + LOG_J(LS_ERROR, this) << "Received STUN binding error:" + << " class=" << error_code->eclass() + << " number=" << error_code->number() + << " reason='" << error_code->reason() << "'" + << " from " << addr.ToSensitiveString(); + // Return message to allow error-specific processing + } else { + LOG_J(LS_ERROR, this) << "Received STUN binding error without a error " + << "code from " << addr.ToSensitiveString(); + return true; + } + } + // NOTE: Username should not be used in verifying response messages. + out_username->clear(); + } else if (stun_msg->type() == STUN_BINDING_INDICATION) { + LOG_J(LS_VERBOSE, this) << "Received STUN binding indication:" + << " from " << addr.ToSensitiveString(); + out_username->clear(); + // No stun attributes will be verified, if it's stun indication message. + // Returning from end of the this method. + } else { + LOG_J(LS_ERROR, this) << "Received STUN packet with invalid type (" + << stun_msg->type() << ") from " + << addr.ToSensitiveString(); + return true; + } + + // Return the STUN message found. + *out_msg = stun_msg.release(); + return true; +} + +bool Port::IsCompatibleAddress(const rtc::SocketAddress& addr) { + int family = ip().family(); + // We use single-stack sockets, so families must match. + if (addr.family() != family) { + return false; + } + // Link-local IPv6 ports can only connect to other link-local IPv6 ports. + if (family == AF_INET6 && + (IPIsLinkLocal(ip()) != IPIsLinkLocal(addr.ipaddr()))) { + return false; + } + return true; +} + +bool Port::ParseStunUsername(const StunMessage* stun_msg, + std::string* local_ufrag, + std::string* remote_ufrag) const { + // The packet must include a username that either begins or ends with our + // fragment. It should begin with our fragment if it is a request and it + // should end with our fragment if it is a response. + local_ufrag->clear(); + remote_ufrag->clear(); + const StunByteStringAttribute* username_attr = + stun_msg->GetByteString(STUN_ATTR_USERNAME); + if (username_attr == NULL) + return false; + + // RFRAG:LFRAG + const std::string username = username_attr->GetString(); + size_t colon_pos = username.find(":"); + if (colon_pos == std::string::npos) { + return false; + } + + *local_ufrag = username.substr(0, colon_pos); + *remote_ufrag = username.substr(colon_pos + 1, username.size()); + return true; +} + +bool Port::MaybeIceRoleConflict( + const rtc::SocketAddress& addr, IceMessage* stun_msg, + const std::string& remote_ufrag) { + // Validate ICE_CONTROLLING or ICE_CONTROLLED attributes. + bool ret = true; + IceRole remote_ice_role = ICEROLE_UNKNOWN; + uint64_t remote_tiebreaker = 0; + const StunUInt64Attribute* stun_attr = + stun_msg->GetUInt64(STUN_ATTR_ICE_CONTROLLING); + if (stun_attr) { + remote_ice_role = ICEROLE_CONTROLLING; + remote_tiebreaker = stun_attr->value(); + } + + // If |remote_ufrag| is same as port local username fragment and + // tie breaker value received in the ping message matches port + // tiebreaker value this must be a loopback call. + // We will treat this as valid scenario. + if (remote_ice_role == ICEROLE_CONTROLLING && + username_fragment() == remote_ufrag && + remote_tiebreaker == IceTiebreaker()) { + return true; + } + + stun_attr = stun_msg->GetUInt64(STUN_ATTR_ICE_CONTROLLED); + if (stun_attr) { + remote_ice_role = ICEROLE_CONTROLLED; + remote_tiebreaker = stun_attr->value(); + } + + switch (ice_role_) { + case ICEROLE_CONTROLLING: + if (ICEROLE_CONTROLLING == remote_ice_role) { + if (remote_tiebreaker >= tiebreaker_) { + SignalRoleConflict(this); + } else { + // Send Role Conflict (487) error response. + SendBindingErrorResponse(stun_msg, addr, + STUN_ERROR_ROLE_CONFLICT, STUN_ERROR_REASON_ROLE_CONFLICT); + ret = false; + } + } + break; + case ICEROLE_CONTROLLED: + if (ICEROLE_CONTROLLED == remote_ice_role) { + if (remote_tiebreaker < tiebreaker_) { + SignalRoleConflict(this); + } else { + // Send Role Conflict (487) error response. + SendBindingErrorResponse(stun_msg, addr, + STUN_ERROR_ROLE_CONFLICT, STUN_ERROR_REASON_ROLE_CONFLICT); + ret = false; + } + } + break; + default: + ASSERT(false); + } + return ret; +} + +void Port::CreateStunUsername(const std::string& remote_username, + std::string* stun_username_attr_str) const { + stun_username_attr_str->clear(); + *stun_username_attr_str = remote_username; + stun_username_attr_str->append(":"); + stun_username_attr_str->append(username_fragment()); +} + +void Port::SendBindingResponse(StunMessage* request, + const rtc::SocketAddress& addr) { + ASSERT(request->type() == STUN_BINDING_REQUEST); + + // Retrieve the username from the request. + const StunByteStringAttribute* username_attr = + request->GetByteString(STUN_ATTR_USERNAME); + ASSERT(username_attr != NULL); + if (username_attr == NULL) { + // No valid username, skip the response. + return; + } + + // Fill in the response message. + StunMessage response; + response.SetType(STUN_BINDING_RESPONSE); + response.SetTransactionID(request->transaction_id()); + const StunUInt32Attribute* retransmit_attr = + request->GetUInt32(STUN_ATTR_RETRANSMIT_COUNT); + if (retransmit_attr) { + // Inherit the incoming retransmit value in the response so the other side + // can see our view of lost pings. + response.AddAttribute(new StunUInt32Attribute( + STUN_ATTR_RETRANSMIT_COUNT, retransmit_attr->value())); + + if (retransmit_attr->value() > CONNECTION_WRITE_CONNECT_FAILURES) { + LOG_J(LS_INFO, this) + << "Received a remote ping with high retransmit count: " + << retransmit_attr->value(); + } + } + + response.AddAttribute( + new StunXorAddressAttribute(STUN_ATTR_XOR_MAPPED_ADDRESS, addr)); + response.AddMessageIntegrity(password_); + response.AddFingerprint(); + + // The fact that we received a successful request means that this connection + // (if one exists) should now be receiving. + Connection* conn = GetConnection(addr); + + // Send the response message. + rtc::ByteBuffer buf; + response.Write(&buf); + rtc::PacketOptions options(DefaultDscpValue()); + auto err = SendTo(buf.Data(), buf.Length(), addr, options, false); + if (err < 0) { + LOG_J(LS_ERROR, this) + << "Failed to send STUN ping response" + << ", to=" << addr.ToSensitiveString() + << ", err=" << err + << ", id=" << rtc::hex_encode(response.transaction_id()); + } else { + // Log at LS_INFO if we send a stun ping response on an unwritable + // connection. + rtc::LoggingSeverity sev = (conn && !conn->writable()) ? + rtc::LS_INFO : rtc::LS_VERBOSE; + LOG_JV(sev, this) + << "Sent STUN ping response" + << ", to=" << addr.ToSensitiveString() + << ", id=" << rtc::hex_encode(response.transaction_id()); + } + + ASSERT(conn != NULL); + if (conn) + conn->ReceivedPing(); +} + +void Port::SendBindingErrorResponse(StunMessage* request, + const rtc::SocketAddress& addr, + int error_code, const std::string& reason) { + ASSERT(request->type() == STUN_BINDING_REQUEST); + + // Fill in the response message. + StunMessage response; + response.SetType(STUN_BINDING_ERROR_RESPONSE); + response.SetTransactionID(request->transaction_id()); + + // When doing GICE, we need to write out the error code incorrectly to + // maintain backwards compatiblility. + StunErrorCodeAttribute* error_attr = StunAttribute::CreateErrorCode(); + error_attr->SetCode(error_code); + error_attr->SetReason(reason); + response.AddAttribute(error_attr); + + // Per Section 10.1.2, certain error cases don't get a MESSAGE-INTEGRITY, + // because we don't have enough information to determine the shared secret. + if (error_code != STUN_ERROR_BAD_REQUEST && + error_code != STUN_ERROR_UNAUTHORIZED) + response.AddMessageIntegrity(password_); + response.AddFingerprint(); + + // Send the response message. + rtc::ByteBuffer buf; + response.Write(&buf); + rtc::PacketOptions options(DefaultDscpValue()); + SendTo(buf.Data(), buf.Length(), addr, options, false); + LOG_J(LS_INFO, this) << "Sending STUN binding error: reason=" << reason + << " to " << addr.ToSensitiveString(); +} + +void Port::OnMessage(rtc::Message *pmsg) { + ASSERT(pmsg->message_id == MSG_DEAD); + if (dead()) { + Destroy(); + } +} + +std::string Port::ToString() const { + std::stringstream ss; + ss << "Port[" << content_name_ << ":" << component_ + << ":" << generation_ << ":" << type_ + << ":" << network_->ToString() << "]"; + return ss.str(); +} + +void Port::EnablePortPackets() { + enable_port_packets_ = true; +} + +void Port::OnConnectionDestroyed(Connection* conn) { + AddressMap::iterator iter = + connections_.find(conn->remote_candidate().address()); + ASSERT(iter != connections_.end()); + connections_.erase(iter); + + // On the controlled side, ports time out after all connections fail. + // Note: If a new connection is added after this message is posted, but it + // fails and is removed before kPortTimeoutDelay, then this message will + // still cause the Port to be destroyed. + if (dead()) { + thread_->PostDelayed(timeout_delay_, this, MSG_DEAD); + } +} + +void Port::Destroy() { + ASSERT(connections_.empty()); + LOG_J(LS_INFO, this) << "Port deleted"; + SignalDestroyed(this); + delete this; +} + +const std::string Port::username_fragment() const { + return ice_username_fragment_; +} + +// A ConnectionRequest is a simple STUN ping used to determine writability. +class ConnectionRequest : public StunRequest { + public: + explicit ConnectionRequest(Connection* connection) + : StunRequest(new IceMessage()), + connection_(connection) { + } + + virtual ~ConnectionRequest() { + } + + void Prepare(StunMessage* request) override { + request->SetType(STUN_BINDING_REQUEST); + std::string username; + connection_->port()->CreateStunUsername( + connection_->remote_candidate().username(), &username); + request->AddAttribute( + new StunByteStringAttribute(STUN_ATTR_USERNAME, username)); + + // connection_ already holds this ping, so subtract one from count. + if (connection_->port()->send_retransmit_count_attribute()) { + request->AddAttribute(new StunUInt32Attribute( + STUN_ATTR_RETRANSMIT_COUNT, + static_cast<uint32_t>(connection_->pings_since_last_response_.size() - + 1))); + } + + // Adding ICE_CONTROLLED or ICE_CONTROLLING attribute based on the role. + if (connection_->port()->GetIceRole() == ICEROLE_CONTROLLING) { + request->AddAttribute(new StunUInt64Attribute( + STUN_ATTR_ICE_CONTROLLING, connection_->port()->IceTiebreaker())); + // Since we are trying aggressive nomination, sending USE-CANDIDATE + // attribute in every ping. + // If we are dealing with a ice-lite end point, nomination flag + // in Connection will be set to false by default. Once the connection + // becomes "best connection", nomination flag will be turned on. + if (connection_->use_candidate_attr()) { + request->AddAttribute(new StunByteStringAttribute( + STUN_ATTR_USE_CANDIDATE)); + } + } else if (connection_->port()->GetIceRole() == ICEROLE_CONTROLLED) { + request->AddAttribute(new StunUInt64Attribute( + STUN_ATTR_ICE_CONTROLLED, connection_->port()->IceTiebreaker())); + } else { + ASSERT(false); + } + + // Adding PRIORITY Attribute. + // Changing the type preference to Peer Reflexive and local preference + // and component id information is unchanged from the original priority. + // priority = (2^24)*(type preference) + + // (2^8)*(local preference) + + // (2^0)*(256 - component ID) + uint32_t prflx_priority = + ICE_TYPE_PREFERENCE_PRFLX << 24 | + (connection_->local_candidate().priority() & 0x00FFFFFF); + request->AddAttribute( + new StunUInt32Attribute(STUN_ATTR_PRIORITY, prflx_priority)); + + // Adding Message Integrity attribute. + request->AddMessageIntegrity(connection_->remote_candidate().password()); + // Adding Fingerprint. + request->AddFingerprint(); + } + + void OnResponse(StunMessage* response) override { + connection_->OnConnectionRequestResponse(this, response); + } + + void OnErrorResponse(StunMessage* response) override { + connection_->OnConnectionRequestErrorResponse(this, response); + } + + void OnTimeout() override { + connection_->OnConnectionRequestTimeout(this); + } + + void OnSent() override { + connection_->OnConnectionRequestSent(this); + // Each request is sent only once. After a single delay , the request will + // time out. + timeout_ = true; + } + + int resend_delay() override { + return CONNECTION_RESPONSE_TIMEOUT; + } + + private: + Connection* connection_; +}; + +// +// Connection +// + +Connection::Connection(Port* port, + size_t index, + const Candidate& remote_candidate) + : port_(port), + local_candidate_index_(index), + remote_candidate_(remote_candidate), + write_state_(STATE_WRITE_INIT), + receiving_(false), + connected_(true), + pruned_(false), + use_candidate_attr_(false), + nominated_(false), + remote_ice_mode_(ICEMODE_FULL), + requests_(port->thread()), + rtt_(DEFAULT_RTT), + last_ping_sent_(0), + last_ping_received_(0), + last_data_received_(0), + last_ping_response_received_(0), + recv_rate_tracker_(100u, 10u), + send_rate_tracker_(100u, 10u), + sent_packets_discarded_(0), + sent_packets_total_(0), + reported_(false), + state_(STATE_WAITING), + receiving_timeout_(WEAK_CONNECTION_RECEIVE_TIMEOUT), + time_created_ms_(rtc::Time()) { + // All of our connections start in WAITING state. + // TODO(mallinath) - Start connections from STATE_FROZEN. + // Wire up to send stun packets + requests_.SignalSendPacket.connect(this, &Connection::OnSendStunPacket); + LOG_J(LS_INFO, this) << "Connection created"; +} + +Connection::~Connection() { +} + +const Candidate& Connection::local_candidate() const { + ASSERT(local_candidate_index_ < port_->Candidates().size()); + return port_->Candidates()[local_candidate_index_]; +} + +uint64_t Connection::priority() const { + uint64_t priority = 0; + // RFC 5245 - 5.7.2. Computing Pair Priority and Ordering Pairs + // Let G be the priority for the candidate provided by the controlling + // agent. Let D be the priority for the candidate provided by the + // controlled agent. + // pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0) + IceRole role = port_->GetIceRole(); + if (role != ICEROLE_UNKNOWN) { + uint32_t g = 0; + uint32_t d = 0; + if (role == ICEROLE_CONTROLLING) { + g = local_candidate().priority(); + d = remote_candidate_.priority(); + } else { + g = remote_candidate_.priority(); + d = local_candidate().priority(); + } + priority = std::min(g, d); + priority = priority << 32; + priority += 2 * std::max(g, d) + (g > d ? 1 : 0); + } + return priority; +} + +void Connection::set_write_state(WriteState value) { + WriteState old_value = write_state_; + write_state_ = value; + if (value != old_value) { + LOG_J(LS_VERBOSE, this) << "set_write_state from: " << old_value << " to " + << value; + SignalStateChange(this); + } +} + +void Connection::set_receiving(bool value) { + if (value != receiving_) { + LOG_J(LS_VERBOSE, this) << "set_receiving to " << value; + receiving_ = value; + SignalStateChange(this); + } +} + +void Connection::set_state(State state) { + State old_state = state_; + state_ = state; + if (state != old_state) { + LOG_J(LS_VERBOSE, this) << "set_state"; + } +} + +void Connection::set_connected(bool value) { + bool old_value = connected_; + connected_ = value; + if (value != old_value) { + LOG_J(LS_VERBOSE, this) << "set_connected from: " << old_value << " to " + << value; + } +} + +void Connection::set_use_candidate_attr(bool enable) { + use_candidate_attr_ = enable; +} + +void Connection::OnSendStunPacket(const void* data, size_t size, + StunRequest* req) { + rtc::PacketOptions options(port_->DefaultDscpValue()); + auto err = port_->SendTo( + data, size, remote_candidate_.address(), options, false); + if (err < 0) { + LOG_J(LS_WARNING, this) << "Failed to send STUN ping " + << " err=" << err + << " id=" << rtc::hex_encode(req->id()); + } +} + +void Connection::OnReadPacket( + const char* data, size_t size, const rtc::PacketTime& packet_time) { + rtc::scoped_ptr<IceMessage> msg; + std::string remote_ufrag; + const rtc::SocketAddress& addr(remote_candidate_.address()); + if (!port_->GetStunMessage(data, size, addr, msg.accept(), &remote_ufrag)) { + // The packet did not parse as a valid STUN message + // This is a data packet, pass it along. + set_receiving(true); + last_data_received_ = rtc::Time(); + recv_rate_tracker_.AddSamples(size); + SignalReadPacket(this, data, size, packet_time); + + // If timed out sending writability checks, start up again + if (!pruned_ && (write_state_ == STATE_WRITE_TIMEOUT)) { + LOG(LS_WARNING) << "Received a data packet on a timed-out Connection. " + << "Resetting state to STATE_WRITE_INIT."; + set_write_state(STATE_WRITE_INIT); + } + } else if (!msg) { + // The packet was STUN, but failed a check and was handled internally. + } else { + // The packet is STUN and passed the Port checks. + // Perform our own checks to ensure this packet is valid. + // If this is a STUN request, then update the receiving bit and respond. + // If this is a STUN response, then update the writable bit. + // Log at LS_INFO if we receive a ping on an unwritable connection. + rtc::LoggingSeverity sev = (!writable() ? rtc::LS_INFO : rtc::LS_VERBOSE); + switch (msg->type()) { + case STUN_BINDING_REQUEST: + LOG_JV(sev, this) << "Received STUN ping" + << ", id=" << rtc::hex_encode(msg->transaction_id()); + + if (remote_ufrag == remote_candidate_.username()) { + // Check for role conflicts. + if (!port_->MaybeIceRoleConflict(addr, msg.get(), remote_ufrag)) { + // Received conflicting role from the peer. + LOG(LS_INFO) << "Received conflicting role from the peer."; + return; + } + + // Incoming, validated stun request from remote peer. + // This call will also set the connection receiving. + port_->SendBindingResponse(msg.get(), addr); + + // If timed out sending writability checks, start up again + if (!pruned_ && (write_state_ == STATE_WRITE_TIMEOUT)) + set_write_state(STATE_WRITE_INIT); + + if (port_->GetIceRole() == ICEROLE_CONTROLLED) { + const StunByteStringAttribute* use_candidate_attr = + msg->GetByteString(STUN_ATTR_USE_CANDIDATE); + if (use_candidate_attr) { + set_nominated(true); + SignalNominated(this); + } + } + } else { + // The packet had the right local username, but the remote username + // was not the right one for the remote address. + LOG_J(LS_ERROR, this) + << "Received STUN request with bad remote username " + << remote_ufrag; + port_->SendBindingErrorResponse(msg.get(), addr, + STUN_ERROR_UNAUTHORIZED, + STUN_ERROR_REASON_UNAUTHORIZED); + + } + break; + + // Response from remote peer. Does it match request sent? + // This doesn't just check, it makes callbacks if transaction + // id's match. + case STUN_BINDING_RESPONSE: + case STUN_BINDING_ERROR_RESPONSE: + if (msg->ValidateMessageIntegrity( + data, size, remote_candidate().password())) { + requests_.CheckResponse(msg.get()); + } + // Otherwise silently discard the response message. + break; + + // Remote end point sent an STUN indication instead of regular binding + // request. In this case |last_ping_received_| will be updated but no + // response will be sent. + case STUN_BINDING_INDICATION: + ReceivedPing(); + break; + + default: + ASSERT(false); + break; + } + } +} + +void Connection::OnReadyToSend() { + if (write_state_ == STATE_WRITABLE) { + SignalReadyToSend(this); + } +} + +void Connection::Prune() { + if (!pruned_ || active()) { + LOG_J(LS_VERBOSE, this) << "Connection pruned"; + pruned_ = true; + requests_.Clear(); + set_write_state(STATE_WRITE_TIMEOUT); + } +} + +void Connection::Destroy() { + LOG_J(LS_VERBOSE, this) << "Connection destroyed"; + port_->thread()->Post(this, MSG_DELETE); +} + +void Connection::PrintPingsSinceLastResponse(std::string* s, size_t max) { + std::ostringstream oss; + oss << std::boolalpha; + if (pings_since_last_response_.size() > max) { + for (size_t i = 0; i < max; i++) { + const SentPing& ping = pings_since_last_response_[i]; + oss << rtc::hex_encode(ping.id) << " "; + } + oss << "... " << (pings_since_last_response_.size() - max) << " more"; + } else { + for (const SentPing& ping : pings_since_last_response_) { + oss << rtc::hex_encode(ping.id) << " "; + } + } + *s = oss.str(); +} + +void Connection::UpdateState(uint32_t now) { + uint32_t rtt = ConservativeRTTEstimate(rtt_); + + if (LOG_CHECK_LEVEL(LS_VERBOSE)) { + std::string pings; + PrintPingsSinceLastResponse(&pings, 5); + LOG_J(LS_VERBOSE, this) << "UpdateState()" + << ", ms since last received response=" + << now - last_ping_response_received_ + << ", ms since last received data=" + << now - last_data_received_ + << ", rtt=" << rtt + << ", pings_since_last_response=" << pings; + } + + // Check the writable state. (The order of these checks is important.) + // + // Before becoming unwritable, we allow for a fixed number of pings to fail + // (i.e., receive no response). We also have to give the response time to + // get back, so we include a conservative estimate of this. + // + // Before timing out writability, we give a fixed amount of time. This is to + // allow for changes in network conditions. + + if ((write_state_ == STATE_WRITABLE) && + TooManyFailures(pings_since_last_response_, + CONNECTION_WRITE_CONNECT_FAILURES, + rtt, + now) && + TooLongWithoutResponse(pings_since_last_response_, + CONNECTION_WRITE_CONNECT_TIMEOUT, + now)) { + uint32_t max_pings = CONNECTION_WRITE_CONNECT_FAILURES; + LOG_J(LS_INFO, this) << "Unwritable after " << max_pings + << " ping failures and " + << now - pings_since_last_response_[0].sent_time + << " ms without a response," + << " ms since last received ping=" + << now - last_ping_received_ + << " ms since last received data=" + << now - last_data_received_ + << " rtt=" << rtt; + set_write_state(STATE_WRITE_UNRELIABLE); + } + if ((write_state_ == STATE_WRITE_UNRELIABLE || + write_state_ == STATE_WRITE_INIT) && + TooLongWithoutResponse(pings_since_last_response_, + CONNECTION_WRITE_TIMEOUT, + now)) { + LOG_J(LS_INFO, this) << "Timed out after " + << now - pings_since_last_response_[0].sent_time + << " ms without a response" + << ", rtt=" << rtt; + set_write_state(STATE_WRITE_TIMEOUT); + } + + // Check the receiving state. + uint32_t last_recv_time = last_received(); + bool receiving = now <= last_recv_time + receiving_timeout_; + set_receiving(receiving); + if (dead(now)) { + Destroy(); + } +} + +void Connection::Ping(uint32_t now) { + last_ping_sent_ = now; + ConnectionRequest *req = new ConnectionRequest(this); + pings_since_last_response_.push_back(SentPing(req->id(), now)); + LOG_J(LS_VERBOSE, this) << "Sending STUN ping " + << ", id=" << rtc::hex_encode(req->id()); + requests_.Send(req); + state_ = STATE_INPROGRESS; +} + +void Connection::ReceivedPing() { + set_receiving(true); + last_ping_received_ = rtc::Time(); +} + +void Connection::ReceivedPingResponse() { + // We've already validated that this is a STUN binding response with + // the correct local and remote username for this connection. + // So if we're not already, become writable. We may be bringing a pruned + // connection back to life, but if we don't really want it, we can always + // prune it again. + set_receiving(true); + set_write_state(STATE_WRITABLE); + set_state(STATE_SUCCEEDED); + pings_since_last_response_.clear(); + last_ping_response_received_ = rtc::Time(); +} + +bool Connection::dead(uint32_t now) const { + if (now < (time_created_ms_ + MIN_CONNECTION_LIFETIME)) { + // A connection that hasn't passed its minimum lifetime is still alive. + // We do this to prevent connections from being pruned too quickly + // during a network change event when two networks would be up + // simultaneously but only for a brief period. + return false; + } + + if (receiving_) { + // A connection that is receiving is alive. + return false; + } + + // A connection is alive until it is inactive. + return !active(); + + // TODO(honghaiz): Move from using the write state to using the receiving + // state with something like the following: + // return (now > (last_received() + DEAD_CONNECTION_RECEIVE_TIMEOUT)); +} + +std::string Connection::ToDebugId() const { + std::stringstream ss; + ss << std::hex << this; + return ss.str(); +} + +std::string Connection::ToString() const { + const char CONNECT_STATE_ABBREV[2] = { + '-', // not connected (false) + 'C', // connected (true) + }; + const char RECEIVE_STATE_ABBREV[2] = { + '-', // not receiving (false) + 'R', // receiving (true) + }; + const char WRITE_STATE_ABBREV[4] = { + 'W', // STATE_WRITABLE + 'w', // STATE_WRITE_UNRELIABLE + '-', // STATE_WRITE_INIT + 'x', // STATE_WRITE_TIMEOUT + }; + const std::string ICESTATE[4] = { + "W", // STATE_WAITING + "I", // STATE_INPROGRESS + "S", // STATE_SUCCEEDED + "F" // STATE_FAILED + }; + const Candidate& local = local_candidate(); + const Candidate& remote = remote_candidate(); + std::stringstream ss; + ss << "Conn[" << ToDebugId() + << ":" << port_->content_name() + << ":" << local.id() << ":" << local.component() + << ":" << local.generation() + << ":" << local.type() << ":" << local.protocol() + << ":" << local.address().ToSensitiveString() + << "->" << remote.id() << ":" << remote.component() + << ":" << remote.priority() + << ":" << remote.type() << ":" + << remote.protocol() << ":" << remote.address().ToSensitiveString() << "|" + << CONNECT_STATE_ABBREV[connected()] + << RECEIVE_STATE_ABBREV[receiving()] + << WRITE_STATE_ABBREV[write_state()] + << ICESTATE[state()] << "|" + << priority() << "|"; + if (rtt_ < DEFAULT_RTT) { + ss << rtt_ << "]"; + } else { + ss << "-]"; + } + return ss.str(); +} + +std::string Connection::ToSensitiveString() const { + return ToString(); +} + +void Connection::OnConnectionRequestResponse(ConnectionRequest* request, + StunMessage* response) { + // Log at LS_INFO if we receive a ping response on an unwritable + // connection. + rtc::LoggingSeverity sev = !writable() ? rtc::LS_INFO : rtc::LS_VERBOSE; + + uint32_t rtt = request->Elapsed(); + + ReceivedPingResponse(); + + if (LOG_CHECK_LEVEL_V(sev)) { + bool use_candidate = ( + response->GetByteString(STUN_ATTR_USE_CANDIDATE) != nullptr); + std::string pings; + PrintPingsSinceLastResponse(&pings, 5); + LOG_JV(sev, this) << "Received STUN ping response" + << ", id=" << rtc::hex_encode(request->id()) + << ", code=0" // Makes logging easier to parse. + << ", rtt=" << rtt + << ", use_candidate=" << use_candidate + << ", pings_since_last_response=" << pings; + } + + rtt_ = (RTT_RATIO * rtt_ + rtt) / (RTT_RATIO + 1); + + MaybeAddPrflxCandidate(request, response); +} + +void Connection::OnConnectionRequestErrorResponse(ConnectionRequest* request, + StunMessage* response) { + const StunErrorCodeAttribute* error_attr = response->GetErrorCode(); + int error_code = STUN_ERROR_GLOBAL_FAILURE; + if (error_attr) { + error_code = error_attr->code(); + } + + LOG_J(LS_INFO, this) << "Received STUN error response" + << " id=" << rtc::hex_encode(request->id()) + << " code=" << error_code + << " rtt=" << request->Elapsed(); + + if (error_code == STUN_ERROR_UNKNOWN_ATTRIBUTE || + error_code == STUN_ERROR_SERVER_ERROR || + error_code == STUN_ERROR_UNAUTHORIZED) { + // Recoverable error, retry + } else if (error_code == STUN_ERROR_STALE_CREDENTIALS) { + // Race failure, retry + } else if (error_code == STUN_ERROR_ROLE_CONFLICT) { + HandleRoleConflictFromPeer(); + } else { + // This is not a valid connection. + LOG_J(LS_ERROR, this) << "Received STUN error response, code=" + << error_code << "; killing connection"; + set_state(STATE_FAILED); + Destroy(); + } +} + +void Connection::OnConnectionRequestTimeout(ConnectionRequest* request) { + // Log at LS_INFO if we miss a ping on a writable connection. + rtc::LoggingSeverity sev = writable() ? rtc::LS_INFO : rtc::LS_VERBOSE; + LOG_JV(sev, this) << "Timing-out STUN ping " + << rtc::hex_encode(request->id()) + << " after " << request->Elapsed() << " ms"; +} + +void Connection::OnConnectionRequestSent(ConnectionRequest* request) { + // Log at LS_INFO if we send a ping on an unwritable connection. + rtc::LoggingSeverity sev = !writable() ? rtc::LS_INFO : rtc::LS_VERBOSE; + bool use_candidate = use_candidate_attr(); + LOG_JV(sev, this) << "Sent STUN ping" + << ", id=" << rtc::hex_encode(request->id()) + << ", use_candidate=" << use_candidate; +} + +void Connection::HandleRoleConflictFromPeer() { + port_->SignalRoleConflict(port_); +} + +void Connection::MaybeSetRemoteIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) { + if (remote_candidate_.username() == ice_ufrag && + remote_candidate_.password().empty()) { + remote_candidate_.set_password(ice_pwd); + } +} + +void Connection::MaybeUpdatePeerReflexiveCandidate( + const Candidate& new_candidate) { + if (remote_candidate_.type() == PRFLX_PORT_TYPE && + new_candidate.type() != PRFLX_PORT_TYPE && + remote_candidate_.protocol() == new_candidate.protocol() && + remote_candidate_.address() == new_candidate.address() && + remote_candidate_.username() == new_candidate.username() && + remote_candidate_.password() == new_candidate.password() && + remote_candidate_.generation() == new_candidate.generation()) { + remote_candidate_ = new_candidate; + } +} + +void Connection::OnMessage(rtc::Message *pmsg) { + ASSERT(pmsg->message_id == MSG_DELETE); + LOG_J(LS_INFO, this) << "Connection deleted"; + SignalDestroyed(this); + delete this; +} + +uint32_t Connection::last_received() { + return std::max(last_data_received_, + std::max(last_ping_received_, last_ping_response_received_)); +} + +size_t Connection::recv_bytes_second() { + return recv_rate_tracker_.ComputeRate(); +} + +size_t Connection::recv_total_bytes() { + return recv_rate_tracker_.TotalSampleCount(); +} + +size_t Connection::sent_bytes_second() { + return send_rate_tracker_.ComputeRate(); +} + +size_t Connection::sent_total_bytes() { + return send_rate_tracker_.TotalSampleCount(); +} + +size_t Connection::sent_discarded_packets() { + return sent_packets_discarded_; +} + +size_t Connection::sent_total_packets() { + return sent_packets_total_; +} + +void Connection::MaybeAddPrflxCandidate(ConnectionRequest* request, + StunMessage* response) { + // RFC 5245 + // The agent checks the mapped address from the STUN response. If the + // transport address does not match any of the local candidates that the + // agent knows about, the mapped address represents a new candidate -- a + // peer reflexive candidate. + const StunAddressAttribute* addr = + response->GetAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + if (!addr) { + LOG(LS_WARNING) << "Connection::OnConnectionRequestResponse - " + << "No MAPPED-ADDRESS or XOR-MAPPED-ADDRESS found in the " + << "stun response message"; + return; + } + + bool known_addr = false; + for (size_t i = 0; i < port_->Candidates().size(); ++i) { + if (port_->Candidates()[i].address() == addr->GetAddress()) { + known_addr = true; + break; + } + } + if (known_addr) { + return; + } + + // RFC 5245 + // Its priority is set equal to the value of the PRIORITY attribute + // in the Binding request. + const StunUInt32Attribute* priority_attr = + request->msg()->GetUInt32(STUN_ATTR_PRIORITY); + if (!priority_attr) { + LOG(LS_WARNING) << "Connection::OnConnectionRequestResponse - " + << "No STUN_ATTR_PRIORITY found in the " + << "stun response message"; + return; + } + const uint32_t priority = priority_attr->value(); + std::string id = rtc::CreateRandomString(8); + + Candidate new_local_candidate; + new_local_candidate.set_id(id); + new_local_candidate.set_component(local_candidate().component()); + new_local_candidate.set_type(PRFLX_PORT_TYPE); + new_local_candidate.set_protocol(local_candidate().protocol()); + new_local_candidate.set_address(addr->GetAddress()); + new_local_candidate.set_priority(priority); + new_local_candidate.set_username(local_candidate().username()); + new_local_candidate.set_password(local_candidate().password()); + new_local_candidate.set_network_name(local_candidate().network_name()); + new_local_candidate.set_network_type(local_candidate().network_type()); + new_local_candidate.set_related_address(local_candidate().address()); + new_local_candidate.set_foundation( + ComputeFoundation(PRFLX_PORT_TYPE, local_candidate().protocol(), + local_candidate().address())); + + // Change the local candidate of this Connection to the new prflx candidate. + local_candidate_index_ = port_->AddPrflxCandidate(new_local_candidate); + + // SignalStateChange to force a re-sort in P2PTransportChannel as this + // Connection's local candidate has changed. + SignalStateChange(this); +} + +ProxyConnection::ProxyConnection(Port* port, size_t index, + const Candidate& candidate) + : Connection(port, index, candidate), error_(0) { +} + +int ProxyConnection::Send(const void* data, size_t size, + const rtc::PacketOptions& options) { + if (write_state_ == STATE_WRITE_INIT || write_state_ == STATE_WRITE_TIMEOUT) { + error_ = EWOULDBLOCK; + return SOCKET_ERROR; + } + sent_packets_total_++; + int sent = port_->SendTo(data, size, remote_candidate_.address(), + options, true); + if (sent <= 0) { + ASSERT(sent < 0); + error_ = port_->GetError(); + sent_packets_discarded_++; + } else { + send_rate_tracker_.AddSamples(sent); + } + return sent; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/port.h b/webrtc/p2p/base/port.h new file mode 100644 index 0000000000..01c45f26d8 --- /dev/null +++ b/webrtc/p2p/base/port.h @@ -0,0 +1,646 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_PORT_H_ +#define WEBRTC_P2P_BASE_PORT_H_ + +#include <map> +#include <set> +#include <string> +#include <vector> + +#include "webrtc/p2p/base/candidate.h" +#include "webrtc/p2p/base/packetsocketfactory.h" +#include "webrtc/p2p/base/portinterface.h" +#include "webrtc/p2p/base/stun.h" +#include "webrtc/p2p/base/stunrequest.h" +#include "webrtc/p2p/base/transport.h" +#include "webrtc/base/asyncpacketsocket.h" +#include "webrtc/base/network.h" +#include "webrtc/base/proxyinfo.h" +#include "webrtc/base/ratetracker.h" +#include "webrtc/base/sigslot.h" +#include "webrtc/base/socketaddress.h" +#include "webrtc/base/thread.h" + +namespace cricket { + +class Connection; +class ConnectionRequest; + +extern const char LOCAL_PORT_TYPE[]; +extern const char STUN_PORT_TYPE[]; +extern const char PRFLX_PORT_TYPE[]; +extern const char RELAY_PORT_TYPE[]; + +extern const char UDP_PROTOCOL_NAME[]; +extern const char TCP_PROTOCOL_NAME[]; +extern const char SSLTCP_PROTOCOL_NAME[]; + +// RFC 6544, TCP candidate encoding rules. +extern const int DISCARD_PORT; +extern const char TCPTYPE_ACTIVE_STR[]; +extern const char TCPTYPE_PASSIVE_STR[]; +extern const char TCPTYPE_SIMOPEN_STR[]; + +// The minimum time we will wait before destroying a connection after creating +// it. +const uint32_t MIN_CONNECTION_LIFETIME = 10 * 1000; // 10 seconds. + +// The timeout duration when a connection does not receive anything. +const uint32_t WEAK_CONNECTION_RECEIVE_TIMEOUT = 2500; // 2.5 seconds + +// The length of time we wait before timing out writability on a connection. +const uint32_t CONNECTION_WRITE_TIMEOUT = 15 * 1000; // 15 seconds + +// The length of time we wait before we become unwritable. +const uint32_t CONNECTION_WRITE_CONNECT_TIMEOUT = 5 * 1000; // 5 seconds + +// The number of pings that must fail to respond before we become unwritable. +const uint32_t CONNECTION_WRITE_CONNECT_FAILURES = 5; + +// This is the length of time that we wait for a ping response to come back. +const int CONNECTION_RESPONSE_TIMEOUT = 5 * 1000; // 5 seconds + +enum RelayType { + RELAY_GTURN, // Legacy google relay service. + RELAY_TURN // Standard (TURN) relay service. +}; + +enum IcePriorityValue { + // The reason we are choosing Relay preference 2 is because, we can run + // Relay from client to server on UDP/TCP/TLS. To distinguish the transport + // protocol, we prefer UDP over TCP over TLS. + // For UDP ICE_TYPE_PREFERENCE_RELAY will be 2. + // For TCP ICE_TYPE_PREFERENCE_RELAY will be 1. + // For TLS ICE_TYPE_PREFERENCE_RELAY will be 0. + // Check turnport.cc for setting these values. + ICE_TYPE_PREFERENCE_RELAY = 2, + ICE_TYPE_PREFERENCE_HOST_TCP = 90, + ICE_TYPE_PREFERENCE_SRFLX = 100, + ICE_TYPE_PREFERENCE_PRFLX = 110, + ICE_TYPE_PREFERENCE_HOST = 126 +}; + +const char* ProtoToString(ProtocolType proto); +bool StringToProto(const char* value, ProtocolType* proto); + +struct ProtocolAddress { + rtc::SocketAddress address; + ProtocolType proto; + bool secure; + + ProtocolAddress(const rtc::SocketAddress& a, ProtocolType p) + : address(a), proto(p), secure(false) { } + ProtocolAddress(const rtc::SocketAddress& a, ProtocolType p, bool sec) + : address(a), proto(p), secure(sec) { } +}; + +typedef std::set<rtc::SocketAddress> ServerAddresses; + +// Represents a local communication mechanism that can be used to create +// connections to similar mechanisms of the other client. Subclasses of this +// one add support for specific mechanisms like local UDP ports. +class Port : public PortInterface, public rtc::MessageHandler, + public sigslot::has_slots<> { + public: + Port(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + const std::string& username_fragment, + const std::string& password); + Port(rtc::Thread* thread, + const std::string& type, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username_fragment, + const std::string& password); + virtual ~Port(); + + virtual const std::string& Type() const { return type_; } + virtual rtc::Network* Network() const { return network_; } + + // Methods to set/get ICE role and tiebreaker values. + IceRole GetIceRole() const { return ice_role_; } + void SetIceRole(IceRole role) { ice_role_ = role; } + + void SetIceTiebreaker(uint64_t tiebreaker) { tiebreaker_ = tiebreaker; } + uint64_t IceTiebreaker() const { return tiebreaker_; } + + virtual bool SharedSocket() const { return shared_socket_; } + void ResetSharedSocket() { shared_socket_ = false; } + + // The thread on which this port performs its I/O. + rtc::Thread* thread() { return thread_; } + + // The factory used to create the sockets of this port. + rtc::PacketSocketFactory* socket_factory() const { return factory_; } + void set_socket_factory(rtc::PacketSocketFactory* factory) { + factory_ = factory; + } + + // For debugging purposes. + const std::string& content_name() const { return content_name_; } + void set_content_name(const std::string& content_name) { + content_name_ = content_name; + } + + int component() const { return component_; } + void set_component(int component) { component_ = component; } + + bool send_retransmit_count_attribute() const { + return send_retransmit_count_attribute_; + } + void set_send_retransmit_count_attribute(bool enable) { + send_retransmit_count_attribute_ = enable; + } + + // Identifies the generation that this port was created in. + uint32_t generation() { return generation_; } + void set_generation(uint32_t generation) { generation_ = generation; } + + // ICE requires a single username/password per content/media line. So the + // |ice_username_fragment_| of the ports that belongs to the same content will + // be the same. However this causes a small complication with our relay + // server, which expects different username for RTP and RTCP. + // + // To resolve this problem, we implemented the username_fragment(), + // which returns a different username (calculated from + // |ice_username_fragment_|) for RTCP in the case of ICEPROTO_GOOGLE. And the + // username_fragment() simply returns |ice_username_fragment_| when running + // in ICEPROTO_RFC5245. + // + // As a result the ICEPROTO_GOOGLE will use different usernames for RTP and + // RTCP. And the ICEPROTO_RFC5245 will use same username for both RTP and + // RTCP. + const std::string username_fragment() const; + const std::string& password() const { return password_; } + + // Fired when candidates are discovered by the port. When all candidates + // are discovered that belong to port SignalAddressReady is fired. + sigslot::signal2<Port*, const Candidate&> SignalCandidateReady; + + // Provides all of the above information in one handy object. + virtual const std::vector<Candidate>& Candidates() const { + return candidates_; + } + + // SignalPortComplete is sent when port completes the task of candidates + // allocation. + sigslot::signal1<Port*> SignalPortComplete; + // This signal sent when port fails to allocate candidates and this port + // can't be used in establishing the connections. When port is in shared mode + // and port fails to allocate one of the candidates, port shouldn't send + // this signal as other candidates might be usefull in establishing the + // connection. + sigslot::signal1<Port*> SignalPortError; + + // Returns a map containing all of the connections of this port, keyed by the + // remote address. + typedef std::map<rtc::SocketAddress, Connection*> AddressMap; + const AddressMap& connections() { return connections_; } + + // Returns the connection to the given address or NULL if none exists. + virtual Connection* GetConnection( + const rtc::SocketAddress& remote_addr); + + // Called each time a connection is created. + sigslot::signal2<Port*, Connection*> SignalConnectionCreated; + + // In a shared socket mode each port which shares the socket will decide + // to accept the packet based on the |remote_addr|. Currently only UDP + // port implemented this method. + // TODO(mallinath) - Make it pure virtual. + virtual bool HandleIncomingPacket( + rtc::AsyncPacketSocket* socket, const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + ASSERT(false); + return false; + } + + // Sends a response message (normal or error) to the given request. One of + // these methods should be called as a response to SignalUnknownAddress. + // NOTE: You MUST call CreateConnection BEFORE SendBindingResponse. + virtual void SendBindingResponse(StunMessage* request, + const rtc::SocketAddress& addr); + virtual void SendBindingErrorResponse( + StunMessage* request, const rtc::SocketAddress& addr, + int error_code, const std::string& reason); + + void set_proxy(const std::string& user_agent, + const rtc::ProxyInfo& proxy) { + user_agent_ = user_agent; + proxy_ = proxy; + } + const std::string& user_agent() { return user_agent_; } + const rtc::ProxyInfo& proxy() { return proxy_; } + + virtual void EnablePortPackets(); + + // Called if the port has no connections and is no longer useful. + void Destroy(); + + virtual void OnMessage(rtc::Message *pmsg); + + // Debugging description of this port + virtual std::string ToString() const; + const rtc::IPAddress& ip() const { return ip_; } + uint16_t min_port() { return min_port_; } + uint16_t max_port() { return max_port_; } + + // Timeout shortening function to speed up unit tests. + void set_timeout_delay(int delay) { timeout_delay_ = delay; } + + // This method will return local and remote username fragements from the + // stun username attribute if present. + bool ParseStunUsername(const StunMessage* stun_msg, + std::string* local_username, + std::string* remote_username) const; + void CreateStunUsername(const std::string& remote_username, + std::string* stun_username_attr_str) const; + + bool MaybeIceRoleConflict(const rtc::SocketAddress& addr, + IceMessage* stun_msg, + const std::string& remote_ufrag); + + // Called when a packet has been sent to the socket. + void OnSentPacket(const rtc::SentPacket& sent_packet); + + // Called when the socket is currently able to send. + void OnReadyToSend(); + + // Called when the Connection discovers a local peer reflexive candidate. + // Returns the index of the new local candidate. + size_t AddPrflxCandidate(const Candidate& local); + + void set_candidate_filter(uint32_t candidate_filter) { + candidate_filter_ = candidate_filter; + } + + protected: + enum { + MSG_DEAD = 0, + MSG_FIRST_AVAILABLE + }; + + void set_type(const std::string& type) { type_ = type; } + + void AddAddress(const rtc::SocketAddress& address, + const rtc::SocketAddress& base_address, + const rtc::SocketAddress& related_address, + const std::string& protocol, + const std::string& relay_protocol, + const std::string& tcptype, + const std::string& type, + uint32_t type_preference, + uint32_t relay_preference, + bool final); + + // Adds the given connection to the list. (Deleting removes them.) + void AddConnection(Connection* conn); + + // Called when a packet is received from an unknown address that is not + // currently a connection. If this is an authenticated STUN binding request, + // then we will signal the client. + void OnReadPacket(const char* data, size_t size, + const rtc::SocketAddress& addr, + ProtocolType proto); + + // If the given data comprises a complete and correct STUN message then the + // return value is true, otherwise false. If the message username corresponds + // with this port's username fragment, msg will contain the parsed STUN + // message. Otherwise, the function may send a STUN response internally. + // remote_username contains the remote fragment of the STUN username. + bool GetStunMessage(const char* data, size_t size, + const rtc::SocketAddress& addr, + IceMessage** out_msg, std::string* out_username); + + // Checks if the address in addr is compatible with the port's ip. + bool IsCompatibleAddress(const rtc::SocketAddress& addr); + + // Returns default DSCP value. + rtc::DiffServCodePoint DefaultDscpValue() const { + // No change from what MediaChannel set. + return rtc::DSCP_NO_CHANGE; + } + + uint32_t candidate_filter() { return candidate_filter_; } + + private: + void Construct(); + // Called when one of our connections deletes itself. + void OnConnectionDestroyed(Connection* conn); + + // Whether this port is dead, and hence, should be destroyed on the controlled + // side. + bool dead() const { + return ice_role_ == ICEROLE_CONTROLLED && connections_.empty(); + } + + rtc::Thread* thread_; + rtc::PacketSocketFactory* factory_; + std::string type_; + bool send_retransmit_count_attribute_; + rtc::Network* network_; + rtc::IPAddress ip_; + uint16_t min_port_; + uint16_t max_port_; + std::string content_name_; + int component_; + uint32_t generation_; + // In order to establish a connection to this Port (so that real data can be + // sent through), the other side must send us a STUN binding request that is + // authenticated with this username_fragment and password. + // PortAllocatorSession will provide these username_fragment and password. + // + // Note: we should always use username_fragment() instead of using + // |ice_username_fragment_| directly. For the details see the comment on + // username_fragment(). + std::string ice_username_fragment_; + std::string password_; + std::vector<Candidate> candidates_; + AddressMap connections_; + int timeout_delay_; + bool enable_port_packets_; + IceRole ice_role_; + uint64_t tiebreaker_; + bool shared_socket_; + // Information to use when going through a proxy. + std::string user_agent_; + rtc::ProxyInfo proxy_; + + // Candidate filter is pushed down to Port such that each Port could + // make its own decision on how to create candidates. For example, + // when IceTransportsType is set to relay, both RelayPort and + // TurnPort will hide raddr to avoid local address leakage. + uint32_t candidate_filter_; + + friend class Connection; +}; + +// Represents a communication link between a port on the local client and a +// port on the remote client. +class Connection : public rtc::MessageHandler, + public sigslot::has_slots<> { + public: + struct SentPing { + SentPing(const std::string id, uint32_t sent_time) + : id(id), sent_time(sent_time) {} + + std::string id; + uint32_t sent_time; + }; + + // States are from RFC 5245. http://tools.ietf.org/html/rfc5245#section-5.7.4 + enum State { + STATE_WAITING = 0, // Check has not been performed, Waiting pair on CL. + STATE_INPROGRESS, // Check has been sent, transaction is in progress. + STATE_SUCCEEDED, // Check already done, produced a successful result. + STATE_FAILED // Check for this connection failed. + }; + + virtual ~Connection(); + + // The local port where this connection sends and receives packets. + Port* port() { return port_; } + const Port* port() const { return port_; } + + // Returns the description of the local port + virtual const Candidate& local_candidate() const; + + // Returns the description of the remote port to which we communicate. + const Candidate& remote_candidate() const { return remote_candidate_; } + + // Returns the pair priority. + uint64_t priority() const; + + enum WriteState { + STATE_WRITABLE = 0, // we have received ping responses recently + STATE_WRITE_UNRELIABLE = 1, // we have had a few ping failures + STATE_WRITE_INIT = 2, // we have yet to receive a ping response + STATE_WRITE_TIMEOUT = 3, // we have had a large number of ping failures + }; + + WriteState write_state() const { return write_state_; } + bool writable() const { return write_state_ == STATE_WRITABLE; } + bool receiving() const { return receiving_; } + + // Determines whether the connection has finished connecting. This can only + // be false for TCP connections. + bool connected() const { return connected_; } + bool weak() const { return !(writable() && receiving() && connected()); } + bool active() const { + // TODO(honghaiz): Move from using |write_state_| to using |pruned_|. + return write_state_ != STATE_WRITE_TIMEOUT; + } + // A connection is dead if it can be safely deleted. + bool dead(uint32_t now) const; + + // Estimate of the round-trip time over this connection. + uint32_t rtt() const { return rtt_; } + + size_t sent_total_bytes(); + size_t sent_bytes_second(); + // Used to track how many packets are discarded in the application socket due + // to errors. + size_t sent_discarded_packets(); + size_t sent_total_packets(); + size_t recv_total_bytes(); + size_t recv_bytes_second(); + sigslot::signal1<Connection*> SignalStateChange; + + // Sent when the connection has decided that it is no longer of value. It + // will delete itself immediately after this call. + sigslot::signal1<Connection*> SignalDestroyed; + + // The connection can send and receive packets asynchronously. This matches + // the interface of AsyncPacketSocket, which may use UDP or TCP under the + // covers. + virtual int Send(const void* data, size_t size, + const rtc::PacketOptions& options) = 0; + + // Error if Send() returns < 0 + virtual int GetError() = 0; + + sigslot::signal4<Connection*, const char*, size_t, const rtc::PacketTime&> + SignalReadPacket; + + sigslot::signal1<Connection*> SignalReadyToSend; + + // Called when a packet is received on this connection. + void OnReadPacket(const char* data, size_t size, + const rtc::PacketTime& packet_time); + + // Called when the socket is currently able to send. + void OnReadyToSend(); + + // Called when a connection is determined to be no longer useful to us. We + // still keep it around in case the other side wants to use it. But we can + // safely stop pinging on it and we can allow it to time out if the other + // side stops using it as well. + bool pruned() const { return pruned_; } + void Prune(); + + bool use_candidate_attr() const { return use_candidate_attr_; } + void set_use_candidate_attr(bool enable); + + bool nominated() const { return nominated_; } + void set_nominated(bool nominated) { nominated_ = nominated; } + + void set_remote_ice_mode(IceMode mode) { + remote_ice_mode_ = mode; + } + + void set_receiving_timeout(uint32_t receiving_timeout_ms) { + receiving_timeout_ = receiving_timeout_ms; + } + + // Makes the connection go away. + void Destroy(); + + // Checks that the state of this connection is up-to-date. The argument is + // the current time, which is compared against various timeouts. + void UpdateState(uint32_t now); + + // Called when this connection should try checking writability again. + uint32_t last_ping_sent() const { return last_ping_sent_; } + void Ping(uint32_t now); + void ReceivedPingResponse(); + + // Called whenever a valid ping is received on this connection. This is + // public because the connection intercepts the first ping for us. + uint32_t last_ping_received() const { return last_ping_received_; } + void ReceivedPing(); + + // Debugging description of this connection + std::string ToDebugId() const; + std::string ToString() const; + std::string ToSensitiveString() const; + // Prints pings_since_last_response_ into a string. + void PrintPingsSinceLastResponse(std::string* pings, size_t max); + + bool reported() const { return reported_; } + void set_reported(bool reported) { reported_ = reported;} + + // This signal will be fired if this connection is nominated by the + // controlling side. + sigslot::signal1<Connection*> SignalNominated; + + // Invoked when Connection receives STUN error response with 487 code. + void HandleRoleConflictFromPeer(); + + State state() const { return state_; } + + IceMode remote_ice_mode() const { return remote_ice_mode_; } + + // Update the ICE password of the remote candidate if |ice_ufrag| matches + // the candidate's ufrag, and the candidate's passwrod has not been set. + void MaybeSetRemoteIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd); + + // If |remote_candidate_| is peer reflexive and is equivalent to + // |new_candidate| except the type, update |remote_candidate_| to + // |new_candidate|. + void MaybeUpdatePeerReflexiveCandidate(const Candidate& new_candidate); + + // Returns the last received time of any data, stun request, or stun + // response in milliseconds + uint32_t last_received(); + + protected: + enum { MSG_DELETE = 0, MSG_FIRST_AVAILABLE }; + + // Constructs a new connection to the given remote port. + Connection(Port* port, size_t index, const Candidate& candidate); + + // Called back when StunRequestManager has a stun packet to send + void OnSendStunPacket(const void* data, size_t size, StunRequest* req); + + // Callbacks from ConnectionRequest + virtual void OnConnectionRequestResponse(ConnectionRequest* req, + StunMessage* response); + void OnConnectionRequestErrorResponse(ConnectionRequest* req, + StunMessage* response); + void OnConnectionRequestTimeout(ConnectionRequest* req); + void OnConnectionRequestSent(ConnectionRequest* req); + + // Changes the state and signals if necessary. + void set_write_state(WriteState value); + void set_receiving(bool value); + void set_state(State state); + void set_connected(bool value); + + void OnMessage(rtc::Message *pmsg); + + Port* port_; + size_t local_candidate_index_; + Candidate remote_candidate_; + WriteState write_state_; + bool receiving_; + bool connected_; + bool pruned_; + // By default |use_candidate_attr_| flag will be true, + // as we will be using aggressive nomination. + // But when peer is ice-lite, this flag "must" be initialized to false and + // turn on when connection becomes "best connection". + bool use_candidate_attr_; + // Whether this connection has been nominated by the controlling side via + // the use_candidate attribute. + bool nominated_; + IceMode remote_ice_mode_; + StunRequestManager requests_; + uint32_t rtt_; + uint32_t last_ping_sent_; // last time we sent a ping to the other side + uint32_t last_ping_received_; // last time we received a ping from the other + // side + uint32_t last_data_received_; + uint32_t last_ping_response_received_; + std::vector<SentPing> pings_since_last_response_; + + rtc::RateTracker recv_rate_tracker_; + rtc::RateTracker send_rate_tracker_; + uint32_t sent_packets_discarded_; + uint32_t sent_packets_total_; + + private: + void MaybeAddPrflxCandidate(ConnectionRequest* request, + StunMessage* response); + + bool reported_; + State state_; + // Time duration to switch from receiving to not receiving. + uint32_t receiving_timeout_; + uint32_t time_created_ms_; + + friend class Port; + friend class ConnectionRequest; +}; + +// ProxyConnection defers all the interesting work to the port +class ProxyConnection : public Connection { + public: + ProxyConnection(Port* port, size_t index, const Candidate& candidate); + + virtual int Send(const void* data, size_t size, + const rtc::PacketOptions& options); + virtual int GetError() { return error_; } + + private: + int error_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_PORT_H_ diff --git a/webrtc/p2p/base/port_unittest.cc b/webrtc/p2p/base/port_unittest.cc new file mode 100644 index 0000000000..4a4ed32456 --- /dev/null +++ b/webrtc/p2p/base/port_unittest.cc @@ -0,0 +1,2452 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/basicpacketsocketfactory.h" +#include "webrtc/p2p/base/relayport.h" +#include "webrtc/p2p/base/stunport.h" +#include "webrtc/p2p/base/tcpport.h" +#include "webrtc/p2p/base/testrelayserver.h" +#include "webrtc/p2p/base/teststunserver.h" +#include "webrtc/p2p/base/testturnserver.h" +#include "webrtc/p2p/base/transport.h" +#include "webrtc/p2p/base/turnport.h" +#include "webrtc/base/crc32.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/natserver.h" +#include "webrtc/base/natsocketfactory.h" +#include "webrtc/base/physicalsocketserver.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/socketaddress.h" +#include "webrtc/base/ssladapter.h" +#include "webrtc/base/stringutils.h" +#include "webrtc/base/thread.h" +#include "webrtc/base/virtualsocketserver.h" + +using rtc::AsyncPacketSocket; +using rtc::ByteBuffer; +using rtc::NATType; +using rtc::NAT_OPEN_CONE; +using rtc::NAT_ADDR_RESTRICTED; +using rtc::NAT_PORT_RESTRICTED; +using rtc::NAT_SYMMETRIC; +using rtc::PacketSocketFactory; +using rtc::scoped_ptr; +using rtc::Socket; +using rtc::SocketAddress; +using namespace cricket; + +static const int kTimeout = 1000; +static const SocketAddress kLocalAddr1("192.168.1.2", 0); +static const SocketAddress kLocalAddr2("192.168.1.3", 0); +static const SocketAddress kNatAddr1("77.77.77.77", rtc::NAT_SERVER_UDP_PORT); +static const SocketAddress kNatAddr2("88.88.88.88", rtc::NAT_SERVER_UDP_PORT); +static const SocketAddress kStunAddr("99.99.99.1", STUN_SERVER_PORT); +static const SocketAddress kRelayUdpIntAddr("99.99.99.2", 5000); +static const SocketAddress kRelayUdpExtAddr("99.99.99.3", 5001); +static const SocketAddress kRelayTcpIntAddr("99.99.99.2", 5002); +static const SocketAddress kRelayTcpExtAddr("99.99.99.3", 5003); +static const SocketAddress kRelaySslTcpIntAddr("99.99.99.2", 5004); +static const SocketAddress kRelaySslTcpExtAddr("99.99.99.3", 5005); +static const SocketAddress kTurnUdpIntAddr("99.99.99.4", STUN_SERVER_PORT); +static const SocketAddress kTurnUdpExtAddr("99.99.99.5", 0); +static const RelayCredentials kRelayCredentials("test", "test"); + +// TODO: Update these when RFC5245 is completely supported. +// Magic value of 30 is from RFC3484, for IPv4 addresses. +static const uint32_t kDefaultPrflxPriority = + ICE_TYPE_PREFERENCE_PRFLX << 24 | 30 << 8 | + (256 - ICE_CANDIDATE_COMPONENT_DEFAULT); + +static const int kTiebreaker1 = 11111; +static const int kTiebreaker2 = 22222; + +static const char* data = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + +static Candidate GetCandidate(Port* port) { + assert(port->Candidates().size() >= 1); + return port->Candidates()[0]; +} + +static SocketAddress GetAddress(Port* port) { + return GetCandidate(port).address(); +} + +static IceMessage* CopyStunMessage(const IceMessage* src) { + IceMessage* dst = new IceMessage(); + ByteBuffer buf; + src->Write(&buf); + dst->Read(&buf); + return dst; +} + +static bool WriteStunMessage(const StunMessage* msg, ByteBuffer* buf) { + buf->Resize(0); // clear out any existing buffer contents + return msg->Write(buf); +} + +// Stub port class for testing STUN generation and processing. +class TestPort : public Port { + public: + TestPort(rtc::Thread* thread, + const std::string& type, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username_fragment, + const std::string& password) + : Port(thread, + type, + factory, + network, + ip, + min_port, + max_port, + username_fragment, + password) {} + ~TestPort() {} + + // Expose GetStunMessage so that we can test it. + using cricket::Port::GetStunMessage; + + // The last StunMessage that was sent on this Port. + // TODO: Make these const; requires changes to SendXXXXResponse. + ByteBuffer* last_stun_buf() { return last_stun_buf_.get(); } + IceMessage* last_stun_msg() { return last_stun_msg_.get(); } + int last_stun_error_code() { + int code = 0; + if (last_stun_msg_) { + const StunErrorCodeAttribute* error_attr = last_stun_msg_->GetErrorCode(); + if (error_attr) { + code = error_attr->code(); + } + } + return code; + } + + virtual void PrepareAddress() { + rtc::SocketAddress addr(ip(), min_port()); + AddAddress(addr, addr, rtc::SocketAddress(), "udp", "", "", Type(), + ICE_TYPE_PREFERENCE_HOST, 0, true); + } + + // Exposed for testing candidate building. + void AddCandidateAddress(const rtc::SocketAddress& addr) { + AddAddress(addr, addr, rtc::SocketAddress(), "udp", "", "", Type(), + type_preference_, 0, false); + } + void AddCandidateAddress(const rtc::SocketAddress& addr, + const rtc::SocketAddress& base_address, + const std::string& type, + int type_preference, + bool final) { + AddAddress(addr, base_address, rtc::SocketAddress(), "udp", "", "", type, + type_preference, 0, final); + } + + virtual Connection* CreateConnection(const Candidate& remote_candidate, + CandidateOrigin origin) { + Connection* conn = new ProxyConnection(this, 0, remote_candidate); + AddConnection(conn); + // Set use-candidate attribute flag as this will add USE-CANDIDATE attribute + // in STUN binding requests. + conn->set_use_candidate_attr(true); + return conn; + } + virtual int SendTo( + const void* data, size_t size, const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, bool payload) { + if (!payload) { + IceMessage* msg = new IceMessage; + ByteBuffer* buf = new ByteBuffer(static_cast<const char*>(data), size); + ByteBuffer::ReadPosition pos(buf->GetReadPosition()); + if (!msg->Read(buf)) { + delete msg; + delete buf; + return -1; + } + buf->SetReadPosition(pos); + last_stun_buf_.reset(buf); + last_stun_msg_.reset(msg); + } + return static_cast<int>(size); + } + virtual int SetOption(rtc::Socket::Option opt, int value) { + return 0; + } + virtual int GetOption(rtc::Socket::Option opt, int* value) { + return -1; + } + virtual int GetError() { + return 0; + } + void Reset() { + last_stun_buf_.reset(); + last_stun_msg_.reset(); + } + void set_type_preference(int type_preference) { + type_preference_ = type_preference; + } + + private: + rtc::scoped_ptr<ByteBuffer> last_stun_buf_; + rtc::scoped_ptr<IceMessage> last_stun_msg_; + int type_preference_; +}; + +class TestChannel : public sigslot::has_slots<> { + public: + // Takes ownership of |p1| (but not |p2|). + TestChannel(Port* p1) + : ice_mode_(ICEMODE_FULL), + port_(p1), + complete_count_(0), + conn_(NULL), + remote_request_(), + nominated_(false) { + port_->SignalPortComplete.connect(this, &TestChannel::OnPortComplete); + port_->SignalUnknownAddress.connect(this, &TestChannel::OnUnknownAddress); + port_->SignalDestroyed.connect(this, &TestChannel::OnSrcPortDestroyed); + } + + int complete_count() { return complete_count_; } + Connection* conn() { return conn_; } + const SocketAddress& remote_address() { return remote_address_; } + const std::string remote_fragment() { return remote_frag_; } + + void Start() { port_->PrepareAddress(); } + void CreateConnection(const Candidate& remote_candidate) { + conn_ = port_->CreateConnection(remote_candidate, Port::ORIGIN_MESSAGE); + IceMode remote_ice_mode = + (ice_mode_ == ICEMODE_FULL) ? ICEMODE_LITE : ICEMODE_FULL; + conn_->set_remote_ice_mode(remote_ice_mode); + conn_->set_use_candidate_attr(remote_ice_mode == ICEMODE_FULL); + conn_->SignalStateChange.connect( + this, &TestChannel::OnConnectionStateChange); + conn_->SignalDestroyed.connect(this, &TestChannel::OnDestroyed); + conn_->SignalReadyToSend.connect(this, + &TestChannel::OnConnectionReadyToSend); + connection_ready_to_send_ = false; + } + void OnConnectionStateChange(Connection* conn) { + if (conn->write_state() == Connection::STATE_WRITABLE) { + conn->set_use_candidate_attr(true); + nominated_ = true; + } + } + void AcceptConnection(const Candidate& remote_candidate) { + ASSERT_TRUE(remote_request_.get() != NULL); + Candidate c = remote_candidate; + c.set_address(remote_address_); + conn_ = port_->CreateConnection(c, Port::ORIGIN_MESSAGE); + conn_->SignalDestroyed.connect(this, &TestChannel::OnDestroyed); + port_->SendBindingResponse(remote_request_.get(), remote_address_); + remote_request_.reset(); + } + void Ping() { + Ping(0); + } + void Ping(uint32_t now) { conn_->Ping(now); } + void Stop() { + if (conn_) { + conn_->Destroy(); + } + } + + void OnPortComplete(Port* port) { + complete_count_++; + } + void SetIceMode(IceMode ice_mode) { + ice_mode_ = ice_mode; + } + + int SendData(const char* data, size_t len) { + rtc::PacketOptions options; + return conn_->Send(data, len, options); + } + + void OnUnknownAddress(PortInterface* port, const SocketAddress& addr, + ProtocolType proto, + IceMessage* msg, const std::string& rf, + bool /*port_muxed*/) { + ASSERT_EQ(port_.get(), port); + if (!remote_address_.IsNil()) { + ASSERT_EQ(remote_address_, addr); + } + const cricket::StunUInt32Attribute* priority_attr = + msg->GetUInt32(STUN_ATTR_PRIORITY); + const cricket::StunByteStringAttribute* mi_attr = + msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY); + const cricket::StunUInt32Attribute* fingerprint_attr = + msg->GetUInt32(STUN_ATTR_FINGERPRINT); + EXPECT_TRUE(priority_attr != NULL); + EXPECT_TRUE(mi_attr != NULL); + EXPECT_TRUE(fingerprint_attr != NULL); + remote_address_ = addr; + remote_request_.reset(CopyStunMessage(msg)); + remote_frag_ = rf; + } + + void OnDestroyed(Connection* conn) { + ASSERT_EQ(conn_, conn); + LOG(INFO) << "OnDestroy connection " << conn << " deleted"; + conn_ = NULL; + // When the connection is destroyed, also clear these fields so future + // connections are possible. + remote_request_.reset(); + remote_address_.Clear(); + } + + void OnSrcPortDestroyed(PortInterface* port) { + Port* destroyed_src = port_.release(); + ASSERT_EQ(destroyed_src, port); + } + + Port* port() { return port_.get(); } + + bool nominated() const { return nominated_; } + + void set_connection_ready_to_send(bool ready) { + connection_ready_to_send_ = ready; + } + bool connection_ready_to_send() const { + return connection_ready_to_send_; + } + + private: + // ReadyToSend will only issue after a Connection recovers from EWOULDBLOCK. + void OnConnectionReadyToSend(Connection* conn) { + ASSERT_EQ(conn, conn_); + connection_ready_to_send_ = true; + } + + IceMode ice_mode_; + rtc::scoped_ptr<Port> port_; + + int complete_count_; + Connection* conn_; + SocketAddress remote_address_; + rtc::scoped_ptr<StunMessage> remote_request_; + std::string remote_frag_; + bool nominated_; + bool connection_ready_to_send_ = false; +}; + +class PortTest : public testing::Test, public sigslot::has_slots<> { + public: + PortTest() + : main_(rtc::Thread::Current()), + pss_(new rtc::PhysicalSocketServer), + ss_(new rtc::VirtualSocketServer(pss_.get())), + ss_scope_(ss_.get()), + network_("unittest", "unittest", rtc::IPAddress(INADDR_ANY), 32), + socket_factory_(rtc::Thread::Current()), + nat_factory1_(ss_.get(), kNatAddr1, SocketAddress()), + nat_factory2_(ss_.get(), kNatAddr2, SocketAddress()), + nat_socket_factory1_(&nat_factory1_), + nat_socket_factory2_(&nat_factory2_), + stun_server_(TestStunServer::Create(main_, kStunAddr)), + turn_server_(main_, kTurnUdpIntAddr, kTurnUdpExtAddr), + relay_server_(main_, + kRelayUdpIntAddr, + kRelayUdpExtAddr, + kRelayTcpIntAddr, + kRelayTcpExtAddr, + kRelaySslTcpIntAddr, + kRelaySslTcpExtAddr), + username_(rtc::CreateRandomString(ICE_UFRAG_LENGTH)), + password_(rtc::CreateRandomString(ICE_PWD_LENGTH)), + role_conflict_(false), + destroyed_(false) { + network_.AddIP(rtc::IPAddress(INADDR_ANY)); + } + + protected: + void TestLocalToLocal() { + Port* port1 = CreateUdpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + Port* port2 = CreateUdpPort(kLocalAddr2); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + TestConnectivity("udp", port1, "udp", port2, true, true, true, true); + } + void TestLocalToStun(NATType ntype) { + Port* port1 = CreateUdpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + nat_server2_.reset(CreateNatServer(kNatAddr2, ntype)); + Port* port2 = CreateStunPort(kLocalAddr2, &nat_socket_factory2_); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + TestConnectivity("udp", port1, StunName(ntype), port2, + ntype == NAT_OPEN_CONE, true, + ntype != NAT_SYMMETRIC, true); + } + void TestLocalToRelay(RelayType rtype, ProtocolType proto) { + Port* port1 = CreateUdpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + Port* port2 = CreateRelayPort(kLocalAddr2, rtype, proto, PROTO_UDP); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + TestConnectivity("udp", port1, RelayName(rtype, proto), port2, + rtype == RELAY_GTURN, true, true, true); + } + void TestStunToLocal(NATType ntype) { + nat_server1_.reset(CreateNatServer(kNatAddr1, ntype)); + Port* port1 = CreateStunPort(kLocalAddr1, &nat_socket_factory1_); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + Port* port2 = CreateUdpPort(kLocalAddr2); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + TestConnectivity(StunName(ntype), port1, "udp", port2, + true, ntype != NAT_SYMMETRIC, true, true); + } + void TestStunToStun(NATType ntype1, NATType ntype2) { + nat_server1_.reset(CreateNatServer(kNatAddr1, ntype1)); + Port* port1 = CreateStunPort(kLocalAddr1, &nat_socket_factory1_); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + nat_server2_.reset(CreateNatServer(kNatAddr2, ntype2)); + Port* port2 = CreateStunPort(kLocalAddr2, &nat_socket_factory2_); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + TestConnectivity(StunName(ntype1), port1, StunName(ntype2), port2, + ntype2 == NAT_OPEN_CONE, + ntype1 != NAT_SYMMETRIC, ntype2 != NAT_SYMMETRIC, + ntype1 + ntype2 < (NAT_PORT_RESTRICTED + NAT_SYMMETRIC)); + } + void TestStunToRelay(NATType ntype, RelayType rtype, ProtocolType proto) { + nat_server1_.reset(CreateNatServer(kNatAddr1, ntype)); + Port* port1 = CreateStunPort(kLocalAddr1, &nat_socket_factory1_); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + Port* port2 = CreateRelayPort(kLocalAddr2, rtype, proto, PROTO_UDP); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + TestConnectivity(StunName(ntype), port1, RelayName(rtype, proto), port2, + rtype == RELAY_GTURN, ntype != NAT_SYMMETRIC, true, true); + } + void TestTcpToTcp() { + Port* port1 = CreateTcpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + Port* port2 = CreateTcpPort(kLocalAddr2); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + TestConnectivity("tcp", port1, "tcp", port2, true, false, true, true); + } + void TestTcpToRelay(RelayType rtype, ProtocolType proto) { + Port* port1 = CreateTcpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + Port* port2 = CreateRelayPort(kLocalAddr2, rtype, proto, PROTO_TCP); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + TestConnectivity("tcp", port1, RelayName(rtype, proto), port2, + rtype == RELAY_GTURN, false, true, true); + } + void TestSslTcpToRelay(RelayType rtype, ProtocolType proto) { + Port* port1 = CreateTcpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + Port* port2 = CreateRelayPort(kLocalAddr2, rtype, proto, PROTO_SSLTCP); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + TestConnectivity("ssltcp", port1, RelayName(rtype, proto), port2, + rtype == RELAY_GTURN, false, true, true); + } + // helpers for above functions + UDPPort* CreateUdpPort(const SocketAddress& addr) { + return CreateUdpPort(addr, &socket_factory_); + } + UDPPort* CreateUdpPort(const SocketAddress& addr, + PacketSocketFactory* socket_factory) { + return UDPPort::Create(main_, socket_factory, &network_, + addr.ipaddr(), 0, 0, username_, password_, + std::string(), false); + } + TCPPort* CreateTcpPort(const SocketAddress& addr) { + return CreateTcpPort(addr, &socket_factory_); + } + TCPPort* CreateTcpPort(const SocketAddress& addr, + PacketSocketFactory* socket_factory) { + return TCPPort::Create(main_, socket_factory, &network_, + addr.ipaddr(), 0, 0, username_, password_, + true); + } + StunPort* CreateStunPort(const SocketAddress& addr, + rtc::PacketSocketFactory* factory) { + ServerAddresses stun_servers; + stun_servers.insert(kStunAddr); + return StunPort::Create(main_, factory, &network_, + addr.ipaddr(), 0, 0, + username_, password_, stun_servers, + std::string()); + } + Port* CreateRelayPort(const SocketAddress& addr, RelayType rtype, + ProtocolType int_proto, ProtocolType ext_proto) { + if (rtype == RELAY_TURN) { + return CreateTurnPort(addr, &socket_factory_, int_proto, ext_proto); + } else { + return CreateGturnPort(addr, int_proto, ext_proto); + } + } + TurnPort* CreateTurnPort(const SocketAddress& addr, + PacketSocketFactory* socket_factory, + ProtocolType int_proto, ProtocolType ext_proto) { + return CreateTurnPort(addr, socket_factory, + int_proto, ext_proto, kTurnUdpIntAddr); + } + TurnPort* CreateTurnPort(const SocketAddress& addr, + PacketSocketFactory* socket_factory, + ProtocolType int_proto, ProtocolType ext_proto, + const rtc::SocketAddress& server_addr) { + return TurnPort::Create(main_, socket_factory, &network_, + addr.ipaddr(), 0, 0, + username_, password_, ProtocolAddress( + server_addr, PROTO_UDP), + kRelayCredentials, 0, + std::string()); + } + RelayPort* CreateGturnPort(const SocketAddress& addr, + ProtocolType int_proto, ProtocolType ext_proto) { + RelayPort* port = CreateGturnPort(addr); + SocketAddress addrs[] = + { kRelayUdpIntAddr, kRelayTcpIntAddr, kRelaySslTcpIntAddr }; + port->AddServerAddress(ProtocolAddress(addrs[int_proto], int_proto)); + return port; + } + RelayPort* CreateGturnPort(const SocketAddress& addr) { + // TODO(pthatcher): Remove GTURN. + return RelayPort::Create(main_, &socket_factory_, &network_, + addr.ipaddr(), 0, 0, + username_, password_); + // TODO: Add an external address for ext_proto, so that the + // other side can connect to this port using a non-UDP protocol. + } + rtc::NATServer* CreateNatServer(const SocketAddress& addr, + rtc::NATType type) { + return new rtc::NATServer(type, ss_.get(), addr, addr, ss_.get(), addr); + } + static const char* StunName(NATType type) { + switch (type) { + case NAT_OPEN_CONE: return "stun(open cone)"; + case NAT_ADDR_RESTRICTED: return "stun(addr restricted)"; + case NAT_PORT_RESTRICTED: return "stun(port restricted)"; + case NAT_SYMMETRIC: return "stun(symmetric)"; + default: return "stun(?)"; + } + } + static const char* RelayName(RelayType type, ProtocolType proto) { + if (type == RELAY_TURN) { + switch (proto) { + case PROTO_UDP: return "turn(udp)"; + case PROTO_TCP: return "turn(tcp)"; + case PROTO_SSLTCP: return "turn(ssltcp)"; + default: return "turn(?)"; + } + } else { + switch (proto) { + case PROTO_UDP: return "gturn(udp)"; + case PROTO_TCP: return "gturn(tcp)"; + case PROTO_SSLTCP: return "gturn(ssltcp)"; + default: return "gturn(?)"; + } + } + } + + void TestCrossFamilyPorts(int type); + + void ExpectPortsCanConnect(bool can_connect, Port* p1, Port* p2); + + // This does all the work and then deletes |port1| and |port2|. + void TestConnectivity(const char* name1, Port* port1, + const char* name2, Port* port2, + bool accept, bool same_addr1, + bool same_addr2, bool possible); + + // This connects the provided channels which have already started. |ch1| + // should have its Connection created (either through CreateConnection() or + // TCP reconnecting mechanism before entering this function. + void ConnectStartedChannels(TestChannel* ch1, TestChannel* ch2) { + ASSERT_TRUE(ch1->conn()); + EXPECT_TRUE_WAIT(ch1->conn()->connected(), kTimeout); // for TCP connect + ch1->Ping(); + WAIT(!ch2->remote_address().IsNil(), kTimeout); + + // Send a ping from dst to src. + ch2->AcceptConnection(GetCandidate(ch1->port())); + ch2->Ping(); + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, ch2->conn()->write_state(), + kTimeout); + } + + // This connects and disconnects the provided channels in the same sequence as + // TestConnectivity with all options set to |true|. It does not delete either + // channel. + void StartConnectAndStopChannels(TestChannel* ch1, TestChannel* ch2) { + // Acquire addresses. + ch1->Start(); + ch2->Start(); + + ch1->CreateConnection(GetCandidate(ch2->port())); + ConnectStartedChannels(ch1, ch2); + + // Destroy the connections. + ch1->Stop(); + ch2->Stop(); + } + + // This disconnects both end's Connection and make sure ch2 ready for new + // connection. + void DisconnectTcpTestChannels(TestChannel* ch1, TestChannel* ch2) { + TCPConnection* tcp_conn1 = static_cast<TCPConnection*>(ch1->conn()); + TCPConnection* tcp_conn2 = static_cast<TCPConnection*>(ch2->conn()); + ASSERT_TRUE( + ss_->CloseTcpConnections(tcp_conn1->socket()->GetLocalAddress(), + tcp_conn2->socket()->GetLocalAddress())); + + // Wait for both OnClose are delivered. + EXPECT_TRUE_WAIT(!ch1->conn()->connected(), kTimeout); + EXPECT_TRUE_WAIT(!ch2->conn()->connected(), kTimeout); + + // Ensure redundant SignalClose events on TcpConnection won't break tcp + // reconnection. Chromium will fire SignalClose for all outstanding IPC + // packets during reconnection. + tcp_conn1->socket()->SignalClose(tcp_conn1->socket(), 0); + tcp_conn2->socket()->SignalClose(tcp_conn2->socket(), 0); + + // Speed up destroying ch2's connection such that the test is ready to + // accept a new connection from ch1 before ch1's connection destroys itself. + ch2->conn()->Destroy(); + EXPECT_TRUE_WAIT(ch2->conn() == NULL, kTimeout); + } + + void TestTcpReconnect(bool ping_after_disconnected, + bool send_after_disconnected) { + Port* port1 = CreateTcpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + Port* port2 = CreateTcpPort(kLocalAddr2); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + + port1->set_component(cricket::ICE_CANDIDATE_COMPONENT_DEFAULT); + port2->set_component(cricket::ICE_CANDIDATE_COMPONENT_DEFAULT); + + // Set up channels and ensure both ports will be deleted. + TestChannel ch1(port1); + TestChannel ch2(port2); + EXPECT_EQ(0, ch1.complete_count()); + EXPECT_EQ(0, ch2.complete_count()); + + ch1.Start(); + ch2.Start(); + ASSERT_EQ_WAIT(1, ch1.complete_count(), kTimeout); + ASSERT_EQ_WAIT(1, ch2.complete_count(), kTimeout); + + // Initial connecting the channel, create connection on channel1. + ch1.CreateConnection(GetCandidate(port2)); + ConnectStartedChannels(&ch1, &ch2); + + // Shorten the timeout period. + const int kTcpReconnectTimeout = kTimeout; + static_cast<TCPConnection*>(ch1.conn()) + ->set_reconnection_timeout(kTcpReconnectTimeout); + static_cast<TCPConnection*>(ch2.conn()) + ->set_reconnection_timeout(kTcpReconnectTimeout); + + EXPECT_FALSE(ch1.connection_ready_to_send()); + EXPECT_FALSE(ch2.connection_ready_to_send()); + + // Once connected, disconnect them. + DisconnectTcpTestChannels(&ch1, &ch2); + + if (send_after_disconnected || ping_after_disconnected) { + if (send_after_disconnected) { + // First SendData after disconnect should fail but will trigger + // reconnect. + EXPECT_EQ(-1, ch1.SendData(data, static_cast<int>(strlen(data)))); + } + + if (ping_after_disconnected) { + // Ping should trigger reconnect. + ch1.Ping(); + } + + // Wait for channel's outgoing TCPConnection connected. + EXPECT_TRUE_WAIT(ch1.conn()->connected(), kTimeout); + + // Verify that we could still connect channels. + ConnectStartedChannels(&ch1, &ch2); + EXPECT_TRUE_WAIT(ch1.connection_ready_to_send(), + kTcpReconnectTimeout); + // Channel2 is the passive one so a new connection is created during + // reconnect. This new connection should never have issued EWOULDBLOCK + // hence the connection_ready_to_send() should be false. + EXPECT_FALSE(ch2.connection_ready_to_send()); + } else { + EXPECT_EQ(ch1.conn()->write_state(), Connection::STATE_WRITABLE); + // Since the reconnection never happens, the connections should have been + // destroyed after the timeout. + EXPECT_TRUE_WAIT(!ch1.conn(), kTcpReconnectTimeout + kTimeout); + EXPECT_TRUE(!ch2.conn()); + } + + // Tear down and ensure that goes smoothly. + ch1.Stop(); + ch2.Stop(); + EXPECT_TRUE_WAIT(ch1.conn() == NULL, kTimeout); + EXPECT_TRUE_WAIT(ch2.conn() == NULL, kTimeout); + } + + IceMessage* CreateStunMessage(int type) { + IceMessage* msg = new IceMessage(); + msg->SetType(type); + msg->SetTransactionID("TESTTESTTEST"); + return msg; + } + IceMessage* CreateStunMessageWithUsername(int type, + const std::string& username) { + IceMessage* msg = CreateStunMessage(type); + msg->AddAttribute( + new StunByteStringAttribute(STUN_ATTR_USERNAME, username)); + return msg; + } + TestPort* CreateTestPort(const rtc::SocketAddress& addr, + const std::string& username, + const std::string& password) { + TestPort* port = new TestPort(main_, "test", &socket_factory_, &network_, + addr.ipaddr(), 0, 0, username, password); + port->SignalRoleConflict.connect(this, &PortTest::OnRoleConflict); + return port; + } + TestPort* CreateTestPort(const rtc::SocketAddress& addr, + const std::string& username, + const std::string& password, + cricket::IceRole role, + int tiebreaker) { + TestPort* port = CreateTestPort(addr, username, password); + port->SetIceRole(role); + port->SetIceTiebreaker(tiebreaker); + return port; + } + + void OnRoleConflict(PortInterface* port) { + role_conflict_ = true; + } + bool role_conflict() const { return role_conflict_; } + + void ConnectToSignalDestroyed(PortInterface* port) { + port->SignalDestroyed.connect(this, &PortTest::OnDestroyed); + } + + void OnDestroyed(PortInterface* port) { + destroyed_ = true; + } + bool destroyed() const { return destroyed_; } + + rtc::BasicPacketSocketFactory* nat_socket_factory1() { + return &nat_socket_factory1_; + } + + protected: + rtc::VirtualSocketServer* vss() { return ss_.get(); } + + private: + rtc::Thread* main_; + rtc::scoped_ptr<rtc::PhysicalSocketServer> pss_; + rtc::scoped_ptr<rtc::VirtualSocketServer> ss_; + rtc::SocketServerScope ss_scope_; + rtc::Network network_; + rtc::BasicPacketSocketFactory socket_factory_; + rtc::scoped_ptr<rtc::NATServer> nat_server1_; + rtc::scoped_ptr<rtc::NATServer> nat_server2_; + rtc::NATSocketFactory nat_factory1_; + rtc::NATSocketFactory nat_factory2_; + rtc::BasicPacketSocketFactory nat_socket_factory1_; + rtc::BasicPacketSocketFactory nat_socket_factory2_; + scoped_ptr<TestStunServer> stun_server_; + TestTurnServer turn_server_; + TestRelayServer relay_server_; + std::string username_; + std::string password_; + bool role_conflict_; + bool destroyed_; +}; + +void PortTest::TestConnectivity(const char* name1, Port* port1, + const char* name2, Port* port2, + bool accept, bool same_addr1, + bool same_addr2, bool possible) { + LOG(LS_INFO) << "Test: " << name1 << " to " << name2 << ": "; + port1->set_component(cricket::ICE_CANDIDATE_COMPONENT_DEFAULT); + port2->set_component(cricket::ICE_CANDIDATE_COMPONENT_DEFAULT); + + // Set up channels and ensure both ports will be deleted. + TestChannel ch1(port1); + TestChannel ch2(port2); + EXPECT_EQ(0, ch1.complete_count()); + EXPECT_EQ(0, ch2.complete_count()); + + // Acquire addresses. + ch1.Start(); + ch2.Start(); + ASSERT_EQ_WAIT(1, ch1.complete_count(), kTimeout); + ASSERT_EQ_WAIT(1, ch2.complete_count(), kTimeout); + + // Send a ping from src to dst. This may or may not make it. + ch1.CreateConnection(GetCandidate(port2)); + ASSERT_TRUE(ch1.conn() != NULL); + EXPECT_TRUE_WAIT(ch1.conn()->connected(), kTimeout); // for TCP connect + ch1.Ping(); + WAIT(!ch2.remote_address().IsNil(), kTimeout); + + if (accept) { + // We are able to send a ping from src to dst. This is the case when + // sending to UDP ports and cone NATs. + EXPECT_TRUE(ch1.remote_address().IsNil()); + EXPECT_EQ(ch2.remote_fragment(), port1->username_fragment()); + + // Ensure the ping came from the same address used for src. + // This is the case unless the source NAT was symmetric. + if (same_addr1) EXPECT_EQ(ch2.remote_address(), GetAddress(port1)); + EXPECT_TRUE(same_addr2); + + // Send a ping from dst to src. + ch2.AcceptConnection(GetCandidate(port1)); + ASSERT_TRUE(ch2.conn() != NULL); + ch2.Ping(); + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, ch2.conn()->write_state(), + kTimeout); + } else { + // We can't send a ping from src to dst, so flip it around. This will happen + // when the destination NAT is addr/port restricted or symmetric. + EXPECT_TRUE(ch1.remote_address().IsNil()); + EXPECT_TRUE(ch2.remote_address().IsNil()); + + // Send a ping from dst to src. Again, this may or may not make it. + ch2.CreateConnection(GetCandidate(port1)); + ASSERT_TRUE(ch2.conn() != NULL); + ch2.Ping(); + WAIT(ch2.conn()->write_state() == Connection::STATE_WRITABLE, kTimeout); + + if (same_addr1 && same_addr2) { + // The new ping got back to the source. + EXPECT_TRUE(ch1.conn()->receiving()); + EXPECT_EQ(Connection::STATE_WRITABLE, ch2.conn()->write_state()); + + // First connection may not be writable if the first ping did not get + // through. So we will have to do another. + if (ch1.conn()->write_state() == Connection::STATE_WRITE_INIT) { + ch1.Ping(); + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, ch1.conn()->write_state(), + kTimeout); + } + } else if (!same_addr1 && possible) { + // The new ping went to the candidate address, but that address was bad. + // This will happen when the source NAT is symmetric. + EXPECT_TRUE(ch1.remote_address().IsNil()); + EXPECT_TRUE(ch2.remote_address().IsNil()); + + // However, since we have now sent a ping to the source IP, we should be + // able to get a ping from it. This gives us the real source address. + ch1.Ping(); + EXPECT_TRUE_WAIT(!ch2.remote_address().IsNil(), kTimeout); + EXPECT_FALSE(ch2.conn()->receiving()); + EXPECT_TRUE(ch1.remote_address().IsNil()); + + // Pick up the actual address and establish the connection. + ch2.AcceptConnection(GetCandidate(port1)); + ASSERT_TRUE(ch2.conn() != NULL); + ch2.Ping(); + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, ch2.conn()->write_state(), + kTimeout); + } else if (!same_addr2 && possible) { + // The new ping came in, but from an unexpected address. This will happen + // when the destination NAT is symmetric. + EXPECT_FALSE(ch1.remote_address().IsNil()); + EXPECT_FALSE(ch1.conn()->receiving()); + + // Update our address and complete the connection. + ch1.AcceptConnection(GetCandidate(port2)); + ch1.Ping(); + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, ch1.conn()->write_state(), + kTimeout); + } else { // (!possible) + // There should be s no way for the pings to reach each other. Check it. + EXPECT_TRUE(ch1.remote_address().IsNil()); + EXPECT_TRUE(ch2.remote_address().IsNil()); + ch1.Ping(); + WAIT(!ch2.remote_address().IsNil(), kTimeout); + EXPECT_TRUE(ch1.remote_address().IsNil()); + EXPECT_TRUE(ch2.remote_address().IsNil()); + } + } + + // Everything should be good, unless we know the situation is impossible. + ASSERT_TRUE(ch1.conn() != NULL); + ASSERT_TRUE(ch2.conn() != NULL); + if (possible) { + EXPECT_TRUE(ch1.conn()->receiving()); + EXPECT_EQ(Connection::STATE_WRITABLE, ch1.conn()->write_state()); + EXPECT_TRUE(ch2.conn()->receiving()); + EXPECT_EQ(Connection::STATE_WRITABLE, ch2.conn()->write_state()); + } else { + EXPECT_FALSE(ch1.conn()->receiving()); + EXPECT_NE(Connection::STATE_WRITABLE, ch1.conn()->write_state()); + EXPECT_FALSE(ch2.conn()->receiving()); + EXPECT_NE(Connection::STATE_WRITABLE, ch2.conn()->write_state()); + } + + // Tear down and ensure that goes smoothly. + ch1.Stop(); + ch2.Stop(); + EXPECT_TRUE_WAIT(ch1.conn() == NULL, kTimeout); + EXPECT_TRUE_WAIT(ch2.conn() == NULL, kTimeout); +} + +class FakePacketSocketFactory : public rtc::PacketSocketFactory { + public: + FakePacketSocketFactory() + : next_udp_socket_(NULL), + next_server_tcp_socket_(NULL), + next_client_tcp_socket_(NULL) { + } + ~FakePacketSocketFactory() override { } + + AsyncPacketSocket* CreateUdpSocket(const SocketAddress& address, + uint16_t min_port, + uint16_t max_port) override { + EXPECT_TRUE(next_udp_socket_ != NULL); + AsyncPacketSocket* result = next_udp_socket_; + next_udp_socket_ = NULL; + return result; + } + + AsyncPacketSocket* CreateServerTcpSocket(const SocketAddress& local_address, + uint16_t min_port, + uint16_t max_port, + int opts) override { + EXPECT_TRUE(next_server_tcp_socket_ != NULL); + AsyncPacketSocket* result = next_server_tcp_socket_; + next_server_tcp_socket_ = NULL; + return result; + } + + // TODO: |proxy_info| and |user_agent| should be set + // per-factory and not when socket is created. + AsyncPacketSocket* CreateClientTcpSocket(const SocketAddress& local_address, + const SocketAddress& remote_address, + const rtc::ProxyInfo& proxy_info, + const std::string& user_agent, + int opts) override { + EXPECT_TRUE(next_client_tcp_socket_ != NULL); + AsyncPacketSocket* result = next_client_tcp_socket_; + next_client_tcp_socket_ = NULL; + return result; + } + + void set_next_udp_socket(AsyncPacketSocket* next_udp_socket) { + next_udp_socket_ = next_udp_socket; + } + void set_next_server_tcp_socket(AsyncPacketSocket* next_server_tcp_socket) { + next_server_tcp_socket_ = next_server_tcp_socket; + } + void set_next_client_tcp_socket(AsyncPacketSocket* next_client_tcp_socket) { + next_client_tcp_socket_ = next_client_tcp_socket; + } + rtc::AsyncResolverInterface* CreateAsyncResolver() { + return NULL; + } + + private: + AsyncPacketSocket* next_udp_socket_; + AsyncPacketSocket* next_server_tcp_socket_; + AsyncPacketSocket* next_client_tcp_socket_; +}; + +class FakeAsyncPacketSocket : public AsyncPacketSocket { + public: + // Returns current local address. Address may be set to NULL if the + // socket is not bound yet (GetState() returns STATE_BINDING). + virtual SocketAddress GetLocalAddress() const { + return SocketAddress(); + } + + // Returns remote address. Returns zeroes if this is not a client TCP socket. + virtual SocketAddress GetRemoteAddress() const { + return SocketAddress(); + } + + // Send a packet. + virtual int Send(const void *pv, size_t cb, + const rtc::PacketOptions& options) { + return static_cast<int>(cb); + } + virtual int SendTo(const void *pv, size_t cb, const SocketAddress& addr, + const rtc::PacketOptions& options) { + return static_cast<int>(cb); + } + virtual int Close() { + return 0; + } + + virtual State GetState() const { return state_; } + virtual int GetOption(Socket::Option opt, int* value) { return 0; } + virtual int SetOption(Socket::Option opt, int value) { return 0; } + virtual int GetError() const { return 0; } + virtual void SetError(int error) { } + + void set_state(State state) { state_ = state; } + + private: + State state_; +}; + +// Local -> XXXX +TEST_F(PortTest, TestLocalToLocal) { + TestLocalToLocal(); +} + +TEST_F(PortTest, TestLocalToConeNat) { + TestLocalToStun(NAT_OPEN_CONE); +} + +TEST_F(PortTest, TestLocalToARNat) { + TestLocalToStun(NAT_ADDR_RESTRICTED); +} + +TEST_F(PortTest, TestLocalToPRNat) { + TestLocalToStun(NAT_PORT_RESTRICTED); +} + +TEST_F(PortTest, TestLocalToSymNat) { + TestLocalToStun(NAT_SYMMETRIC); +} + +// Flaky: https://code.google.com/p/webrtc/issues/detail?id=3316. +TEST_F(PortTest, DISABLED_TestLocalToTurn) { + TestLocalToRelay(RELAY_TURN, PROTO_UDP); +} + +TEST_F(PortTest, TestLocalToGturn) { + TestLocalToRelay(RELAY_GTURN, PROTO_UDP); +} + +TEST_F(PortTest, TestLocalToTcpGturn) { + TestLocalToRelay(RELAY_GTURN, PROTO_TCP); +} + +TEST_F(PortTest, TestLocalToSslTcpGturn) { + TestLocalToRelay(RELAY_GTURN, PROTO_SSLTCP); +} + +// Cone NAT -> XXXX +TEST_F(PortTest, TestConeNatToLocal) { + TestStunToLocal(NAT_OPEN_CONE); +} + +TEST_F(PortTest, TestConeNatToConeNat) { + TestStunToStun(NAT_OPEN_CONE, NAT_OPEN_CONE); +} + +TEST_F(PortTest, TestConeNatToARNat) { + TestStunToStun(NAT_OPEN_CONE, NAT_ADDR_RESTRICTED); +} + +TEST_F(PortTest, TestConeNatToPRNat) { + TestStunToStun(NAT_OPEN_CONE, NAT_PORT_RESTRICTED); +} + +TEST_F(PortTest, TestConeNatToSymNat) { + TestStunToStun(NAT_OPEN_CONE, NAT_SYMMETRIC); +} + +TEST_F(PortTest, TestConeNatToTurn) { + TestStunToRelay(NAT_OPEN_CONE, RELAY_TURN, PROTO_UDP); +} + +TEST_F(PortTest, TestConeNatToGturn) { + TestStunToRelay(NAT_OPEN_CONE, RELAY_GTURN, PROTO_UDP); +} + +TEST_F(PortTest, TestConeNatToTcpGturn) { + TestStunToRelay(NAT_OPEN_CONE, RELAY_GTURN, PROTO_TCP); +} + +// Address-restricted NAT -> XXXX +TEST_F(PortTest, TestARNatToLocal) { + TestStunToLocal(NAT_ADDR_RESTRICTED); +} + +TEST_F(PortTest, TestARNatToConeNat) { + TestStunToStun(NAT_ADDR_RESTRICTED, NAT_OPEN_CONE); +} + +TEST_F(PortTest, TestARNatToARNat) { + TestStunToStun(NAT_ADDR_RESTRICTED, NAT_ADDR_RESTRICTED); +} + +TEST_F(PortTest, TestARNatToPRNat) { + TestStunToStun(NAT_ADDR_RESTRICTED, NAT_PORT_RESTRICTED); +} + +TEST_F(PortTest, TestARNatToSymNat) { + TestStunToStun(NAT_ADDR_RESTRICTED, NAT_SYMMETRIC); +} + +TEST_F(PortTest, TestARNatToTurn) { + TestStunToRelay(NAT_ADDR_RESTRICTED, RELAY_TURN, PROTO_UDP); +} + +TEST_F(PortTest, TestARNatToGturn) { + TestStunToRelay(NAT_ADDR_RESTRICTED, RELAY_GTURN, PROTO_UDP); +} + +TEST_F(PortTest, TestARNATNatToTcpGturn) { + TestStunToRelay(NAT_ADDR_RESTRICTED, RELAY_GTURN, PROTO_TCP); +} + +// Port-restricted NAT -> XXXX +TEST_F(PortTest, TestPRNatToLocal) { + TestStunToLocal(NAT_PORT_RESTRICTED); +} + +TEST_F(PortTest, TestPRNatToConeNat) { + TestStunToStun(NAT_PORT_RESTRICTED, NAT_OPEN_CONE); +} + +TEST_F(PortTest, TestPRNatToARNat) { + TestStunToStun(NAT_PORT_RESTRICTED, NAT_ADDR_RESTRICTED); +} + +TEST_F(PortTest, TestPRNatToPRNat) { + TestStunToStun(NAT_PORT_RESTRICTED, NAT_PORT_RESTRICTED); +} + +TEST_F(PortTest, TestPRNatToSymNat) { + // Will "fail" + TestStunToStun(NAT_PORT_RESTRICTED, NAT_SYMMETRIC); +} + +TEST_F(PortTest, TestPRNatToTurn) { + TestStunToRelay(NAT_PORT_RESTRICTED, RELAY_TURN, PROTO_UDP); +} + +TEST_F(PortTest, TestPRNatToGturn) { + TestStunToRelay(NAT_PORT_RESTRICTED, RELAY_GTURN, PROTO_UDP); +} + +TEST_F(PortTest, TestPRNatToTcpGturn) { + TestStunToRelay(NAT_PORT_RESTRICTED, RELAY_GTURN, PROTO_TCP); +} + +// Symmetric NAT -> XXXX +TEST_F(PortTest, TestSymNatToLocal) { + TestStunToLocal(NAT_SYMMETRIC); +} + +TEST_F(PortTest, TestSymNatToConeNat) { + TestStunToStun(NAT_SYMMETRIC, NAT_OPEN_CONE); +} + +TEST_F(PortTest, TestSymNatToARNat) { + TestStunToStun(NAT_SYMMETRIC, NAT_ADDR_RESTRICTED); +} + +TEST_F(PortTest, TestSymNatToPRNat) { + // Will "fail" + TestStunToStun(NAT_SYMMETRIC, NAT_PORT_RESTRICTED); +} + +TEST_F(PortTest, TestSymNatToSymNat) { + // Will "fail" + TestStunToStun(NAT_SYMMETRIC, NAT_SYMMETRIC); +} + +TEST_F(PortTest, TestSymNatToTurn) { + TestStunToRelay(NAT_SYMMETRIC, RELAY_TURN, PROTO_UDP); +} + +TEST_F(PortTest, TestSymNatToGturn) { + TestStunToRelay(NAT_SYMMETRIC, RELAY_GTURN, PROTO_UDP); +} + +TEST_F(PortTest, TestSymNatToTcpGturn) { + TestStunToRelay(NAT_SYMMETRIC, RELAY_GTURN, PROTO_TCP); +} + +// Outbound TCP -> XXXX +TEST_F(PortTest, TestTcpToTcp) { + TestTcpToTcp(); +} + +TEST_F(PortTest, TestTcpReconnectOnSendPacket) { + TestTcpReconnect(false /* ping */, true /* send */); +} + +TEST_F(PortTest, TestTcpReconnectOnPing) { + TestTcpReconnect(true /* ping */, false /* send */); +} + +TEST_F(PortTest, TestTcpReconnectTimeout) { + TestTcpReconnect(false /* ping */, false /* send */); +} + +// Test when TcpConnection never connects, the OnClose() will be called to +// destroy the connection. +TEST_F(PortTest, TestTcpNeverConnect) { + Port* port1 = CreateTcpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + port1->set_component(cricket::ICE_CANDIDATE_COMPONENT_DEFAULT); + + // Set up a channel and ensure the port will be deleted. + TestChannel ch1(port1); + EXPECT_EQ(0, ch1.complete_count()); + + ch1.Start(); + ASSERT_EQ_WAIT(1, ch1.complete_count(), kTimeout); + + rtc::scoped_ptr<rtc::AsyncSocket> server( + vss()->CreateAsyncSocket(kLocalAddr2.family(), SOCK_STREAM)); + // Bind but not listen. + EXPECT_EQ(0, server->Bind(kLocalAddr2)); + + Candidate c = GetCandidate(port1); + c.set_address(server->GetLocalAddress()); + + ch1.CreateConnection(c); + EXPECT_TRUE(ch1.conn()); + EXPECT_TRUE_WAIT(!ch1.conn(), kTimeout); // for TCP connect +} + +/* TODO: Enable these once testrelayserver can accept external TCP. +TEST_F(PortTest, TestTcpToTcpRelay) { + TestTcpToRelay(PROTO_TCP); +} + +TEST_F(PortTest, TestTcpToSslTcpRelay) { + TestTcpToRelay(PROTO_SSLTCP); +} +*/ + +// Outbound SSLTCP -> XXXX +/* TODO: Enable these once testrelayserver can accept external SSL. +TEST_F(PortTest, TestSslTcpToTcpRelay) { + TestSslTcpToRelay(PROTO_TCP); +} + +TEST_F(PortTest, TestSslTcpToSslTcpRelay) { + TestSslTcpToRelay(PROTO_SSLTCP); +} +*/ + +// This test case verifies standard ICE features in STUN messages. Currently it +// verifies Message Integrity attribute in STUN messages and username in STUN +// binding request will have colon (":") between remote and local username. +TEST_F(PortTest, TestLocalToLocalStandard) { + UDPPort* port1 = CreateUdpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + port1->SetIceTiebreaker(kTiebreaker1); + UDPPort* port2 = CreateUdpPort(kLocalAddr2); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + port2->SetIceTiebreaker(kTiebreaker2); + // Same parameters as TestLocalToLocal above. + TestConnectivity("udp", port1, "udp", port2, true, true, true, true); +} + +// This test is trying to validate a successful and failure scenario in a +// loopback test when protocol is RFC5245. For success IceTiebreaker, username +// should remain equal to the request generated by the port and role of port +// must be in controlling. +TEST_F(PortTest, TestLoopbackCal) { + rtc::scoped_ptr<TestPort> lport( + CreateTestPort(kLocalAddr1, "lfrag", "lpass")); + lport->SetIceRole(cricket::ICEROLE_CONTROLLING); + lport->SetIceTiebreaker(kTiebreaker1); + lport->PrepareAddress(); + ASSERT_FALSE(lport->Candidates().empty()); + Connection* conn = lport->CreateConnection(lport->Candidates()[0], + Port::ORIGIN_MESSAGE); + conn->Ping(0); + + ASSERT_TRUE_WAIT(lport->last_stun_msg() != NULL, 1000); + IceMessage* msg = lport->last_stun_msg(); + EXPECT_EQ(STUN_BINDING_REQUEST, msg->type()); + conn->OnReadPacket(lport->last_stun_buf()->Data(), + lport->last_stun_buf()->Length(), + rtc::PacketTime()); + ASSERT_TRUE_WAIT(lport->last_stun_msg() != NULL, 1000); + msg = lport->last_stun_msg(); + EXPECT_EQ(STUN_BINDING_RESPONSE, msg->type()); + + // If the tiebreaker value is different from port, we expect a error + // response. + lport->Reset(); + lport->AddCandidateAddress(kLocalAddr2); + // Creating a different connection as |conn| is receiving. + Connection* conn1 = lport->CreateConnection(lport->Candidates()[1], + Port::ORIGIN_MESSAGE); + conn1->Ping(0); + + ASSERT_TRUE_WAIT(lport->last_stun_msg() != NULL, 1000); + msg = lport->last_stun_msg(); + EXPECT_EQ(STUN_BINDING_REQUEST, msg->type()); + rtc::scoped_ptr<IceMessage> modified_req( + CreateStunMessage(STUN_BINDING_REQUEST)); + const StunByteStringAttribute* username_attr = msg->GetByteString( + STUN_ATTR_USERNAME); + modified_req->AddAttribute(new StunByteStringAttribute( + STUN_ATTR_USERNAME, username_attr->GetString())); + // To make sure we receive error response, adding tiebreaker less than + // what's present in request. + modified_req->AddAttribute(new StunUInt64Attribute( + STUN_ATTR_ICE_CONTROLLING, kTiebreaker1 - 1)); + modified_req->AddMessageIntegrity("lpass"); + modified_req->AddFingerprint(); + + lport->Reset(); + rtc::scoped_ptr<ByteBuffer> buf(new ByteBuffer()); + WriteStunMessage(modified_req.get(), buf.get()); + conn1->OnReadPacket(buf->Data(), buf->Length(), rtc::PacketTime()); + ASSERT_TRUE_WAIT(lport->last_stun_msg() != NULL, 1000); + msg = lport->last_stun_msg(); + EXPECT_EQ(STUN_BINDING_ERROR_RESPONSE, msg->type()); +} + +// This test verifies role conflict signal is received when there is +// conflict in the role. In this case both ports are in controlling and +// |rport| has higher tiebreaker value than |lport|. Since |lport| has lower +// value of tiebreaker, when it receives ping request from |rport| it will +// send role conflict signal. +TEST_F(PortTest, TestIceRoleConflict) { + rtc::scoped_ptr<TestPort> lport( + CreateTestPort(kLocalAddr1, "lfrag", "lpass")); + lport->SetIceRole(cricket::ICEROLE_CONTROLLING); + lport->SetIceTiebreaker(kTiebreaker1); + rtc::scoped_ptr<TestPort> rport( + CreateTestPort(kLocalAddr2, "rfrag", "rpass")); + rport->SetIceRole(cricket::ICEROLE_CONTROLLING); + rport->SetIceTiebreaker(kTiebreaker2); + + lport->PrepareAddress(); + rport->PrepareAddress(); + ASSERT_FALSE(lport->Candidates().empty()); + ASSERT_FALSE(rport->Candidates().empty()); + Connection* lconn = lport->CreateConnection(rport->Candidates()[0], + Port::ORIGIN_MESSAGE); + Connection* rconn = rport->CreateConnection(lport->Candidates()[0], + Port::ORIGIN_MESSAGE); + rconn->Ping(0); + + ASSERT_TRUE_WAIT(rport->last_stun_msg() != NULL, 1000); + IceMessage* msg = rport->last_stun_msg(); + EXPECT_EQ(STUN_BINDING_REQUEST, msg->type()); + // Send rport binding request to lport. + lconn->OnReadPacket(rport->last_stun_buf()->Data(), + rport->last_stun_buf()->Length(), + rtc::PacketTime()); + + ASSERT_TRUE_WAIT(lport->last_stun_msg() != NULL, 1000); + EXPECT_EQ(STUN_BINDING_RESPONSE, lport->last_stun_msg()->type()); + EXPECT_TRUE(role_conflict()); +} + +TEST_F(PortTest, TestTcpNoDelay) { + TCPPort* port1 = CreateTcpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + int option_value = -1; + int success = port1->GetOption(rtc::Socket::OPT_NODELAY, + &option_value); + ASSERT_EQ(0, success); // GetOption() should complete successfully w/ 0 + ASSERT_EQ(1, option_value); + delete port1; +} + +TEST_F(PortTest, TestDelayedBindingUdp) { + FakeAsyncPacketSocket *socket = new FakeAsyncPacketSocket(); + FakePacketSocketFactory socket_factory; + + socket_factory.set_next_udp_socket(socket); + scoped_ptr<UDPPort> port( + CreateUdpPort(kLocalAddr1, &socket_factory)); + + socket->set_state(AsyncPacketSocket::STATE_BINDING); + port->PrepareAddress(); + + EXPECT_EQ(0U, port->Candidates().size()); + socket->SignalAddressReady(socket, kLocalAddr2); + + EXPECT_EQ(1U, port->Candidates().size()); +} + +TEST_F(PortTest, TestDelayedBindingTcp) { + FakeAsyncPacketSocket *socket = new FakeAsyncPacketSocket(); + FakePacketSocketFactory socket_factory; + + socket_factory.set_next_server_tcp_socket(socket); + scoped_ptr<TCPPort> port( + CreateTcpPort(kLocalAddr1, &socket_factory)); + + socket->set_state(AsyncPacketSocket::STATE_BINDING); + port->PrepareAddress(); + + EXPECT_EQ(0U, port->Candidates().size()); + socket->SignalAddressReady(socket, kLocalAddr2); + + EXPECT_EQ(1U, port->Candidates().size()); +} + +void PortTest::TestCrossFamilyPorts(int type) { + FakePacketSocketFactory factory; + scoped_ptr<Port> ports[4]; + SocketAddress addresses[4] = {SocketAddress("192.168.1.3", 0), + SocketAddress("192.168.1.4", 0), + SocketAddress("2001:db8::1", 0), + SocketAddress("2001:db8::2", 0)}; + for (int i = 0; i < 4; i++) { + FakeAsyncPacketSocket *socket = new FakeAsyncPacketSocket(); + if (type == SOCK_DGRAM) { + factory.set_next_udp_socket(socket); + ports[i].reset(CreateUdpPort(addresses[i], &factory)); + } else if (type == SOCK_STREAM) { + factory.set_next_server_tcp_socket(socket); + ports[i].reset(CreateTcpPort(addresses[i], &factory)); + } + socket->set_state(AsyncPacketSocket::STATE_BINDING); + socket->SignalAddressReady(socket, addresses[i]); + ports[i]->PrepareAddress(); + } + + // IPv4 Port, connects to IPv6 candidate and then to IPv4 candidate. + if (type == SOCK_STREAM) { + FakeAsyncPacketSocket* clientsocket = new FakeAsyncPacketSocket(); + factory.set_next_client_tcp_socket(clientsocket); + } + Connection* c = ports[0]->CreateConnection(GetCandidate(ports[2].get()), + Port::ORIGIN_MESSAGE); + EXPECT_TRUE(NULL == c); + EXPECT_EQ(0U, ports[0]->connections().size()); + c = ports[0]->CreateConnection(GetCandidate(ports[1].get()), + Port::ORIGIN_MESSAGE); + EXPECT_FALSE(NULL == c); + EXPECT_EQ(1U, ports[0]->connections().size()); + + // IPv6 Port, connects to IPv4 candidate and to IPv6 candidate. + if (type == SOCK_STREAM) { + FakeAsyncPacketSocket* clientsocket = new FakeAsyncPacketSocket(); + factory.set_next_client_tcp_socket(clientsocket); + } + c = ports[2]->CreateConnection(GetCandidate(ports[0].get()), + Port::ORIGIN_MESSAGE); + EXPECT_TRUE(NULL == c); + EXPECT_EQ(0U, ports[2]->connections().size()); + c = ports[2]->CreateConnection(GetCandidate(ports[3].get()), + Port::ORIGIN_MESSAGE); + EXPECT_FALSE(NULL == c); + EXPECT_EQ(1U, ports[2]->connections().size()); +} + +TEST_F(PortTest, TestSkipCrossFamilyTcp) { + TestCrossFamilyPorts(SOCK_STREAM); +} + +TEST_F(PortTest, TestSkipCrossFamilyUdp) { + TestCrossFamilyPorts(SOCK_DGRAM); +} + +void PortTest::ExpectPortsCanConnect(bool can_connect, Port* p1, Port* p2) { + Connection* c = p1->CreateConnection(GetCandidate(p2), + Port::ORIGIN_MESSAGE); + if (can_connect) { + EXPECT_FALSE(NULL == c); + EXPECT_EQ(1U, p1->connections().size()); + } else { + EXPECT_TRUE(NULL == c); + EXPECT_EQ(0U, p1->connections().size()); + } +} + +TEST_F(PortTest, TestUdpV6CrossTypePorts) { + FakePacketSocketFactory factory; + scoped_ptr<Port> ports[4]; + SocketAddress addresses[4] = {SocketAddress("2001:db8::1", 0), + SocketAddress("fe80::1", 0), + SocketAddress("fe80::2", 0), + SocketAddress("::1", 0)}; + for (int i = 0; i < 4; i++) { + FakeAsyncPacketSocket *socket = new FakeAsyncPacketSocket(); + factory.set_next_udp_socket(socket); + ports[i].reset(CreateUdpPort(addresses[i], &factory)); + socket->set_state(AsyncPacketSocket::STATE_BINDING); + socket->SignalAddressReady(socket, addresses[i]); + ports[i]->PrepareAddress(); + } + + Port* standard = ports[0].get(); + Port* link_local1 = ports[1].get(); + Port* link_local2 = ports[2].get(); + Port* localhost = ports[3].get(); + + ExpectPortsCanConnect(false, link_local1, standard); + ExpectPortsCanConnect(false, standard, link_local1); + ExpectPortsCanConnect(false, link_local1, localhost); + ExpectPortsCanConnect(false, localhost, link_local1); + + ExpectPortsCanConnect(true, link_local1, link_local2); + ExpectPortsCanConnect(true, localhost, standard); + ExpectPortsCanConnect(true, standard, localhost); +} + +// This test verifies DSCP value set through SetOption interface can be +// get through DefaultDscpValue. +TEST_F(PortTest, TestDefaultDscpValue) { + int dscp; + rtc::scoped_ptr<UDPPort> udpport(CreateUdpPort(kLocalAddr1)); + EXPECT_EQ(0, udpport->SetOption(rtc::Socket::OPT_DSCP, + rtc::DSCP_CS6)); + EXPECT_EQ(0, udpport->GetOption(rtc::Socket::OPT_DSCP, &dscp)); + rtc::scoped_ptr<TCPPort> tcpport(CreateTcpPort(kLocalAddr1)); + EXPECT_EQ(0, tcpport->SetOption(rtc::Socket::OPT_DSCP, + rtc::DSCP_AF31)); + EXPECT_EQ(0, tcpport->GetOption(rtc::Socket::OPT_DSCP, &dscp)); + EXPECT_EQ(rtc::DSCP_AF31, dscp); + rtc::scoped_ptr<StunPort> stunport( + CreateStunPort(kLocalAddr1, nat_socket_factory1())); + EXPECT_EQ(0, stunport->SetOption(rtc::Socket::OPT_DSCP, + rtc::DSCP_AF41)); + EXPECT_EQ(0, stunport->GetOption(rtc::Socket::OPT_DSCP, &dscp)); + EXPECT_EQ(rtc::DSCP_AF41, dscp); + rtc::scoped_ptr<TurnPort> turnport1(CreateTurnPort( + kLocalAddr1, nat_socket_factory1(), PROTO_UDP, PROTO_UDP)); + // Socket is created in PrepareAddress. + turnport1->PrepareAddress(); + EXPECT_EQ(0, turnport1->SetOption(rtc::Socket::OPT_DSCP, + rtc::DSCP_CS7)); + EXPECT_EQ(0, turnport1->GetOption(rtc::Socket::OPT_DSCP, &dscp)); + EXPECT_EQ(rtc::DSCP_CS7, dscp); + // This will verify correct value returned without the socket. + rtc::scoped_ptr<TurnPort> turnport2(CreateTurnPort( + kLocalAddr1, nat_socket_factory1(), PROTO_UDP, PROTO_UDP)); + EXPECT_EQ(0, turnport2->SetOption(rtc::Socket::OPT_DSCP, + rtc::DSCP_CS6)); + EXPECT_EQ(0, turnport2->GetOption(rtc::Socket::OPT_DSCP, &dscp)); + EXPECT_EQ(rtc::DSCP_CS6, dscp); +} + +// Test sending STUN messages. +TEST_F(PortTest, TestSendStunMessage) { + rtc::scoped_ptr<TestPort> lport( + CreateTestPort(kLocalAddr1, "lfrag", "lpass")); + rtc::scoped_ptr<TestPort> rport( + CreateTestPort(kLocalAddr2, "rfrag", "rpass")); + lport->SetIceRole(cricket::ICEROLE_CONTROLLING); + lport->SetIceTiebreaker(kTiebreaker1); + rport->SetIceRole(cricket::ICEROLE_CONTROLLED); + rport->SetIceTiebreaker(kTiebreaker2); + + // Send a fake ping from lport to rport. + lport->PrepareAddress(); + rport->PrepareAddress(); + ASSERT_FALSE(rport->Candidates().empty()); + Connection* lconn = lport->CreateConnection( + rport->Candidates()[0], Port::ORIGIN_MESSAGE); + Connection* rconn = rport->CreateConnection( + lport->Candidates()[0], Port::ORIGIN_MESSAGE); + lconn->Ping(0); + + // Check that it's a proper BINDING-REQUEST. + ASSERT_TRUE_WAIT(lport->last_stun_msg() != NULL, 1000); + IceMessage* msg = lport->last_stun_msg(); + EXPECT_EQ(STUN_BINDING_REQUEST, msg->type()); + EXPECT_FALSE(msg->IsLegacy()); + const StunByteStringAttribute* username_attr = + msg->GetByteString(STUN_ATTR_USERNAME); + ASSERT_TRUE(username_attr != NULL); + const StunUInt32Attribute* priority_attr = msg->GetUInt32(STUN_ATTR_PRIORITY); + ASSERT_TRUE(priority_attr != NULL); + EXPECT_EQ(kDefaultPrflxPriority, priority_attr->value()); + EXPECT_EQ("rfrag:lfrag", username_attr->GetString()); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + lport->last_stun_buf()->Data(), lport->last_stun_buf()->Length(), + "rpass")); + const StunUInt64Attribute* ice_controlling_attr = + msg->GetUInt64(STUN_ATTR_ICE_CONTROLLING); + ASSERT_TRUE(ice_controlling_attr != NULL); + EXPECT_EQ(lport->IceTiebreaker(), ice_controlling_attr->value()); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_ICE_CONTROLLED) == NULL); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_USE_CANDIDATE) != NULL); + EXPECT_TRUE(msg->GetUInt32(STUN_ATTR_FINGERPRINT) != NULL); + EXPECT_TRUE(StunMessage::ValidateFingerprint( + lport->last_stun_buf()->Data(), lport->last_stun_buf()->Length())); + + // Request should not include ping count. + ASSERT_TRUE(msg->GetUInt32(STUN_ATTR_RETRANSMIT_COUNT) == NULL); + + // Save a copy of the BINDING-REQUEST for use below. + rtc::scoped_ptr<IceMessage> request(CopyStunMessage(msg)); + + // Respond with a BINDING-RESPONSE. + rport->SendBindingResponse(request.get(), lport->Candidates()[0].address()); + msg = rport->last_stun_msg(); + ASSERT_TRUE(msg != NULL); + EXPECT_EQ(STUN_BINDING_RESPONSE, msg->type()); + + + EXPECT_FALSE(msg->IsLegacy()); + const StunAddressAttribute* addr_attr = msg->GetAddress( + STUN_ATTR_XOR_MAPPED_ADDRESS); + ASSERT_TRUE(addr_attr != NULL); + EXPECT_EQ(lport->Candidates()[0].address(), addr_attr->GetAddress()); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + rport->last_stun_buf()->Data(), rport->last_stun_buf()->Length(), + "rpass")); + EXPECT_TRUE(msg->GetUInt32(STUN_ATTR_FINGERPRINT) != NULL); + EXPECT_TRUE(StunMessage::ValidateFingerprint( + lport->last_stun_buf()->Data(), lport->last_stun_buf()->Length())); + // No USERNAME or PRIORITY in ICE responses. + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_USERNAME) == NULL); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_PRIORITY) == NULL); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_MAPPED_ADDRESS) == NULL); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_ICE_CONTROLLING) == NULL); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_ICE_CONTROLLED) == NULL); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_USE_CANDIDATE) == NULL); + + // Response should not include ping count. + ASSERT_TRUE(msg->GetUInt32(STUN_ATTR_RETRANSMIT_COUNT) == NULL); + + // Respond with a BINDING-ERROR-RESPONSE. This wouldn't happen in real life, + // but we can do it here. + rport->SendBindingErrorResponse(request.get(), + lport->Candidates()[0].address(), + STUN_ERROR_SERVER_ERROR, + STUN_ERROR_REASON_SERVER_ERROR); + msg = rport->last_stun_msg(); + ASSERT_TRUE(msg != NULL); + EXPECT_EQ(STUN_BINDING_ERROR_RESPONSE, msg->type()); + EXPECT_FALSE(msg->IsLegacy()); + const StunErrorCodeAttribute* error_attr = msg->GetErrorCode(); + ASSERT_TRUE(error_attr != NULL); + EXPECT_EQ(STUN_ERROR_SERVER_ERROR, error_attr->code()); + EXPECT_EQ(std::string(STUN_ERROR_REASON_SERVER_ERROR), error_attr->reason()); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + rport->last_stun_buf()->Data(), rport->last_stun_buf()->Length(), + "rpass")); + EXPECT_TRUE(msg->GetUInt32(STUN_ATTR_FINGERPRINT) != NULL); + EXPECT_TRUE(StunMessage::ValidateFingerprint( + lport->last_stun_buf()->Data(), lport->last_stun_buf()->Length())); + // No USERNAME with ICE. + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_USERNAME) == NULL); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_PRIORITY) == NULL); + + // Testing STUN binding requests from rport --> lport, having ICE_CONTROLLED + // and (incremented) RETRANSMIT_COUNT attributes. + rport->Reset(); + rport->set_send_retransmit_count_attribute(true); + rconn->Ping(0); + rconn->Ping(0); + rconn->Ping(0); + ASSERT_TRUE_WAIT(rport->last_stun_msg() != NULL, 1000); + msg = rport->last_stun_msg(); + EXPECT_EQ(STUN_BINDING_REQUEST, msg->type()); + const StunUInt64Attribute* ice_controlled_attr = + msg->GetUInt64(STUN_ATTR_ICE_CONTROLLED); + ASSERT_TRUE(ice_controlled_attr != NULL); + EXPECT_EQ(rport->IceTiebreaker(), ice_controlled_attr->value()); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_USE_CANDIDATE) == NULL); + + // Request should include ping count. + const StunUInt32Attribute* retransmit_attr = + msg->GetUInt32(STUN_ATTR_RETRANSMIT_COUNT); + ASSERT_TRUE(retransmit_attr != NULL); + EXPECT_EQ(2U, retransmit_attr->value()); + + // Respond with a BINDING-RESPONSE. + request.reset(CopyStunMessage(msg)); + lport->SendBindingResponse(request.get(), rport->Candidates()[0].address()); + msg = lport->last_stun_msg(); + + // Response should include same ping count. + retransmit_attr = msg->GetUInt32(STUN_ATTR_RETRANSMIT_COUNT); + ASSERT_TRUE(retransmit_attr != NULL); + EXPECT_EQ(2U, retransmit_attr->value()); +} + +TEST_F(PortTest, TestUseCandidateAttribute) { + rtc::scoped_ptr<TestPort> lport( + CreateTestPort(kLocalAddr1, "lfrag", "lpass")); + rtc::scoped_ptr<TestPort> rport( + CreateTestPort(kLocalAddr2, "rfrag", "rpass")); + lport->SetIceRole(cricket::ICEROLE_CONTROLLING); + lport->SetIceTiebreaker(kTiebreaker1); + rport->SetIceRole(cricket::ICEROLE_CONTROLLED); + rport->SetIceTiebreaker(kTiebreaker2); + + // Send a fake ping from lport to rport. + lport->PrepareAddress(); + rport->PrepareAddress(); + ASSERT_FALSE(rport->Candidates().empty()); + Connection* lconn = lport->CreateConnection( + rport->Candidates()[0], Port::ORIGIN_MESSAGE); + lconn->Ping(0); + ASSERT_TRUE_WAIT(lport->last_stun_msg() != NULL, 1000); + IceMessage* msg = lport->last_stun_msg(); + const StunUInt64Attribute* ice_controlling_attr = + msg->GetUInt64(STUN_ATTR_ICE_CONTROLLING); + ASSERT_TRUE(ice_controlling_attr != NULL); + const StunByteStringAttribute* use_candidate_attr = msg->GetByteString( + STUN_ATTR_USE_CANDIDATE); + ASSERT_TRUE(use_candidate_attr != NULL); +} + +// Test handling STUN messages. +TEST_F(PortTest, TestHandleStunMessage) { + // Our port will act as the "remote" port. + rtc::scoped_ptr<TestPort> port( + CreateTestPort(kLocalAddr2, "rfrag", "rpass")); + + rtc::scoped_ptr<IceMessage> in_msg, out_msg; + rtc::scoped_ptr<ByteBuffer> buf(new ByteBuffer()); + rtc::SocketAddress addr(kLocalAddr1); + std::string username; + + // BINDING-REQUEST from local to remote with valid ICE username, + // MESSAGE-INTEGRITY, and FINGERPRINT. + in_msg.reset(CreateStunMessageWithUsername(STUN_BINDING_REQUEST, + "rfrag:lfrag")); + in_msg->AddMessageIntegrity("rpass"); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() != NULL); + EXPECT_EQ("lfrag", username); + + // BINDING-RESPONSE without username, with MESSAGE-INTEGRITY and FINGERPRINT. + in_msg.reset(CreateStunMessage(STUN_BINDING_RESPONSE)); + in_msg->AddAttribute( + new StunXorAddressAttribute(STUN_ATTR_XOR_MAPPED_ADDRESS, kLocalAddr2)); + in_msg->AddMessageIntegrity("rpass"); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() != NULL); + EXPECT_EQ("", username); + + // BINDING-ERROR-RESPONSE without username, with error, M-I, and FINGERPRINT. + in_msg.reset(CreateStunMessage(STUN_BINDING_ERROR_RESPONSE)); + in_msg->AddAttribute(new StunErrorCodeAttribute(STUN_ATTR_ERROR_CODE, + STUN_ERROR_SERVER_ERROR, STUN_ERROR_REASON_SERVER_ERROR)); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() != NULL); + EXPECT_EQ("", username); + ASSERT_TRUE(out_msg->GetErrorCode() != NULL); + EXPECT_EQ(STUN_ERROR_SERVER_ERROR, out_msg->GetErrorCode()->code()); + EXPECT_EQ(std::string(STUN_ERROR_REASON_SERVER_ERROR), + out_msg->GetErrorCode()->reason()); +} + +// Tests handling of ICE binding requests with missing or incorrect usernames. +TEST_F(PortTest, TestHandleStunMessageBadUsername) { + rtc::scoped_ptr<TestPort> port( + CreateTestPort(kLocalAddr2, "rfrag", "rpass")); + + rtc::scoped_ptr<IceMessage> in_msg, out_msg; + rtc::scoped_ptr<ByteBuffer> buf(new ByteBuffer()); + rtc::SocketAddress addr(kLocalAddr1); + std::string username; + + // BINDING-REQUEST with no username. + in_msg.reset(CreateStunMessage(STUN_BINDING_REQUEST)); + in_msg->AddMessageIntegrity("rpass"); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() == NULL); + EXPECT_EQ("", username); + EXPECT_EQ(STUN_ERROR_BAD_REQUEST, port->last_stun_error_code()); + + // BINDING-REQUEST with empty username. + in_msg.reset(CreateStunMessageWithUsername(STUN_BINDING_REQUEST, "")); + in_msg->AddMessageIntegrity("rpass"); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() == NULL); + EXPECT_EQ("", username); + EXPECT_EQ(STUN_ERROR_UNAUTHORIZED, port->last_stun_error_code()); + + // BINDING-REQUEST with too-short username. + in_msg.reset(CreateStunMessageWithUsername(STUN_BINDING_REQUEST, "rfra")); + in_msg->AddMessageIntegrity("rpass"); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() == NULL); + EXPECT_EQ("", username); + EXPECT_EQ(STUN_ERROR_UNAUTHORIZED, port->last_stun_error_code()); + + // BINDING-REQUEST with reversed username. + in_msg.reset(CreateStunMessageWithUsername(STUN_BINDING_REQUEST, + "lfrag:rfrag")); + in_msg->AddMessageIntegrity("rpass"); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() == NULL); + EXPECT_EQ("", username); + EXPECT_EQ(STUN_ERROR_UNAUTHORIZED, port->last_stun_error_code()); + + // BINDING-REQUEST with garbage username. + in_msg.reset(CreateStunMessageWithUsername(STUN_BINDING_REQUEST, + "abcd:efgh")); + in_msg->AddMessageIntegrity("rpass"); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() == NULL); + EXPECT_EQ("", username); + EXPECT_EQ(STUN_ERROR_UNAUTHORIZED, port->last_stun_error_code()); +} + +// Test handling STUN messages with missing or malformed M-I. +TEST_F(PortTest, TestHandleStunMessageBadMessageIntegrity) { + // Our port will act as the "remote" port. + rtc::scoped_ptr<TestPort> port( + CreateTestPort(kLocalAddr2, "rfrag", "rpass")); + + rtc::scoped_ptr<IceMessage> in_msg, out_msg; + rtc::scoped_ptr<ByteBuffer> buf(new ByteBuffer()); + rtc::SocketAddress addr(kLocalAddr1); + std::string username; + + // BINDING-REQUEST from local to remote with valid ICE username and + // FINGERPRINT, but no MESSAGE-INTEGRITY. + in_msg.reset(CreateStunMessageWithUsername(STUN_BINDING_REQUEST, + "rfrag:lfrag")); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() == NULL); + EXPECT_EQ("", username); + EXPECT_EQ(STUN_ERROR_BAD_REQUEST, port->last_stun_error_code()); + + // BINDING-REQUEST from local to remote with valid ICE username and + // FINGERPRINT, but invalid MESSAGE-INTEGRITY. + in_msg.reset(CreateStunMessageWithUsername(STUN_BINDING_REQUEST, + "rfrag:lfrag")); + in_msg->AddMessageIntegrity("invalid"); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() == NULL); + EXPECT_EQ("", username); + EXPECT_EQ(STUN_ERROR_UNAUTHORIZED, port->last_stun_error_code()); + + // TODO: BINDING-RESPONSES and BINDING-ERROR-RESPONSES are checked + // by the Connection, not the Port, since they require the remote username. + // Change this test to pass in data via Connection::OnReadPacket instead. +} + +// Test handling STUN messages with missing or malformed FINGERPRINT. +TEST_F(PortTest, TestHandleStunMessageBadFingerprint) { + // Our port will act as the "remote" port. + rtc::scoped_ptr<TestPort> port( + CreateTestPort(kLocalAddr2, "rfrag", "rpass")); + + rtc::scoped_ptr<IceMessage> in_msg, out_msg; + rtc::scoped_ptr<ByteBuffer> buf(new ByteBuffer()); + rtc::SocketAddress addr(kLocalAddr1); + std::string username; + + // BINDING-REQUEST from local to remote with valid ICE username and + // MESSAGE-INTEGRITY, but no FINGERPRINT; GetStunMessage should fail. + in_msg.reset(CreateStunMessageWithUsername(STUN_BINDING_REQUEST, + "rfrag:lfrag")); + in_msg->AddMessageIntegrity("rpass"); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_FALSE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_EQ(0, port->last_stun_error_code()); + + // Now, add a fingerprint, but munge the message so it's not valid. + in_msg->AddFingerprint(); + in_msg->SetTransactionID("TESTTESTBADD"); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_FALSE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_EQ(0, port->last_stun_error_code()); + + // Valid BINDING-RESPONSE, except no FINGERPRINT. + in_msg.reset(CreateStunMessage(STUN_BINDING_RESPONSE)); + in_msg->AddAttribute( + new StunXorAddressAttribute(STUN_ATTR_XOR_MAPPED_ADDRESS, kLocalAddr2)); + in_msg->AddMessageIntegrity("rpass"); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_FALSE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_EQ(0, port->last_stun_error_code()); + + // Now, add a fingerprint, but munge the message so it's not valid. + in_msg->AddFingerprint(); + in_msg->SetTransactionID("TESTTESTBADD"); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_FALSE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_EQ(0, port->last_stun_error_code()); + + // Valid BINDING-ERROR-RESPONSE, except no FINGERPRINT. + in_msg.reset(CreateStunMessage(STUN_BINDING_ERROR_RESPONSE)); + in_msg->AddAttribute(new StunErrorCodeAttribute(STUN_ATTR_ERROR_CODE, + STUN_ERROR_SERVER_ERROR, STUN_ERROR_REASON_SERVER_ERROR)); + in_msg->AddMessageIntegrity("rpass"); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_FALSE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_EQ(0, port->last_stun_error_code()); + + // Now, add a fingerprint, but munge the message so it's not valid. + in_msg->AddFingerprint(); + in_msg->SetTransactionID("TESTTESTBADD"); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_FALSE(port->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_EQ(0, port->last_stun_error_code()); +} + +// Test handling of STUN binding indication messages . STUN binding +// indications are allowed only to the connection which is in read mode. +TEST_F(PortTest, TestHandleStunBindingIndication) { + rtc::scoped_ptr<TestPort> lport( + CreateTestPort(kLocalAddr2, "lfrag", "lpass")); + lport->SetIceRole(cricket::ICEROLE_CONTROLLING); + lport->SetIceTiebreaker(kTiebreaker1); + + // Verifying encoding and decoding STUN indication message. + rtc::scoped_ptr<IceMessage> in_msg, out_msg; + rtc::scoped_ptr<ByteBuffer> buf(new ByteBuffer()); + rtc::SocketAddress addr(kLocalAddr1); + std::string username; + + in_msg.reset(CreateStunMessage(STUN_BINDING_INDICATION)); + in_msg->AddFingerprint(); + WriteStunMessage(in_msg.get(), buf.get()); + EXPECT_TRUE(lport->GetStunMessage(buf->Data(), buf->Length(), addr, + out_msg.accept(), &username)); + EXPECT_TRUE(out_msg.get() != NULL); + EXPECT_EQ(out_msg->type(), STUN_BINDING_INDICATION); + EXPECT_EQ("", username); + + // Verify connection can handle STUN indication and updates + // last_ping_received. + rtc::scoped_ptr<TestPort> rport( + CreateTestPort(kLocalAddr2, "rfrag", "rpass")); + rport->SetIceRole(cricket::ICEROLE_CONTROLLED); + rport->SetIceTiebreaker(kTiebreaker2); + + lport->PrepareAddress(); + rport->PrepareAddress(); + ASSERT_FALSE(lport->Candidates().empty()); + ASSERT_FALSE(rport->Candidates().empty()); + + Connection* lconn = lport->CreateConnection(rport->Candidates()[0], + Port::ORIGIN_MESSAGE); + Connection* rconn = rport->CreateConnection(lport->Candidates()[0], + Port::ORIGIN_MESSAGE); + rconn->Ping(0); + + ASSERT_TRUE_WAIT(rport->last_stun_msg() != NULL, 1000); + IceMessage* msg = rport->last_stun_msg(); + EXPECT_EQ(STUN_BINDING_REQUEST, msg->type()); + // Send rport binding request to lport. + lconn->OnReadPacket(rport->last_stun_buf()->Data(), + rport->last_stun_buf()->Length(), + rtc::PacketTime()); + ASSERT_TRUE_WAIT(lport->last_stun_msg() != NULL, 1000); + EXPECT_EQ(STUN_BINDING_RESPONSE, lport->last_stun_msg()->type()); + uint32_t last_ping_received1 = lconn->last_ping_received(); + + // Adding a delay of 100ms. + rtc::Thread::Current()->ProcessMessages(100); + // Pinging lconn using stun indication message. + lconn->OnReadPacket(buf->Data(), buf->Length(), rtc::PacketTime()); + uint32_t last_ping_received2 = lconn->last_ping_received(); + EXPECT_GT(last_ping_received2, last_ping_received1); +} + +TEST_F(PortTest, TestComputeCandidatePriority) { + rtc::scoped_ptr<TestPort> port( + CreateTestPort(kLocalAddr1, "name", "pass")); + port->set_type_preference(90); + port->set_component(177); + port->AddCandidateAddress(SocketAddress("192.168.1.4", 1234)); + port->AddCandidateAddress(SocketAddress("2001:db8::1234", 1234)); + port->AddCandidateAddress(SocketAddress("fc12:3456::1234", 1234)); + port->AddCandidateAddress(SocketAddress("::ffff:192.168.1.4", 1234)); + port->AddCandidateAddress(SocketAddress("::192.168.1.4", 1234)); + port->AddCandidateAddress(SocketAddress("2002::1234:5678", 1234)); + port->AddCandidateAddress(SocketAddress("2001::1234:5678", 1234)); + port->AddCandidateAddress(SocketAddress("fecf::1234:5678", 1234)); + port->AddCandidateAddress(SocketAddress("3ffe::1234:5678", 1234)); + // These should all be: + // (90 << 24) | ([rfc3484 pref value] << 8) | (256 - 177) + uint32_t expected_priority_v4 = 1509957199U; + uint32_t expected_priority_v6 = 1509959759U; + uint32_t expected_priority_ula = 1509962319U; + uint32_t expected_priority_v4mapped = expected_priority_v4; + uint32_t expected_priority_v4compat = 1509949775U; + uint32_t expected_priority_6to4 = 1509954639U; + uint32_t expected_priority_teredo = 1509952079U; + uint32_t expected_priority_sitelocal = 1509949775U; + uint32_t expected_priority_6bone = 1509949775U; + ASSERT_EQ(expected_priority_v4, port->Candidates()[0].priority()); + ASSERT_EQ(expected_priority_v6, port->Candidates()[1].priority()); + ASSERT_EQ(expected_priority_ula, port->Candidates()[2].priority()); + ASSERT_EQ(expected_priority_v4mapped, port->Candidates()[3].priority()); + ASSERT_EQ(expected_priority_v4compat, port->Candidates()[4].priority()); + ASSERT_EQ(expected_priority_6to4, port->Candidates()[5].priority()); + ASSERT_EQ(expected_priority_teredo, port->Candidates()[6].priority()); + ASSERT_EQ(expected_priority_sitelocal, port->Candidates()[7].priority()); + ASSERT_EQ(expected_priority_6bone, port->Candidates()[8].priority()); +} + +// In the case of shared socket, one port may be shared by local and stun. +// Test that candidates with different types will have different foundation. +TEST_F(PortTest, TestFoundation) { + rtc::scoped_ptr<TestPort> testport( + CreateTestPort(kLocalAddr1, "name", "pass")); + testport->AddCandidateAddress(kLocalAddr1, kLocalAddr1, + LOCAL_PORT_TYPE, + cricket::ICE_TYPE_PREFERENCE_HOST, false); + testport->AddCandidateAddress(kLocalAddr2, kLocalAddr1, + STUN_PORT_TYPE, + cricket::ICE_TYPE_PREFERENCE_SRFLX, true); + EXPECT_NE(testport->Candidates()[0].foundation(), + testport->Candidates()[1].foundation()); +} + +// This test verifies the foundation of different types of ICE candidates. +TEST_F(PortTest, TestCandidateFoundation) { + rtc::scoped_ptr<rtc::NATServer> nat_server( + CreateNatServer(kNatAddr1, NAT_OPEN_CONE)); + rtc::scoped_ptr<UDPPort> udpport1(CreateUdpPort(kLocalAddr1)); + udpport1->PrepareAddress(); + rtc::scoped_ptr<UDPPort> udpport2(CreateUdpPort(kLocalAddr1)); + udpport2->PrepareAddress(); + EXPECT_EQ(udpport1->Candidates()[0].foundation(), + udpport2->Candidates()[0].foundation()); + rtc::scoped_ptr<TCPPort> tcpport1(CreateTcpPort(kLocalAddr1)); + tcpport1->PrepareAddress(); + rtc::scoped_ptr<TCPPort> tcpport2(CreateTcpPort(kLocalAddr1)); + tcpport2->PrepareAddress(); + EXPECT_EQ(tcpport1->Candidates()[0].foundation(), + tcpport2->Candidates()[0].foundation()); + rtc::scoped_ptr<Port> stunport( + CreateStunPort(kLocalAddr1, nat_socket_factory1())); + stunport->PrepareAddress(); + ASSERT_EQ_WAIT(1U, stunport->Candidates().size(), kTimeout); + EXPECT_NE(tcpport1->Candidates()[0].foundation(), + stunport->Candidates()[0].foundation()); + EXPECT_NE(tcpport2->Candidates()[0].foundation(), + stunport->Candidates()[0].foundation()); + EXPECT_NE(udpport1->Candidates()[0].foundation(), + stunport->Candidates()[0].foundation()); + EXPECT_NE(udpport2->Candidates()[0].foundation(), + stunport->Candidates()[0].foundation()); + // Verify GTURN candidate foundation. + rtc::scoped_ptr<RelayPort> relayport( + CreateGturnPort(kLocalAddr1)); + relayport->AddServerAddress( + cricket::ProtocolAddress(kRelayUdpIntAddr, cricket::PROTO_UDP)); + relayport->PrepareAddress(); + ASSERT_EQ_WAIT(1U, relayport->Candidates().size(), kTimeout); + EXPECT_NE(udpport1->Candidates()[0].foundation(), + relayport->Candidates()[0].foundation()); + EXPECT_NE(udpport2->Candidates()[0].foundation(), + relayport->Candidates()[0].foundation()); + // Verifying TURN candidate foundation. + rtc::scoped_ptr<Port> turnport1(CreateTurnPort( + kLocalAddr1, nat_socket_factory1(), PROTO_UDP, PROTO_UDP)); + turnport1->PrepareAddress(); + ASSERT_EQ_WAIT(1U, turnport1->Candidates().size(), kTimeout); + EXPECT_NE(udpport1->Candidates()[0].foundation(), + turnport1->Candidates()[0].foundation()); + EXPECT_NE(udpport2->Candidates()[0].foundation(), + turnport1->Candidates()[0].foundation()); + EXPECT_NE(stunport->Candidates()[0].foundation(), + turnport1->Candidates()[0].foundation()); + rtc::scoped_ptr<Port> turnport2(CreateTurnPort( + kLocalAddr1, nat_socket_factory1(), PROTO_UDP, PROTO_UDP)); + turnport2->PrepareAddress(); + ASSERT_EQ_WAIT(1U, turnport2->Candidates().size(), kTimeout); + EXPECT_EQ(turnport1->Candidates()[0].foundation(), + turnport2->Candidates()[0].foundation()); + + // Running a second turn server, to get different base IP address. + SocketAddress kTurnUdpIntAddr2("99.99.98.4", STUN_SERVER_PORT); + SocketAddress kTurnUdpExtAddr2("99.99.98.5", 0); + TestTurnServer turn_server2( + rtc::Thread::Current(), kTurnUdpIntAddr2, kTurnUdpExtAddr2); + rtc::scoped_ptr<Port> turnport3(CreateTurnPort( + kLocalAddr1, nat_socket_factory1(), PROTO_UDP, PROTO_UDP, + kTurnUdpIntAddr2)); + turnport3->PrepareAddress(); + ASSERT_EQ_WAIT(1U, turnport3->Candidates().size(), kTimeout); + EXPECT_NE(turnport3->Candidates()[0].foundation(), + turnport2->Candidates()[0].foundation()); +} + +// This test verifies the related addresses of different types of +// ICE candiates. +TEST_F(PortTest, TestCandidateRelatedAddress) { + rtc::scoped_ptr<rtc::NATServer> nat_server( + CreateNatServer(kNatAddr1, NAT_OPEN_CONE)); + rtc::scoped_ptr<UDPPort> udpport(CreateUdpPort(kLocalAddr1)); + udpport->PrepareAddress(); + // For UDPPort, related address will be empty. + EXPECT_TRUE(udpport->Candidates()[0].related_address().IsNil()); + // Testing related address for stun candidates. + // For stun candidate related address must be equal to the base + // socket address. + rtc::scoped_ptr<StunPort> stunport( + CreateStunPort(kLocalAddr1, nat_socket_factory1())); + stunport->PrepareAddress(); + ASSERT_EQ_WAIT(1U, stunport->Candidates().size(), kTimeout); + // Check STUN candidate address. + EXPECT_EQ(stunport->Candidates()[0].address().ipaddr(), + kNatAddr1.ipaddr()); + // Check STUN candidate related address. + EXPECT_EQ(stunport->Candidates()[0].related_address(), + stunport->GetLocalAddress()); + // Verifying the related address for the GTURN candidates. + // NOTE: In case of GTURN related address will be equal to the mapped + // address, but address(mapped) will not be XOR. + rtc::scoped_ptr<RelayPort> relayport( + CreateGturnPort(kLocalAddr1)); + relayport->AddServerAddress( + cricket::ProtocolAddress(kRelayUdpIntAddr, cricket::PROTO_UDP)); + relayport->PrepareAddress(); + ASSERT_EQ_WAIT(1U, relayport->Candidates().size(), kTimeout); + // For Gturn related address is set to "0.0.0.0:0" + EXPECT_EQ(rtc::SocketAddress(), + relayport->Candidates()[0].related_address()); + // Verifying the related address for TURN candidate. + // For TURN related address must be equal to the mapped address. + rtc::scoped_ptr<Port> turnport(CreateTurnPort( + kLocalAddr1, nat_socket_factory1(), PROTO_UDP, PROTO_UDP)); + turnport->PrepareAddress(); + ASSERT_EQ_WAIT(1U, turnport->Candidates().size(), kTimeout); + EXPECT_EQ(kTurnUdpExtAddr.ipaddr(), + turnport->Candidates()[0].address().ipaddr()); + EXPECT_EQ(kNatAddr1.ipaddr(), + turnport->Candidates()[0].related_address().ipaddr()); +} + +// Test priority value overflow handling when preference is set to 3. +TEST_F(PortTest, TestCandidatePriority) { + cricket::Candidate cand1; + cand1.set_priority(3); + cricket::Candidate cand2; + cand2.set_priority(1); + EXPECT_TRUE(cand1.priority() > cand2.priority()); +} + +// Test the Connection priority is calculated correctly. +TEST_F(PortTest, TestConnectionPriority) { + rtc::scoped_ptr<TestPort> lport( + CreateTestPort(kLocalAddr1, "lfrag", "lpass")); + lport->set_type_preference(cricket::ICE_TYPE_PREFERENCE_HOST); + rtc::scoped_ptr<TestPort> rport( + CreateTestPort(kLocalAddr2, "rfrag", "rpass")); + rport->set_type_preference(cricket::ICE_TYPE_PREFERENCE_RELAY); + lport->set_component(123); + lport->AddCandidateAddress(SocketAddress("192.168.1.4", 1234)); + rport->set_component(23); + rport->AddCandidateAddress(SocketAddress("10.1.1.100", 1234)); + + EXPECT_EQ(0x7E001E85U, lport->Candidates()[0].priority()); + EXPECT_EQ(0x2001EE9U, rport->Candidates()[0].priority()); + + // RFC 5245 + // pair priority = 2^32*MIN(G,D) + 2*MAX(G,D) + (G>D?1:0) + lport->SetIceRole(cricket::ICEROLE_CONTROLLING); + rport->SetIceRole(cricket::ICEROLE_CONTROLLED); + Connection* lconn = lport->CreateConnection( + rport->Candidates()[0], Port::ORIGIN_MESSAGE); +#if defined(WEBRTC_WIN) + EXPECT_EQ(0x2001EE9FC003D0BU, lconn->priority()); +#else + EXPECT_EQ(0x2001EE9FC003D0BLLU, lconn->priority()); +#endif + + lport->SetIceRole(cricket::ICEROLE_CONTROLLED); + rport->SetIceRole(cricket::ICEROLE_CONTROLLING); + Connection* rconn = rport->CreateConnection( + lport->Candidates()[0], Port::ORIGIN_MESSAGE); +#if defined(WEBRTC_WIN) + EXPECT_EQ(0x2001EE9FC003D0AU, rconn->priority()); +#else + EXPECT_EQ(0x2001EE9FC003D0ALLU, rconn->priority()); +#endif +} + +TEST_F(PortTest, TestWritableState) { + UDPPort* port1 = CreateUdpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + UDPPort* port2 = CreateUdpPort(kLocalAddr2); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + + // Set up channels. + TestChannel ch1(port1); + TestChannel ch2(port2); + + // Acquire addresses. + ch1.Start(); + ch2.Start(); + ASSERT_EQ_WAIT(1, ch1.complete_count(), kTimeout); + ASSERT_EQ_WAIT(1, ch2.complete_count(), kTimeout); + + // Send a ping from src to dst. + ch1.CreateConnection(GetCandidate(port2)); + ASSERT_TRUE(ch1.conn() != NULL); + EXPECT_EQ(Connection::STATE_WRITE_INIT, ch1.conn()->write_state()); + EXPECT_TRUE_WAIT(ch1.conn()->connected(), kTimeout); // for TCP connect + ch1.Ping(); + WAIT(!ch2.remote_address().IsNil(), kTimeout); + + // Data should be unsendable until the connection is accepted. + char data[] = "abcd"; + int data_size = ARRAY_SIZE(data); + rtc::PacketOptions options; + EXPECT_EQ(SOCKET_ERROR, ch1.conn()->Send(data, data_size, options)); + + // Accept the connection to return the binding response, transition to + // writable, and allow data to be sent. + ch2.AcceptConnection(GetCandidate(port1)); + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, ch1.conn()->write_state(), + kTimeout); + EXPECT_EQ(data_size, ch1.conn()->Send(data, data_size, options)); + + // Ask the connection to update state as if enough time has passed to lose + // full writability and 5 pings went unresponded to. We'll accomplish the + // latter by sending pings but not pumping messages. + for (uint32_t i = 1; i <= CONNECTION_WRITE_CONNECT_FAILURES; ++i) { + ch1.Ping(i); + } + uint32_t unreliable_timeout_delay = CONNECTION_WRITE_CONNECT_TIMEOUT + 500u; + ch1.conn()->UpdateState(unreliable_timeout_delay); + EXPECT_EQ(Connection::STATE_WRITE_UNRELIABLE, ch1.conn()->write_state()); + + // Data should be able to be sent in this state. + EXPECT_EQ(data_size, ch1.conn()->Send(data, data_size, options)); + + // And now allow the other side to process the pings and send binding + // responses. + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, ch1.conn()->write_state(), + kTimeout); + + // Wait long enough for a full timeout (past however long we've already + // waited). + for (uint32_t i = 1; i <= CONNECTION_WRITE_CONNECT_FAILURES; ++i) { + ch1.Ping(unreliable_timeout_delay + i); + } + ch1.conn()->UpdateState(unreliable_timeout_delay + CONNECTION_WRITE_TIMEOUT + + 500u); + EXPECT_EQ(Connection::STATE_WRITE_TIMEOUT, ch1.conn()->write_state()); + + // Now that the connection has completely timed out, data send should fail. + EXPECT_EQ(SOCKET_ERROR, ch1.conn()->Send(data, data_size, options)); + + ch1.Stop(); + ch2.Stop(); +} + +TEST_F(PortTest, TestTimeoutForNeverWritable) { + UDPPort* port1 = CreateUdpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + UDPPort* port2 = CreateUdpPort(kLocalAddr2); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + + // Set up channels. + TestChannel ch1(port1); + TestChannel ch2(port2); + + // Acquire addresses. + ch1.Start(); + ch2.Start(); + + ch1.CreateConnection(GetCandidate(port2)); + ASSERT_TRUE(ch1.conn() != NULL); + EXPECT_EQ(Connection::STATE_WRITE_INIT, ch1.conn()->write_state()); + + // Attempt to go directly to write timeout. + for (uint32_t i = 1; i <= CONNECTION_WRITE_CONNECT_FAILURES; ++i) { + ch1.Ping(i); + } + ch1.conn()->UpdateState(CONNECTION_WRITE_TIMEOUT + 500u); + EXPECT_EQ(Connection::STATE_WRITE_TIMEOUT, ch1.conn()->write_state()); +} + +// This test verifies the connection setup between ICEMODE_FULL +// and ICEMODE_LITE. +// In this test |ch1| behaves like FULL mode client and we have created +// port which responds to the ping message just like LITE client. +TEST_F(PortTest, TestIceLiteConnectivity) { + TestPort* ice_full_port = CreateTestPort( + kLocalAddr1, "lfrag", "lpass", + cricket::ICEROLE_CONTROLLING, kTiebreaker1); + + rtc::scoped_ptr<TestPort> ice_lite_port(CreateTestPort( + kLocalAddr2, "rfrag", "rpass", + cricket::ICEROLE_CONTROLLED, kTiebreaker2)); + // Setup TestChannel. This behaves like FULL mode client. + TestChannel ch1(ice_full_port); + ch1.SetIceMode(ICEMODE_FULL); + + // Start gathering candidates. + ch1.Start(); + ice_lite_port->PrepareAddress(); + + ASSERT_EQ_WAIT(1, ch1.complete_count(), kTimeout); + ASSERT_FALSE(ice_lite_port->Candidates().empty()); + + ch1.CreateConnection(GetCandidate(ice_lite_port.get())); + ASSERT_TRUE(ch1.conn() != NULL); + EXPECT_EQ(Connection::STATE_WRITE_INIT, ch1.conn()->write_state()); + + // Send ping from full mode client. + // This ping must not have USE_CANDIDATE_ATTR. + ch1.Ping(); + + // Verify stun ping is without USE_CANDIDATE_ATTR. Getting message directly + // from port. + ASSERT_TRUE_WAIT(ice_full_port->last_stun_msg() != NULL, 1000); + IceMessage* msg = ice_full_port->last_stun_msg(); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_USE_CANDIDATE) == NULL); + + // Respond with a BINDING-RESPONSE from litemode client. + // NOTE: Ideally we should't create connection at this stage from lite + // port, as it should be done only after receiving ping with USE_CANDIDATE. + // But we need a connection to send a response message. + ice_lite_port->CreateConnection( + ice_full_port->Candidates()[0], cricket::Port::ORIGIN_MESSAGE); + rtc::scoped_ptr<IceMessage> request(CopyStunMessage(msg)); + ice_lite_port->SendBindingResponse( + request.get(), ice_full_port->Candidates()[0].address()); + + // Feeding the respone message from litemode to the full mode connection. + ch1.conn()->OnReadPacket(ice_lite_port->last_stun_buf()->Data(), + ice_lite_port->last_stun_buf()->Length(), + rtc::PacketTime()); + // Verifying full mode connection becomes writable from the response. + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, ch1.conn()->write_state(), + kTimeout); + EXPECT_TRUE_WAIT(ch1.nominated(), kTimeout); + + // Clear existing stun messsages. Otherwise we will process old stun + // message right after we send ping. + ice_full_port->Reset(); + // Send ping. This must have USE_CANDIDATE_ATTR. + ch1.Ping(); + ASSERT_TRUE_WAIT(ice_full_port->last_stun_msg() != NULL, 1000); + msg = ice_full_port->last_stun_msg(); + EXPECT_TRUE(msg->GetByteString(STUN_ATTR_USE_CANDIDATE) != NULL); + ch1.Stop(); +} + +// This test case verifies that the CONTROLLING port does not time out. +TEST_F(PortTest, TestControllingNoTimeout) { + UDPPort* port1 = CreateUdpPort(kLocalAddr1); + ConnectToSignalDestroyed(port1); + port1->set_timeout_delay(10); // milliseconds + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + port1->SetIceTiebreaker(kTiebreaker1); + + UDPPort* port2 = CreateUdpPort(kLocalAddr2); + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + port2->SetIceTiebreaker(kTiebreaker2); + + // Set up channels and ensure both ports will be deleted. + TestChannel ch1(port1); + TestChannel ch2(port2); + + // Simulate a connection that succeeds, and then is destroyed. + StartConnectAndStopChannels(&ch1, &ch2); + + // After the connection is destroyed, the port should not be destroyed. + rtc::Thread::Current()->ProcessMessages(kTimeout); + EXPECT_FALSE(destroyed()); +} + +// This test case verifies that the CONTROLLED port does time out, but only +// after connectivity is lost. +TEST_F(PortTest, TestControlledTimeout) { + UDPPort* port1 = CreateUdpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + port1->SetIceTiebreaker(kTiebreaker1); + + UDPPort* port2 = CreateUdpPort(kLocalAddr2); + ConnectToSignalDestroyed(port2); + port2->set_timeout_delay(10); // milliseconds + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + port2->SetIceTiebreaker(kTiebreaker2); + + // The connection must not be destroyed before a connection is attempted. + EXPECT_FALSE(destroyed()); + + port1->set_component(cricket::ICE_CANDIDATE_COMPONENT_DEFAULT); + port2->set_component(cricket::ICE_CANDIDATE_COMPONENT_DEFAULT); + + // Set up channels and ensure both ports will be deleted. + TestChannel ch1(port1); + TestChannel ch2(port2); + + // Simulate a connection that succeeds, and then is destroyed. + StartConnectAndStopChannels(&ch1, &ch2); + + // The controlled port should be destroyed after 10 milliseconds. + EXPECT_TRUE_WAIT(destroyed(), kTimeout); +} + +// This test case verifies that if the role of a port changes from controlled +// to controlling after all connections fail, the port will not be destroyed. +TEST_F(PortTest, TestControlledToControllingNotDestroyed) { + UDPPort* port1 = CreateUdpPort(kLocalAddr1); + port1->SetIceRole(cricket::ICEROLE_CONTROLLING); + port1->SetIceTiebreaker(kTiebreaker1); + + UDPPort* port2 = CreateUdpPort(kLocalAddr2); + ConnectToSignalDestroyed(port2); + port2->set_timeout_delay(10); // milliseconds + port2->SetIceRole(cricket::ICEROLE_CONTROLLED); + port2->SetIceTiebreaker(kTiebreaker2); + + // The connection must not be destroyed before a connection is attempted. + EXPECT_FALSE(destroyed()); + + port1->set_component(cricket::ICE_CANDIDATE_COMPONENT_DEFAULT); + port2->set_component(cricket::ICE_CANDIDATE_COMPONENT_DEFAULT); + + // Set up channels and ensure both ports will be deleted. + TestChannel ch1(port1); + TestChannel ch2(port2); + + // Simulate a connection that succeeds, and then is destroyed. + StartConnectAndStopChannels(&ch1, &ch2); + // Switch the role after all connections are destroyed. + EXPECT_TRUE_WAIT(ch2.conn() == nullptr, kTimeout); + port1->SetIceRole(cricket::ICEROLE_CONTROLLED); + port2->SetIceRole(cricket::ICEROLE_CONTROLLING); + + // After the connection is destroyed, the port should not be destroyed. + rtc::Thread::Current()->ProcessMessages(kTimeout); + EXPECT_FALSE(destroyed()); +} diff --git a/webrtc/p2p/base/portallocator.cc b/webrtc/p2p/base/portallocator.cc new file mode 100644 index 0000000000..5c4243abf6 --- /dev/null +++ b/webrtc/p2p/base/portallocator.cc @@ -0,0 +1,40 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/base/checks.h" +#include "webrtc/p2p/base/portallocator.h" + +namespace cricket { + +PortAllocatorSession::PortAllocatorSession(const std::string& content_name, + int component, + const std::string& ice_ufrag, + const std::string& ice_pwd, + uint32_t flags) + : content_name_(content_name), + component_(component), + flags_(flags), + generation_(0), + ice_ufrag_(ice_ufrag), + ice_pwd_(ice_pwd) { + RTC_DCHECK(!ice_ufrag.empty()); + RTC_DCHECK(!ice_pwd.empty()); +} + +PortAllocatorSession* PortAllocator::CreateSession( + const std::string& sid, + const std::string& content_name, + int component, + const std::string& ice_ufrag, + const std::string& ice_pwd) { + return CreateSessionInternal(content_name, component, ice_ufrag, ice_pwd); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/portallocator.h b/webrtc/p2p/base/portallocator.h new file mode 100644 index 0000000000..4f8ec2fbe6 --- /dev/null +++ b/webrtc/p2p/base/portallocator.h @@ -0,0 +1,209 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_PORTALLOCATOR_H_ +#define WEBRTC_P2P_BASE_PORTALLOCATOR_H_ + +#include <string> +#include <vector> + +#include "webrtc/p2p/base/portinterface.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/proxyinfo.h" +#include "webrtc/base/sigslot.h" + +namespace cricket { + +// PortAllocator is responsible for allocating Port types for a given +// P2PSocket. It also handles port freeing. +// +// Clients can override this class to control port allocation, including +// what kinds of ports are allocated. + +enum { + // Disable local UDP ports. This doesn't impact how we connect to relay + // servers. + PORTALLOCATOR_DISABLE_UDP = 0x01, + PORTALLOCATOR_DISABLE_STUN = 0x02, + PORTALLOCATOR_DISABLE_RELAY = 0x04, + // Disable local TCP ports. This doesn't impact how we connect to relay + // servers. + PORTALLOCATOR_DISABLE_TCP = 0x08, + PORTALLOCATOR_ENABLE_SHAKER = 0x10, + PORTALLOCATOR_ENABLE_IPV6 = 0x40, + // TODO(pthatcher): Remove this once it's no longer used in: + // remoting/client/plugin/pepper_port_allocator.cc + // remoting/protocol/chromium_port_allocator.cc + // remoting/test/fake_port_allocator.cc + // It's a no-op and is no longer needed. + PORTALLOCATOR_ENABLE_SHARED_UFRAG = 0x80, + PORTALLOCATOR_ENABLE_SHARED_SOCKET = 0x100, + PORTALLOCATOR_ENABLE_STUN_RETRANSMIT_ATTRIBUTE = 0x200, + PORTALLOCATOR_DISABLE_ADAPTER_ENUMERATION = 0x400, + // When specified, a loopback candidate will be generated if + // PORTALLOCATOR_DISABLE_ADAPTER_ENUMERATION is specified. + PORTALLOCATOR_ENABLE_LOCALHOST_CANDIDATE = 0x800, + // Disallow use of UDP when connecting to a relay server. Since proxy servers + // usually don't handle UDP, using UDP will leak the IP address. + PORTALLOCATOR_DISABLE_UDP_RELAY = 0x1000, +}; + +const uint32_t kDefaultPortAllocatorFlags = 0; + +const uint32_t kDefaultStepDelay = 1000; // 1 sec step delay. +// As per RFC 5245 Appendix B.1, STUN transactions need to be paced at certain +// internal. Less than 20ms is not acceptable. We choose 50ms as our default. +const uint32_t kMinimumStepDelay = 50; + +// CF = CANDIDATE FILTER +enum { + CF_NONE = 0x0, + CF_HOST = 0x1, + CF_REFLEXIVE = 0x2, + CF_RELAY = 0x4, + CF_ALL = 0x7, +}; + +class PortAllocatorSession : public sigslot::has_slots<> { + public: + // Content name passed in mostly for logging and debugging. + PortAllocatorSession(const std::string& content_name, + int component, + const std::string& ice_ufrag, + const std::string& ice_pwd, + uint32_t flags); + + // Subclasses should clean up any ports created. + virtual ~PortAllocatorSession() {} + + uint32_t flags() const { return flags_; } + void set_flags(uint32_t flags) { flags_ = flags; } + std::string content_name() const { return content_name_; } + int component() const { return component_; } + + // Starts gathering STUN and Relay configurations. + virtual void StartGettingPorts() = 0; + virtual void StopGettingPorts() = 0; + // Only stop the existing gathering process but may start new ones if needed. + virtual void ClearGettingPorts() = 0; + // Whether the process of getting ports has been stopped. + virtual bool IsGettingPorts() = 0; + + sigslot::signal2<PortAllocatorSession*, PortInterface*> SignalPortReady; + sigslot::signal2<PortAllocatorSession*, + const std::vector<Candidate>&> SignalCandidatesReady; + sigslot::signal1<PortAllocatorSession*> SignalCandidatesAllocationDone; + + virtual uint32_t generation() { return generation_; } + virtual void set_generation(uint32_t generation) { generation_ = generation; } + sigslot::signal1<PortAllocatorSession*> SignalDestroyed; + + const std::string& ice_ufrag() const { return ice_ufrag_; } + const std::string& ice_pwd() const { return ice_pwd_; } + + protected: + // TODO(deadbeef): Get rid of these when everyone switches to ice_ufrag and + // ice_pwd. + const std::string& username() const { return ice_ufrag_; } + const std::string& password() const { return ice_pwd_; } + + std::string content_name_; + int component_; + + private: + uint32_t flags_; + uint32_t generation_; + std::string ice_ufrag_; + std::string ice_pwd_; +}; + +class PortAllocator : public sigslot::has_slots<> { + public: + PortAllocator() : + flags_(kDefaultPortAllocatorFlags), + min_port_(0), + max_port_(0), + step_delay_(kDefaultStepDelay), + allow_tcp_listen_(true), + candidate_filter_(CF_ALL) { + // This will allow us to have old behavior on non webrtc clients. + } + virtual ~PortAllocator() {} + + PortAllocatorSession* CreateSession( + const std::string& sid, + const std::string& content_name, + int component, + const std::string& ice_ufrag, + const std::string& ice_pwd); + + uint32_t flags() const { return flags_; } + void set_flags(uint32_t flags) { flags_ = flags; } + + const std::string& user_agent() const { return agent_; } + const rtc::ProxyInfo& proxy() const { return proxy_; } + void set_proxy(const std::string& agent, const rtc::ProxyInfo& proxy) { + agent_ = agent; + proxy_ = proxy; + } + + // Gets/Sets the port range to use when choosing client ports. + int min_port() const { return min_port_; } + int max_port() const { return max_port_; } + bool SetPortRange(int min_port, int max_port) { + if (min_port > max_port) { + return false; + } + + min_port_ = min_port; + max_port_ = max_port; + return true; + } + + uint32_t step_delay() const { return step_delay_; } + void set_step_delay(uint32_t delay) { step_delay_ = delay; } + + bool allow_tcp_listen() const { return allow_tcp_listen_; } + void set_allow_tcp_listen(bool allow_tcp_listen) { + allow_tcp_listen_ = allow_tcp_listen; + } + + uint32_t candidate_filter() { return candidate_filter_; } + bool set_candidate_filter(uint32_t filter) { + // TODO(mallinath) - Do transition check? + candidate_filter_ = filter; + return true; + } + + // Gets/Sets the Origin value used for WebRTC STUN requests. + const std::string& origin() const { return origin_; } + void set_origin(const std::string& origin) { origin_ = origin; } + + protected: + virtual PortAllocatorSession* CreateSessionInternal( + const std::string& content_name, + int component, + const std::string& ice_ufrag, + const std::string& ice_pwd) = 0; + + uint32_t flags_; + std::string agent_; + rtc::ProxyInfo proxy_; + int min_port_; + int max_port_; + uint32_t step_delay_; + bool allow_tcp_listen_; + uint32_t candidate_filter_; + std::string origin_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_PORTALLOCATOR_H_ diff --git a/webrtc/p2p/base/portinterface.h b/webrtc/p2p/base/portinterface.h new file mode 100644 index 0000000000..0f77036ac1 --- /dev/null +++ b/webrtc/p2p/base/portinterface.h @@ -0,0 +1,127 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_PORTINTERFACE_H_ +#define WEBRTC_P2P_BASE_PORTINTERFACE_H_ + +#include <string> + +#include "webrtc/p2p/base/transport.h" +#include "webrtc/base/asyncpacketsocket.h" +#include "webrtc/base/socketaddress.h" + +namespace rtc { +class Network; +struct PacketOptions; +} + +namespace cricket { +class Connection; +class IceMessage; +class StunMessage; + +enum ProtocolType { + PROTO_UDP, + PROTO_TCP, + PROTO_SSLTCP, + PROTO_LAST = PROTO_SSLTCP +}; + +// Defines the interface for a port, which represents a local communication +// mechanism that can be used to create connections to similar mechanisms of +// the other client. Various types of ports will implement this interface. +class PortInterface { + public: + virtual ~PortInterface() {} + + virtual const std::string& Type() const = 0; + virtual rtc::Network* Network() const = 0; + + // Methods to set/get ICE role and tiebreaker values. + virtual void SetIceRole(IceRole role) = 0; + virtual IceRole GetIceRole() const = 0; + + virtual void SetIceTiebreaker(uint64_t tiebreaker) = 0; + virtual uint64_t IceTiebreaker() const = 0; + + virtual bool SharedSocket() const = 0; + + // PrepareAddress will attempt to get an address for this port that other + // clients can send to. It may take some time before the address is ready. + // Once it is ready, we will send SignalAddressReady. If errors are + // preventing the port from getting an address, it may send + // SignalAddressError. + virtual void PrepareAddress() = 0; + + // Returns the connection to the given address or NULL if none exists. + virtual Connection* GetConnection( + const rtc::SocketAddress& remote_addr) = 0; + + // Creates a new connection to the given address. + enum CandidateOrigin { ORIGIN_THIS_PORT, ORIGIN_OTHER_PORT, ORIGIN_MESSAGE }; + virtual Connection* CreateConnection( + const Candidate& remote_candidate, CandidateOrigin origin) = 0; + + // Functions on the underlying socket(s). + virtual int SetOption(rtc::Socket::Option opt, int value) = 0; + virtual int GetOption(rtc::Socket::Option opt, int* value) = 0; + virtual int GetError() = 0; + + virtual const std::vector<Candidate>& Candidates() const = 0; + + // Sends the given packet to the given address, provided that the address is + // that of a connection or an address that has sent to us already. + virtual int SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, bool payload) = 0; + + // Indicates that we received a successful STUN binding request from an + // address that doesn't correspond to any current connection. To turn this + // into a real connection, call CreateConnection. + sigslot::signal6<PortInterface*, const rtc::SocketAddress&, + ProtocolType, IceMessage*, const std::string&, + bool> SignalUnknownAddress; + + // Sends a response message (normal or error) to the given request. One of + // these methods should be called as a response to SignalUnknownAddress. + // NOTE: You MUST call CreateConnection BEFORE SendBindingResponse. + virtual void SendBindingResponse(StunMessage* request, + const rtc::SocketAddress& addr) = 0; + virtual void SendBindingErrorResponse( + StunMessage* request, const rtc::SocketAddress& addr, + int error_code, const std::string& reason) = 0; + + // Signaled when this port decides to delete itself because it no longer has + // any usefulness. + sigslot::signal1<PortInterface*> SignalDestroyed; + + // Signaled when Port discovers ice role conflict with the peer. + sigslot::signal1<PortInterface*> SignalRoleConflict; + + // Normally, packets arrive through a connection (or they result signaling of + // unknown address). Calling this method turns off delivery of packets + // through their respective connection and instead delivers every packet + // through this port. + virtual void EnablePortPackets() = 0; + sigslot::signal4<PortInterface*, const char*, size_t, + const rtc::SocketAddress&> SignalReadPacket; + + // Emitted each time a packet is sent on this port. + sigslot::signal2<PortInterface*, const rtc::SentPacket&> SignalSentPacket; + + virtual std::string ToString() const = 0; + + protected: + PortInterface() {} +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_PORTINTERFACE_H_ diff --git a/webrtc/p2p/base/pseudotcp.cc b/webrtc/p2p/base/pseudotcp.cc new file mode 100644 index 0000000000..5f035ca652 --- /dev/null +++ b/webrtc/p2p/base/pseudotcp.cc @@ -0,0 +1,1282 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/pseudotcp.h" + +#include <stdio.h> +#include <stdlib.h> + +#include <algorithm> +#include <set> + +#include "webrtc/base/basictypes.h" +#include "webrtc/base/bytebuffer.h" +#include "webrtc/base/byteorder.h" +#include "webrtc/base/common.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/socket.h" +#include "webrtc/base/stringutils.h" +#include "webrtc/base/timeutils.h" + +// The following logging is for detailed (packet-level) analysis only. +#define _DBG_NONE 0 +#define _DBG_NORMAL 1 +#define _DBG_VERBOSE 2 +#define _DEBUGMSG _DBG_NONE + +namespace cricket { + +////////////////////////////////////////////////////////////////////// +// Network Constants +////////////////////////////////////////////////////////////////////// + +// Standard MTUs +const uint16_t PACKET_MAXIMUMS[] = { + 65535, // Theoretical maximum, Hyperchannel + 32000, // Nothing + 17914, // 16Mb IBM Token Ring + 8166, // IEEE 802.4 + // 4464, // IEEE 802.5 (4Mb max) + 4352, // FDDI + // 2048, // Wideband Network + 2002, // IEEE 802.5 (4Mb recommended) + // 1536, // Expermental Ethernet Networks + // 1500, // Ethernet, Point-to-Point (default) + 1492, // IEEE 802.3 + 1006, // SLIP, ARPANET + // 576, // X.25 Networks + // 544, // DEC IP Portal + // 512, // NETBIOS + 508, // IEEE 802/Source-Rt Bridge, ARCNET + 296, // Point-to-Point (low delay) + // 68, // Official minimum + 0, // End of list marker +}; + +const uint32_t MAX_PACKET = 65535; +// Note: we removed lowest level because packet overhead was larger! +const uint32_t MIN_PACKET = 296; + +const uint32_t IP_HEADER_SIZE = 20; // (+ up to 40 bytes of options?) +const uint32_t UDP_HEADER_SIZE = 8; +// TODO: Make JINGLE_HEADER_SIZE transparent to this code? +const uint32_t JINGLE_HEADER_SIZE = 64; // when relay framing is in use + +// Default size for receive and send buffer. +const uint32_t DEFAULT_RCV_BUF_SIZE = 60 * 1024; +const uint32_t DEFAULT_SND_BUF_SIZE = 90 * 1024; + +////////////////////////////////////////////////////////////////////// +// Global Constants and Functions +////////////////////////////////////////////////////////////////////// +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 0 | Conversation Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 4 | Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 8 | Acknowledgment Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | |U|A|P|R|S|F| | +// 12 | Control | |R|C|S|S|Y|I| Window | +// | | |G|K|H|T|N|N| | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 16 | Timestamp sending | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 20 | Timestamp receiving | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 24 | data | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// +////////////////////////////////////////////////////////////////////// + +#define PSEUDO_KEEPALIVE 0 + +const uint32_t HEADER_SIZE = 24; +const uint32_t PACKET_OVERHEAD = + HEADER_SIZE + UDP_HEADER_SIZE + IP_HEADER_SIZE + JINGLE_HEADER_SIZE; + +const uint32_t MIN_RTO = + 250; // 250 ms (RFC1122, Sec 4.2.3.1 "fractions of a second") +const uint32_t DEF_RTO = 3000; // 3 seconds (RFC1122, Sec 4.2.3.1) +const uint32_t MAX_RTO = 60000; // 60 seconds +const uint32_t DEF_ACK_DELAY = 100; // 100 milliseconds + +const uint8_t FLAG_CTL = 0x02; +const uint8_t FLAG_RST = 0x04; + +const uint8_t CTL_CONNECT = 0; + +// TCP options. +const uint8_t TCP_OPT_EOL = 0; // End of list. +const uint8_t TCP_OPT_NOOP = 1; // No-op. +const uint8_t TCP_OPT_MSS = 2; // Maximum segment size. +const uint8_t TCP_OPT_WND_SCALE = 3; // Window scale factor. + +const long DEFAULT_TIMEOUT = 4000; // If there are no pending clocks, wake up every 4 seconds +const long CLOSED_TIMEOUT = 60 * 1000; // If the connection is closed, once per minute + +#if PSEUDO_KEEPALIVE +// !?! Rethink these times +const uint32_t IDLE_PING = + 20 * + 1000; // 20 seconds (note: WinXP SP2 firewall udp timeout is 90 seconds) +const uint32_t IDLE_TIMEOUT = 90 * 1000; // 90 seconds; +#endif // PSEUDO_KEEPALIVE + +////////////////////////////////////////////////////////////////////// +// Helper Functions +////////////////////////////////////////////////////////////////////// + +inline void long_to_bytes(uint32_t val, void* buf) { + *static_cast<uint32_t*>(buf) = rtc::HostToNetwork32(val); +} + +inline void short_to_bytes(uint16_t val, void* buf) { + *static_cast<uint16_t*>(buf) = rtc::HostToNetwork16(val); +} + +inline uint32_t bytes_to_long(const void* buf) { + return rtc::NetworkToHost32(*static_cast<const uint32_t*>(buf)); +} + +inline uint16_t bytes_to_short(const void* buf) { + return rtc::NetworkToHost16(*static_cast<const uint16_t*>(buf)); +} + +uint32_t bound(uint32_t lower, uint32_t middle, uint32_t upper) { + return std::min(std::max(lower, middle), upper); +} + +////////////////////////////////////////////////////////////////////// +// Debugging Statistics +////////////////////////////////////////////////////////////////////// + +#if 0 // Not used yet + +enum Stat { + S_SENT_PACKET, // All packet sends + S_RESENT_PACKET, // All packet sends that are retransmits + S_RECV_PACKET, // All packet receives + S_RECV_NEW, // All packet receives that are too new + S_RECV_OLD, // All packet receives that are too old + S_NUM_STATS +}; + +const char* const STAT_NAMES[S_NUM_STATS] = { + "snt", + "snt-r", + "rcv" + "rcv-n", + "rcv-o" +}; + +int g_stats[S_NUM_STATS]; +inline void Incr(Stat s) { ++g_stats[s]; } +void ReportStats() { + char buffer[256]; + size_t len = 0; + for (int i = 0; i < S_NUM_STATS; ++i) { + len += rtc::sprintfn(buffer, ARRAY_SIZE(buffer), "%s%s:%d", + (i == 0) ? "" : ",", STAT_NAMES[i], g_stats[i]); + g_stats[i] = 0; + } + LOG(LS_INFO) << "Stats[" << buffer << "]"; +} + +#endif + +////////////////////////////////////////////////////////////////////// +// PseudoTcp +////////////////////////////////////////////////////////////////////// + +uint32_t PseudoTcp::Now() { +#if 0 // Use this to synchronize timers with logging timestamps (easier debug) + return rtc::TimeSince(StartTime()); +#else + return rtc::Time(); +#endif +} + +PseudoTcp::PseudoTcp(IPseudoTcpNotify* notify, uint32_t conv) + : m_notify(notify), + m_shutdown(SD_NONE), + m_error(0), + m_rbuf_len(DEFAULT_RCV_BUF_SIZE), + m_rbuf(m_rbuf_len), + m_sbuf_len(DEFAULT_SND_BUF_SIZE), + m_sbuf(m_sbuf_len) { + // Sanity check on buffer sizes (needed for OnTcpWriteable notification logic) + ASSERT(m_rbuf_len + MIN_PACKET < m_sbuf_len); + + uint32_t now = Now(); + + m_state = TCP_LISTEN; + m_conv = conv; + m_rcv_wnd = m_rbuf_len; + m_rwnd_scale = m_swnd_scale = 0; + m_snd_nxt = 0; + m_snd_wnd = 1; + m_snd_una = m_rcv_nxt = 0; + m_bReadEnable = true; + m_bWriteEnable = false; + m_t_ack = 0; + + m_msslevel = 0; + m_largest = 0; + ASSERT(MIN_PACKET > PACKET_OVERHEAD); + m_mss = MIN_PACKET - PACKET_OVERHEAD; + m_mtu_advise = MAX_PACKET; + + m_rto_base = 0; + + m_cwnd = 2 * m_mss; + m_ssthresh = m_rbuf_len; + m_lastrecv = m_lastsend = m_lasttraffic = now; + m_bOutgoing = false; + + m_dup_acks = 0; + m_recover = 0; + + m_ts_recent = m_ts_lastack = 0; + + m_rx_rto = DEF_RTO; + m_rx_srtt = m_rx_rttvar = 0; + + m_use_nagling = true; + m_ack_delay = DEF_ACK_DELAY; + m_support_wnd_scale = true; +} + +PseudoTcp::~PseudoTcp() { +} + +int PseudoTcp::Connect() { + if (m_state != TCP_LISTEN) { + m_error = EINVAL; + return -1; + } + + m_state = TCP_SYN_SENT; + LOG(LS_INFO) << "State: TCP_SYN_SENT"; + + queueConnectMessage(); + attemptSend(); + + return 0; +} + +void PseudoTcp::NotifyMTU(uint16_t mtu) { + m_mtu_advise = mtu; + if (m_state == TCP_ESTABLISHED) { + adjustMTU(); + } +} + +void PseudoTcp::NotifyClock(uint32_t now) { + if (m_state == TCP_CLOSED) + return; + + // Check if it's time to retransmit a segment + if (m_rto_base && (rtc::TimeDiff(m_rto_base + m_rx_rto, now) <= 0)) { + if (m_slist.empty()) { + ASSERT(false); + } else { + // Note: (m_slist.front().xmit == 0)) { + // retransmit segments +#if _DEBUGMSG >= _DBG_NORMAL + LOG(LS_INFO) << "timeout retransmit (rto: " << m_rx_rto + << ") (rto_base: " << m_rto_base + << ") (now: " << now + << ") (dup_acks: " << static_cast<unsigned>(m_dup_acks) + << ")"; +#endif // _DEBUGMSG + if (!transmit(m_slist.begin(), now)) { + closedown(ECONNABORTED); + return; + } + + uint32_t nInFlight = m_snd_nxt - m_snd_una; + m_ssthresh = std::max(nInFlight / 2, 2 * m_mss); + //LOG(LS_INFO) << "m_ssthresh: " << m_ssthresh << " nInFlight: " << nInFlight << " m_mss: " << m_mss; + m_cwnd = m_mss; + + // Back off retransmit timer. Note: the limit is lower when connecting. + uint32_t rto_limit = (m_state < TCP_ESTABLISHED) ? DEF_RTO : MAX_RTO; + m_rx_rto = std::min(rto_limit, m_rx_rto * 2); + m_rto_base = now; + } + } + + // Check if it's time to probe closed windows + if ((m_snd_wnd == 0) + && (rtc::TimeDiff(m_lastsend + m_rx_rto, now) <= 0)) { + if (rtc::TimeDiff(now, m_lastrecv) >= 15000) { + closedown(ECONNABORTED); + return; + } + + // probe the window + packet(m_snd_nxt - 1, 0, 0, 0); + m_lastsend = now; + + // back off retransmit timer + m_rx_rto = std::min(MAX_RTO, m_rx_rto * 2); + } + + // Check if it's time to send delayed acks + if (m_t_ack && (rtc::TimeDiff(m_t_ack + m_ack_delay, now) <= 0)) { + packet(m_snd_nxt, 0, 0, 0); + } + +#if PSEUDO_KEEPALIVE + // Check for idle timeout + if ((m_state == TCP_ESTABLISHED) && (TimeDiff(m_lastrecv + IDLE_TIMEOUT, now) <= 0)) { + closedown(ECONNABORTED); + return; + } + + // Check for ping timeout (to keep udp mapping open) + if ((m_state == TCP_ESTABLISHED) && (TimeDiff(m_lasttraffic + (m_bOutgoing ? IDLE_PING * 3/2 : IDLE_PING), now) <= 0)) { + packet(m_snd_nxt, 0, 0, 0); + } +#endif // PSEUDO_KEEPALIVE +} + +bool PseudoTcp::NotifyPacket(const char* buffer, size_t len) { + if (len > MAX_PACKET) { + LOG_F(WARNING) << "packet too large"; + return false; + } + return parse(reinterpret_cast<const uint8_t*>(buffer), uint32_t(len)); +} + +bool PseudoTcp::GetNextClock(uint32_t now, long& timeout) { + return clock_check(now, timeout); +} + +void PseudoTcp::GetOption(Option opt, int* value) { + if (opt == OPT_NODELAY) { + *value = m_use_nagling ? 0 : 1; + } else if (opt == OPT_ACKDELAY) { + *value = m_ack_delay; + } else if (opt == OPT_SNDBUF) { + *value = m_sbuf_len; + } else if (opt == OPT_RCVBUF) { + *value = m_rbuf_len; + } else { + ASSERT(false); + } +} +void PseudoTcp::SetOption(Option opt, int value) { + if (opt == OPT_NODELAY) { + m_use_nagling = value == 0; + } else if (opt == OPT_ACKDELAY) { + m_ack_delay = value; + } else if (opt == OPT_SNDBUF) { + ASSERT(m_state == TCP_LISTEN); + resizeSendBuffer(value); + } else if (opt == OPT_RCVBUF) { + ASSERT(m_state == TCP_LISTEN); + resizeReceiveBuffer(value); + } else { + ASSERT(false); + } +} + +uint32_t PseudoTcp::GetCongestionWindow() const { + return m_cwnd; +} + +uint32_t PseudoTcp::GetBytesInFlight() const { + return m_snd_nxt - m_snd_una; +} + +uint32_t PseudoTcp::GetBytesBufferedNotSent() const { + size_t buffered_bytes = 0; + m_sbuf.GetBuffered(&buffered_bytes); + return static_cast<uint32_t>(m_snd_una + buffered_bytes - m_snd_nxt); +} + +uint32_t PseudoTcp::GetRoundTripTimeEstimateMs() const { + return m_rx_srtt; +} + +// +// IPStream Implementation +// + +int PseudoTcp::Recv(char* buffer, size_t len) { + if (m_state != TCP_ESTABLISHED) { + m_error = ENOTCONN; + return SOCKET_ERROR; + } + + size_t read = 0; + rtc::StreamResult result = m_rbuf.Read(buffer, len, &read, NULL); + + // If there's no data in |m_rbuf|. + if (result == rtc::SR_BLOCK) { + m_bReadEnable = true; + m_error = EWOULDBLOCK; + return SOCKET_ERROR; + } + ASSERT(result == rtc::SR_SUCCESS); + + size_t available_space = 0; + m_rbuf.GetWriteRemaining(&available_space); + + if (uint32_t(available_space) - m_rcv_wnd >= + std::min<uint32_t>(m_rbuf_len / 2, m_mss)) { + // TODO(jbeda): !?! Not sure about this was closed business + bool bWasClosed = (m_rcv_wnd == 0); + m_rcv_wnd = static_cast<uint32_t>(available_space); + + if (bWasClosed) { + attemptSend(sfImmediateAck); + } + } + + return static_cast<int>(read); +} + +int PseudoTcp::Send(const char* buffer, size_t len) { + if (m_state != TCP_ESTABLISHED) { + m_error = ENOTCONN; + return SOCKET_ERROR; + } + + size_t available_space = 0; + m_sbuf.GetWriteRemaining(&available_space); + + if (!available_space) { + m_bWriteEnable = true; + m_error = EWOULDBLOCK; + return SOCKET_ERROR; + } + + int written = queue(buffer, uint32_t(len), false); + attemptSend(); + return written; +} + +void PseudoTcp::Close(bool force) { + LOG_F(LS_VERBOSE) << "(" << (force ? "true" : "false") << ")"; + m_shutdown = force ? SD_FORCEFUL : SD_GRACEFUL; +} + +int PseudoTcp::GetError() { + return m_error; +} + +// +// Internal Implementation +// + +uint32_t PseudoTcp::queue(const char* data, uint32_t len, bool bCtrl) { + size_t available_space = 0; + m_sbuf.GetWriteRemaining(&available_space); + + if (len > static_cast<uint32_t>(available_space)) { + ASSERT(!bCtrl); + len = static_cast<uint32_t>(available_space); + } + + // We can concatenate data if the last segment is the same type + // (control v. regular data), and has not been transmitted yet + if (!m_slist.empty() && (m_slist.back().bCtrl == bCtrl) && + (m_slist.back().xmit == 0)) { + m_slist.back().len += len; + } else { + size_t snd_buffered = 0; + m_sbuf.GetBuffered(&snd_buffered); + SSegment sseg(static_cast<uint32_t>(m_snd_una + snd_buffered), len, bCtrl); + m_slist.push_back(sseg); + } + + size_t written = 0; + m_sbuf.Write(data, len, &written, NULL); + return static_cast<uint32_t>(written); +} + +IPseudoTcpNotify::WriteResult PseudoTcp::packet(uint32_t seq, + uint8_t flags, + uint32_t offset, + uint32_t len) { + ASSERT(HEADER_SIZE + len <= MAX_PACKET); + + uint32_t now = Now(); + + rtc::scoped_ptr<uint8_t[]> buffer(new uint8_t[MAX_PACKET]); + long_to_bytes(m_conv, buffer.get()); + long_to_bytes(seq, buffer.get() + 4); + long_to_bytes(m_rcv_nxt, buffer.get() + 8); + buffer[12] = 0; + buffer[13] = flags; + short_to_bytes(static_cast<uint16_t>(m_rcv_wnd >> m_rwnd_scale), + buffer.get() + 14); + + // Timestamp computations + long_to_bytes(now, buffer.get() + 16); + long_to_bytes(m_ts_recent, buffer.get() + 20); + m_ts_lastack = m_rcv_nxt; + + if (len) { + size_t bytes_read = 0; + rtc::StreamResult result = m_sbuf.ReadOffset( + buffer.get() + HEADER_SIZE, len, offset, &bytes_read); + RTC_UNUSED(result); + ASSERT(result == rtc::SR_SUCCESS); + ASSERT(static_cast<uint32_t>(bytes_read) == len); + } + +#if _DEBUGMSG >= _DBG_VERBOSE + LOG(LS_INFO) << "<-- <CONV=" << m_conv + << "><FLG=" << static_cast<unsigned>(flags) + << "><SEQ=" << seq << ":" << seq + len + << "><ACK=" << m_rcv_nxt + << "><WND=" << m_rcv_wnd + << "><TS=" << (now % 10000) + << "><TSR=" << (m_ts_recent % 10000) + << "><LEN=" << len << ">"; +#endif // _DEBUGMSG + + IPseudoTcpNotify::WriteResult wres = m_notify->TcpWritePacket( + this, reinterpret_cast<char *>(buffer.get()), len + HEADER_SIZE); + // Note: When len is 0, this is an ACK packet. We don't read the return value for those, + // and thus we won't retry. So go ahead and treat the packet as a success (basically simulate + // as if it were dropped), which will prevent our timers from being messed up. + if ((wres != IPseudoTcpNotify::WR_SUCCESS) && (0 != len)) + return wres; + + m_t_ack = 0; + if (len > 0) { + m_lastsend = now; + } + m_lasttraffic = now; + m_bOutgoing = true; + + return IPseudoTcpNotify::WR_SUCCESS; +} + +bool PseudoTcp::parse(const uint8_t* buffer, uint32_t size) { + if (size < 12) + return false; + + Segment seg; + seg.conv = bytes_to_long(buffer); + seg.seq = bytes_to_long(buffer + 4); + seg.ack = bytes_to_long(buffer + 8); + seg.flags = buffer[13]; + seg.wnd = bytes_to_short(buffer + 14); + + seg.tsval = bytes_to_long(buffer + 16); + seg.tsecr = bytes_to_long(buffer + 20); + + seg.data = reinterpret_cast<const char *>(buffer) + HEADER_SIZE; + seg.len = size - HEADER_SIZE; + +#if _DEBUGMSG >= _DBG_VERBOSE + LOG(LS_INFO) << "--> <CONV=" << seg.conv + << "><FLG=" << static_cast<unsigned>(seg.flags) + << "><SEQ=" << seg.seq << ":" << seg.seq + seg.len + << "><ACK=" << seg.ack + << "><WND=" << seg.wnd + << "><TS=" << (seg.tsval % 10000) + << "><TSR=" << (seg.tsecr % 10000) + << "><LEN=" << seg.len << ">"; +#endif // _DEBUGMSG + + return process(seg); +} + +bool PseudoTcp::clock_check(uint32_t now, long& nTimeout) { + if (m_shutdown == SD_FORCEFUL) + return false; + + size_t snd_buffered = 0; + m_sbuf.GetBuffered(&snd_buffered); + if ((m_shutdown == SD_GRACEFUL) + && ((m_state != TCP_ESTABLISHED) + || ((snd_buffered == 0) && (m_t_ack == 0)))) { + return false; + } + + if (m_state == TCP_CLOSED) { + nTimeout = CLOSED_TIMEOUT; + return true; + } + + nTimeout = DEFAULT_TIMEOUT; + + if (m_t_ack) { + nTimeout = + std::min<int32_t>(nTimeout, rtc::TimeDiff(m_t_ack + m_ack_delay, now)); + } + if (m_rto_base) { + nTimeout = + std::min<int32_t>(nTimeout, rtc::TimeDiff(m_rto_base + m_rx_rto, now)); + } + if (m_snd_wnd == 0) { + nTimeout = + std::min<int32_t>(nTimeout, rtc::TimeDiff(m_lastsend + m_rx_rto, now)); + } +#if PSEUDO_KEEPALIVE + if (m_state == TCP_ESTABLISHED) { + nTimeout = std::min<int32_t>( + nTimeout, rtc::TimeDiff(m_lasttraffic + (m_bOutgoing ? IDLE_PING * 3 / 2 + : IDLE_PING), + now)); + } +#endif // PSEUDO_KEEPALIVE + return true; +} + +bool PseudoTcp::process(Segment& seg) { + // If this is the wrong conversation, send a reset!?! (with the correct conversation?) + if (seg.conv != m_conv) { + //if ((seg.flags & FLAG_RST) == 0) { + // packet(tcb, seg.ack, 0, FLAG_RST, 0, 0); + //} + LOG_F(LS_ERROR) << "wrong conversation"; + return false; + } + + uint32_t now = Now(); + m_lasttraffic = m_lastrecv = now; + m_bOutgoing = false; + + if (m_state == TCP_CLOSED) { + // !?! send reset? + LOG_F(LS_ERROR) << "closed"; + return false; + } + + // Check if this is a reset segment + if (seg.flags & FLAG_RST) { + closedown(ECONNRESET); + return false; + } + + // Check for control data + bool bConnect = false; + if (seg.flags & FLAG_CTL) { + if (seg.len == 0) { + LOG_F(LS_ERROR) << "Missing control code"; + return false; + } else if (seg.data[0] == CTL_CONNECT) { + bConnect = true; + + // TCP options are in the remainder of the payload after CTL_CONNECT. + parseOptions(&seg.data[1], seg.len - 1); + + if (m_state == TCP_LISTEN) { + m_state = TCP_SYN_RECEIVED; + LOG(LS_INFO) << "State: TCP_SYN_RECEIVED"; + //m_notify->associate(addr); + queueConnectMessage(); + } else if (m_state == TCP_SYN_SENT) { + m_state = TCP_ESTABLISHED; + LOG(LS_INFO) << "State: TCP_ESTABLISHED"; + adjustMTU(); + if (m_notify) { + m_notify->OnTcpOpen(this); + } + //notify(evOpen); + } + } else { + LOG_F(LS_WARNING) << "Unknown control code: " << seg.data[0]; + return false; + } + } + + // Update timestamp + if ((seg.seq <= m_ts_lastack) && (m_ts_lastack < seg.seq + seg.len)) { + m_ts_recent = seg.tsval; + } + + // Check if this is a valuable ack + if ((seg.ack > m_snd_una) && (seg.ack <= m_snd_nxt)) { + // Calculate round-trip time + if (seg.tsecr) { + int32_t rtt = rtc::TimeDiff(now, seg.tsecr); + if (rtt >= 0) { + if (m_rx_srtt == 0) { + m_rx_srtt = rtt; + m_rx_rttvar = rtt / 2; + } else { + uint32_t unsigned_rtt = static_cast<uint32_t>(rtt); + uint32_t abs_err = unsigned_rtt > m_rx_srtt + ? unsigned_rtt - m_rx_srtt + : m_rx_srtt - unsigned_rtt; + m_rx_rttvar = (3 * m_rx_rttvar + abs_err) / 4; + m_rx_srtt = (7 * m_rx_srtt + rtt) / 8; + } + m_rx_rto = + bound(MIN_RTO, m_rx_srtt + std::max<uint32_t>(1, 4 * m_rx_rttvar), + MAX_RTO); +#if _DEBUGMSG >= _DBG_VERBOSE + LOG(LS_INFO) << "rtt: " << rtt + << " srtt: " << m_rx_srtt + << " rto: " << m_rx_rto; +#endif // _DEBUGMSG + } else { + ASSERT(false); + } + } + + m_snd_wnd = static_cast<uint32_t>(seg.wnd) << m_swnd_scale; + + uint32_t nAcked = seg.ack - m_snd_una; + m_snd_una = seg.ack; + + m_rto_base = (m_snd_una == m_snd_nxt) ? 0 : now; + + m_sbuf.ConsumeReadData(nAcked); + + for (uint32_t nFree = nAcked; nFree > 0;) { + ASSERT(!m_slist.empty()); + if (nFree < m_slist.front().len) { + m_slist.front().len -= nFree; + nFree = 0; + } else { + if (m_slist.front().len > m_largest) { + m_largest = m_slist.front().len; + } + nFree -= m_slist.front().len; + m_slist.pop_front(); + } + } + + if (m_dup_acks >= 3) { + if (m_snd_una >= m_recover) { // NewReno + uint32_t nInFlight = m_snd_nxt - m_snd_una; + m_cwnd = std::min(m_ssthresh, nInFlight + m_mss); // (Fast Retransmit) +#if _DEBUGMSG >= _DBG_NORMAL + LOG(LS_INFO) << "exit recovery"; +#endif // _DEBUGMSG + m_dup_acks = 0; + } else { +#if _DEBUGMSG >= _DBG_NORMAL + LOG(LS_INFO) << "recovery retransmit"; +#endif // _DEBUGMSG + if (!transmit(m_slist.begin(), now)) { + closedown(ECONNABORTED); + return false; + } + m_cwnd += m_mss - std::min(nAcked, m_cwnd); + } + } else { + m_dup_acks = 0; + // Slow start, congestion avoidance + if (m_cwnd < m_ssthresh) { + m_cwnd += m_mss; + } else { + m_cwnd += std::max<uint32_t>(1, m_mss * m_mss / m_cwnd); + } + } + } else if (seg.ack == m_snd_una) { + // !?! Note, tcp says don't do this... but otherwise how does a closed window become open? + m_snd_wnd = static_cast<uint32_t>(seg.wnd) << m_swnd_scale; + + // Check duplicate acks + if (seg.len > 0) { + // it's a dup ack, but with a data payload, so don't modify m_dup_acks + } else if (m_snd_una != m_snd_nxt) { + m_dup_acks += 1; + if (m_dup_acks == 3) { // (Fast Retransmit) +#if _DEBUGMSG >= _DBG_NORMAL + LOG(LS_INFO) << "enter recovery"; + LOG(LS_INFO) << "recovery retransmit"; +#endif // _DEBUGMSG + if (!transmit(m_slist.begin(), now)) { + closedown(ECONNABORTED); + return false; + } + m_recover = m_snd_nxt; + uint32_t nInFlight = m_snd_nxt - m_snd_una; + m_ssthresh = std::max(nInFlight / 2, 2 * m_mss); + //LOG(LS_INFO) << "m_ssthresh: " << m_ssthresh << " nInFlight: " << nInFlight << " m_mss: " << m_mss; + m_cwnd = m_ssthresh + 3 * m_mss; + } else if (m_dup_acks > 3) { + m_cwnd += m_mss; + } + } else { + m_dup_acks = 0; + } + } + + // !?! A bit hacky + if ((m_state == TCP_SYN_RECEIVED) && !bConnect) { + m_state = TCP_ESTABLISHED; + LOG(LS_INFO) << "State: TCP_ESTABLISHED"; + adjustMTU(); + if (m_notify) { + m_notify->OnTcpOpen(this); + } + //notify(evOpen); + } + + // If we make room in the send queue, notify the user + // The goal it to make sure we always have at least enough data to fill the + // window. We'd like to notify the app when we are halfway to that point. + const uint32_t kIdealRefillSize = (m_sbuf_len + m_rbuf_len) / 2; + size_t snd_buffered = 0; + m_sbuf.GetBuffered(&snd_buffered); + if (m_bWriteEnable && + static_cast<uint32_t>(snd_buffered) < kIdealRefillSize) { + m_bWriteEnable = false; + if (m_notify) { + m_notify->OnTcpWriteable(this); + } + //notify(evWrite); + } + + // Conditions were acks must be sent: + // 1) Segment is too old (they missed an ACK) (immediately) + // 2) Segment is too new (we missed a segment) (immediately) + // 3) Segment has data (so we need to ACK!) (delayed) + // ... so the only time we don't need to ACK, is an empty segment that points to rcv_nxt! + + SendFlags sflags = sfNone; + if (seg.seq != m_rcv_nxt) { + sflags = sfImmediateAck; // (Fast Recovery) + } else if (seg.len != 0) { + if (m_ack_delay == 0) { + sflags = sfImmediateAck; + } else { + sflags = sfDelayedAck; + } + } +#if _DEBUGMSG >= _DBG_NORMAL + if (sflags == sfImmediateAck) { + if (seg.seq > m_rcv_nxt) { + LOG_F(LS_INFO) << "too new"; + } else if (seg.seq + seg.len <= m_rcv_nxt) { + LOG_F(LS_INFO) << "too old"; + } + } +#endif // _DEBUGMSG + + // Adjust the incoming segment to fit our receive buffer + if (seg.seq < m_rcv_nxt) { + uint32_t nAdjust = m_rcv_nxt - seg.seq; + if (nAdjust < seg.len) { + seg.seq += nAdjust; + seg.data += nAdjust; + seg.len -= nAdjust; + } else { + seg.len = 0; + } + } + + size_t available_space = 0; + m_rbuf.GetWriteRemaining(&available_space); + + if ((seg.seq + seg.len - m_rcv_nxt) > + static_cast<uint32_t>(available_space)) { + uint32_t nAdjust = + seg.seq + seg.len - m_rcv_nxt - static_cast<uint32_t>(available_space); + if (nAdjust < seg.len) { + seg.len -= nAdjust; + } else { + seg.len = 0; + } + } + + bool bIgnoreData = (seg.flags & FLAG_CTL) || (m_shutdown != SD_NONE); + bool bNewData = false; + + if (seg.len > 0) { + if (bIgnoreData) { + if (seg.seq == m_rcv_nxt) { + m_rcv_nxt += seg.len; + } + } else { + uint32_t nOffset = seg.seq - m_rcv_nxt; + + rtc::StreamResult result = m_rbuf.WriteOffset(seg.data, seg.len, + nOffset, NULL); + ASSERT(result == rtc::SR_SUCCESS); + RTC_UNUSED(result); + + if (seg.seq == m_rcv_nxt) { + m_rbuf.ConsumeWriteBuffer(seg.len); + m_rcv_nxt += seg.len; + m_rcv_wnd -= seg.len; + bNewData = true; + + RList::iterator it = m_rlist.begin(); + while ((it != m_rlist.end()) && (it->seq <= m_rcv_nxt)) { + if (it->seq + it->len > m_rcv_nxt) { + sflags = sfImmediateAck; // (Fast Recovery) + uint32_t nAdjust = (it->seq + it->len) - m_rcv_nxt; +#if _DEBUGMSG >= _DBG_NORMAL + LOG(LS_INFO) << "Recovered " << nAdjust << " bytes (" << m_rcv_nxt << " -> " << m_rcv_nxt + nAdjust << ")"; +#endif // _DEBUGMSG + m_rbuf.ConsumeWriteBuffer(nAdjust); + m_rcv_nxt += nAdjust; + m_rcv_wnd -= nAdjust; + } + it = m_rlist.erase(it); + } + } else { +#if _DEBUGMSG >= _DBG_NORMAL + LOG(LS_INFO) << "Saving " << seg.len << " bytes (" << seg.seq << " -> " << seg.seq + seg.len << ")"; +#endif // _DEBUGMSG + RSegment rseg; + rseg.seq = seg.seq; + rseg.len = seg.len; + RList::iterator it = m_rlist.begin(); + while ((it != m_rlist.end()) && (it->seq < rseg.seq)) { + ++it; + } + m_rlist.insert(it, rseg); + } + } + } + + attemptSend(sflags); + + // If we have new data, notify the user + if (bNewData && m_bReadEnable) { + m_bReadEnable = false; + if (m_notify) { + m_notify->OnTcpReadable(this); + } + //notify(evRead); + } + + return true; +} + +bool PseudoTcp::transmit(const SList::iterator& seg, uint32_t now) { + if (seg->xmit >= ((m_state == TCP_ESTABLISHED) ? 15 : 30)) { + LOG_F(LS_VERBOSE) << "too many retransmits"; + return false; + } + + uint32_t nTransmit = std::min(seg->len, m_mss); + + while (true) { + uint32_t seq = seg->seq; + uint8_t flags = (seg->bCtrl ? FLAG_CTL : 0); + IPseudoTcpNotify::WriteResult wres = packet(seq, + flags, + seg->seq - m_snd_una, + nTransmit); + + if (wres == IPseudoTcpNotify::WR_SUCCESS) + break; + + if (wres == IPseudoTcpNotify::WR_FAIL) { + LOG_F(LS_VERBOSE) << "packet failed"; + return false; + } + + ASSERT(wres == IPseudoTcpNotify::WR_TOO_LARGE); + + while (true) { + if (PACKET_MAXIMUMS[m_msslevel + 1] == 0) { + LOG_F(LS_VERBOSE) << "MTU too small"; + return false; + } + // !?! We need to break up all outstanding and pending packets and then retransmit!?! + + m_mss = PACKET_MAXIMUMS[++m_msslevel] - PACKET_OVERHEAD; + m_cwnd = 2 * m_mss; // I added this... haven't researched actual formula + if (m_mss < nTransmit) { + nTransmit = m_mss; + break; + } + } +#if _DEBUGMSG >= _DBG_NORMAL + LOG(LS_INFO) << "Adjusting mss to " << m_mss << " bytes"; +#endif // _DEBUGMSG + } + + if (nTransmit < seg->len) { + LOG_F(LS_VERBOSE) << "mss reduced to " << m_mss; + + SSegment subseg(seg->seq + nTransmit, seg->len - nTransmit, seg->bCtrl); + //subseg.tstamp = seg->tstamp; + subseg.xmit = seg->xmit; + seg->len = nTransmit; + + SList::iterator next = seg; + m_slist.insert(++next, subseg); + } + + if (seg->xmit == 0) { + m_snd_nxt += seg->len; + } + seg->xmit += 1; + //seg->tstamp = now; + if (m_rto_base == 0) { + m_rto_base = now; + } + + return true; +} + +void PseudoTcp::attemptSend(SendFlags sflags) { + uint32_t now = Now(); + + if (rtc::TimeDiff(now, m_lastsend) > static_cast<long>(m_rx_rto)) { + m_cwnd = m_mss; + } + +#if _DEBUGMSG + bool bFirst = true; + RTC_UNUSED(bFirst); +#endif // _DEBUGMSG + + while (true) { + uint32_t cwnd = m_cwnd; + if ((m_dup_acks == 1) || (m_dup_acks == 2)) { // Limited Transmit + cwnd += m_dup_acks * m_mss; + } + uint32_t nWindow = std::min(m_snd_wnd, cwnd); + uint32_t nInFlight = m_snd_nxt - m_snd_una; + uint32_t nUseable = (nInFlight < nWindow) ? (nWindow - nInFlight) : 0; + + size_t snd_buffered = 0; + m_sbuf.GetBuffered(&snd_buffered); + uint32_t nAvailable = + std::min(static_cast<uint32_t>(snd_buffered) - nInFlight, m_mss); + + if (nAvailable > nUseable) { + if (nUseable * 4 < nWindow) { + // RFC 813 - avoid SWS + nAvailable = 0; + } else { + nAvailable = nUseable; + } + } + +#if _DEBUGMSG >= _DBG_VERBOSE + if (bFirst) { + size_t available_space = 0; + m_sbuf.GetWriteRemaining(&available_space); + + bFirst = false; + LOG(LS_INFO) << "[cwnd: " << m_cwnd + << " nWindow: " << nWindow + << " nInFlight: " << nInFlight + << " nAvailable: " << nAvailable + << " nQueued: " << snd_buffered + << " nEmpty: " << available_space + << " ssthresh: " << m_ssthresh << "]"; + } +#endif // _DEBUGMSG + + if (nAvailable == 0) { + if (sflags == sfNone) + return; + + // If this is an immediate ack, or the second delayed ack + if ((sflags == sfImmediateAck) || m_t_ack) { + packet(m_snd_nxt, 0, 0, 0); + } else { + m_t_ack = Now(); + } + return; + } + + // Nagle's algorithm. + // If there is data already in-flight, and we haven't a full segment of + // data ready to send then hold off until we get more to send, or the + // in-flight data is acknowledged. + if (m_use_nagling && (m_snd_nxt > m_snd_una) && (nAvailable < m_mss)) { + return; + } + + // Find the next segment to transmit + SList::iterator it = m_slist.begin(); + while (it->xmit > 0) { + ++it; + ASSERT(it != m_slist.end()); + } + SList::iterator seg = it; + + // If the segment is too large, break it into two + if (seg->len > nAvailable) { + SSegment subseg(seg->seq + nAvailable, seg->len - nAvailable, seg->bCtrl); + seg->len = nAvailable; + m_slist.insert(++it, subseg); + } + + if (!transmit(seg, now)) { + LOG_F(LS_VERBOSE) << "transmit failed"; + // TODO: consider closing socket + return; + } + + sflags = sfNone; + } +} + +void PseudoTcp::closedown(uint32_t err) { + LOG(LS_INFO) << "State: TCP_CLOSED"; + m_state = TCP_CLOSED; + if (m_notify) { + m_notify->OnTcpClosed(this, err); + } + //notify(evClose, err); +} + +void +PseudoTcp::adjustMTU() { + // Determine our current mss level, so that we can adjust appropriately later + for (m_msslevel = 0; PACKET_MAXIMUMS[m_msslevel + 1] > 0; ++m_msslevel) { + if (static_cast<uint16_t>(PACKET_MAXIMUMS[m_msslevel]) <= m_mtu_advise) { + break; + } + } + m_mss = m_mtu_advise - PACKET_OVERHEAD; + // !?! Should we reset m_largest here? +#if _DEBUGMSG >= _DBG_NORMAL + LOG(LS_INFO) << "Adjusting mss to " << m_mss << " bytes"; +#endif // _DEBUGMSG + // Enforce minimums on ssthresh and cwnd + m_ssthresh = std::max(m_ssthresh, 2 * m_mss); + m_cwnd = std::max(m_cwnd, m_mss); +} + +bool +PseudoTcp::isReceiveBufferFull() const { + size_t available_space = 0; + m_rbuf.GetWriteRemaining(&available_space); + return !available_space; +} + +void +PseudoTcp::disableWindowScale() { + m_support_wnd_scale = false; +} + +void +PseudoTcp::queueConnectMessage() { + rtc::ByteBuffer buf(rtc::ByteBuffer::ORDER_NETWORK); + + buf.WriteUInt8(CTL_CONNECT); + if (m_support_wnd_scale) { + buf.WriteUInt8(TCP_OPT_WND_SCALE); + buf.WriteUInt8(1); + buf.WriteUInt8(m_rwnd_scale); + } + m_snd_wnd = static_cast<uint32_t>(buf.Length()); + queue(buf.Data(), static_cast<uint32_t>(buf.Length()), true); +} + +void PseudoTcp::parseOptions(const char* data, uint32_t len) { + std::set<uint8_t> options_specified; + + // See http://www.freesoft.org/CIE/Course/Section4/8.htm for + // parsing the options list. + rtc::ByteBuffer buf(data, len); + while (buf.Length()) { + uint8_t kind = TCP_OPT_EOL; + buf.ReadUInt8(&kind); + + if (kind == TCP_OPT_EOL) { + // End of option list. + break; + } else if (kind == TCP_OPT_NOOP) { + // No op. + continue; + } + + // Length of this option. + ASSERT(len != 0); + RTC_UNUSED(len); + uint8_t opt_len = 0; + buf.ReadUInt8(&opt_len); + + // Content of this option. + if (opt_len <= buf.Length()) { + applyOption(kind, buf.Data(), opt_len); + buf.Consume(opt_len); + } else { + LOG(LS_ERROR) << "Invalid option length received."; + return; + } + options_specified.insert(kind); + } + + if (options_specified.find(TCP_OPT_WND_SCALE) == options_specified.end()) { + LOG(LS_WARNING) << "Peer doesn't support window scaling"; + + if (m_rwnd_scale > 0) { + // Peer doesn't support TCP options and window scaling. + // Revert receive buffer size to default value. + resizeReceiveBuffer(DEFAULT_RCV_BUF_SIZE); + m_swnd_scale = 0; + } + } +} + +void PseudoTcp::applyOption(char kind, const char* data, uint32_t len) { + if (kind == TCP_OPT_MSS) { + LOG(LS_WARNING) << "Peer specified MSS option which is not supported."; + // TODO: Implement. + } else if (kind == TCP_OPT_WND_SCALE) { + // Window scale factor. + // http://www.ietf.org/rfc/rfc1323.txt + if (len != 1) { + LOG_F(WARNING) << "Invalid window scale option received."; + return; + } + applyWindowScaleOption(data[0]); + } +} + +void PseudoTcp::applyWindowScaleOption(uint8_t scale_factor) { + m_swnd_scale = scale_factor; +} + +void PseudoTcp::resizeSendBuffer(uint32_t new_size) { + m_sbuf_len = new_size; + m_sbuf.SetCapacity(new_size); +} + +void PseudoTcp::resizeReceiveBuffer(uint32_t new_size) { + uint8_t scale_factor = 0; + + // Determine the scale factor such that the scaled window size can fit + // in a 16-bit unsigned integer. + while (new_size > 0xFFFF) { + ++scale_factor; + new_size >>= 1; + } + + // Determine the proper size of the buffer. + new_size <<= scale_factor; + bool result = m_rbuf.SetCapacity(new_size); + + // Make sure the new buffer is large enough to contain data in the old + // buffer. This should always be true because this method is called either + // before connection is established or when peers are exchanging connect + // messages. + ASSERT(result); + RTC_UNUSED(result); + m_rbuf_len = new_size; + m_rwnd_scale = scale_factor; + m_ssthresh = new_size; + + size_t available_space = 0; + m_rbuf.GetWriteRemaining(&available_space); + m_rcv_wnd = static_cast<uint32_t>(available_space); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/pseudotcp.h b/webrtc/p2p/base/pseudotcp.h new file mode 100644 index 0000000000..6d402daa6f --- /dev/null +++ b/webrtc/p2p/base/pseudotcp.h @@ -0,0 +1,242 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_PSEUDOTCP_H_ +#define WEBRTC_P2P_BASE_PSEUDOTCP_H_ + +#include <list> + +#include "webrtc/base/basictypes.h" +#include "webrtc/base/stream.h" + +namespace cricket { + +////////////////////////////////////////////////////////////////////// +// IPseudoTcpNotify +////////////////////////////////////////////////////////////////////// + +class PseudoTcp; + +class IPseudoTcpNotify { + public: + // Notification of tcp events + virtual void OnTcpOpen(PseudoTcp* tcp) = 0; + virtual void OnTcpReadable(PseudoTcp* tcp) = 0; + virtual void OnTcpWriteable(PseudoTcp* tcp) = 0; + virtual void OnTcpClosed(PseudoTcp* tcp, uint32_t error) = 0; + + // Write the packet onto the network + enum WriteResult { WR_SUCCESS, WR_TOO_LARGE, WR_FAIL }; + virtual WriteResult TcpWritePacket(PseudoTcp* tcp, + const char* buffer, size_t len) = 0; + + protected: + virtual ~IPseudoTcpNotify() {} +}; + +////////////////////////////////////////////////////////////////////// +// PseudoTcp +////////////////////////////////////////////////////////////////////// + +class PseudoTcp { + public: + static uint32_t Now(); + + PseudoTcp(IPseudoTcpNotify* notify, uint32_t conv); + virtual ~PseudoTcp(); + + int Connect(); + int Recv(char* buffer, size_t len); + int Send(const char* buffer, size_t len); + void Close(bool force); + int GetError(); + + enum TcpState { + TCP_LISTEN, TCP_SYN_SENT, TCP_SYN_RECEIVED, TCP_ESTABLISHED, TCP_CLOSED + }; + TcpState State() const { return m_state; } + + // Call this when the PMTU changes. + void NotifyMTU(uint16_t mtu); + + // Call this based on timeout value returned from GetNextClock. + // It's ok to call this too frequently. + void NotifyClock(uint32_t now); + + // Call this whenever a packet arrives. + // Returns true if the packet was processed successfully. + bool NotifyPacket(const char * buffer, size_t len); + + // Call this to determine the next time NotifyClock should be called. + // Returns false if the socket is ready to be destroyed. + bool GetNextClock(uint32_t now, long& timeout); + + // Call these to get/set option values to tailor this PseudoTcp + // instance's behaviour for the kind of data it will carry. + // If an unrecognized option is set or got, an assertion will fire. + // + // Setting options for OPT_RCVBUF or OPT_SNDBUF after Connect() is called + // will result in an assertion. + enum Option { + OPT_NODELAY, // Whether to enable Nagle's algorithm (0 == off) + OPT_ACKDELAY, // The Delayed ACK timeout (0 == off). + OPT_RCVBUF, // Set the receive buffer size, in bytes. + OPT_SNDBUF, // Set the send buffer size, in bytes. + }; + void GetOption(Option opt, int* value); + void SetOption(Option opt, int value); + + // Returns current congestion window in bytes. + uint32_t GetCongestionWindow() const; + + // Returns amount of data in bytes that has been sent, but haven't + // been acknowledged. + uint32_t GetBytesInFlight() const; + + // Returns number of bytes that were written in buffer and haven't + // been sent. + uint32_t GetBytesBufferedNotSent() const; + + // Returns current round-trip time estimate in milliseconds. + uint32_t GetRoundTripTimeEstimateMs() const; + + protected: + enum SendFlags { sfNone, sfDelayedAck, sfImmediateAck }; + + struct Segment { + uint32_t conv, seq, ack; + uint8_t flags; + uint16_t wnd; + const char * data; + uint32_t len; + uint32_t tsval, tsecr; + }; + + struct SSegment { + SSegment(uint32_t s, uint32_t l, bool c) + : seq(s), len(l), /*tstamp(0),*/ xmit(0), bCtrl(c) {} + uint32_t seq, len; + // uint32_t tstamp; + uint8_t xmit; + bool bCtrl; + }; + typedef std::list<SSegment> SList; + + struct RSegment { + uint32_t seq, len; + }; + + uint32_t queue(const char* data, uint32_t len, bool bCtrl); + + // Creates a packet and submits it to the network. This method can either + // send payload or just an ACK packet. + // + // |seq| is the sequence number of this packet. + // |flags| is the flags for sending this packet. + // |offset| is the offset to read from |m_sbuf|. + // |len| is the number of bytes to read from |m_sbuf| as payload. If this + // value is 0 then this is an ACK packet, otherwise this packet has payload. + IPseudoTcpNotify::WriteResult packet(uint32_t seq, + uint8_t flags, + uint32_t offset, + uint32_t len); + bool parse(const uint8_t* buffer, uint32_t size); + + void attemptSend(SendFlags sflags = sfNone); + + void closedown(uint32_t err = 0); + + bool clock_check(uint32_t now, long& nTimeout); + + bool process(Segment& seg); + bool transmit(const SList::iterator& seg, uint32_t now); + + void adjustMTU(); + + protected: + // This method is used in test only to query receive buffer state. + bool isReceiveBufferFull() const; + + // This method is only used in tests, to disable window scaling + // support for testing backward compatibility. + void disableWindowScale(); + + private: + // Queue the connect message with TCP options. + void queueConnectMessage(); + + // Parse TCP options in the header. + void parseOptions(const char* data, uint32_t len); + + // Apply a TCP option that has been read from the header. + void applyOption(char kind, const char* data, uint32_t len); + + // Apply window scale option. + void applyWindowScaleOption(uint8_t scale_factor); + + // Resize the send buffer with |new_size| in bytes. + void resizeSendBuffer(uint32_t new_size); + + // Resize the receive buffer with |new_size| in bytes. This call adjusts + // window scale factor |m_swnd_scale| accordingly. + void resizeReceiveBuffer(uint32_t new_size); + + IPseudoTcpNotify* m_notify; + enum Shutdown { SD_NONE, SD_GRACEFUL, SD_FORCEFUL } m_shutdown; + int m_error; + + // TCB data + TcpState m_state; + uint32_t m_conv; + bool m_bReadEnable, m_bWriteEnable, m_bOutgoing; + uint32_t m_lasttraffic; + + // Incoming data + typedef std::list<RSegment> RList; + RList m_rlist; + uint32_t m_rbuf_len, m_rcv_nxt, m_rcv_wnd, m_lastrecv; + uint8_t m_rwnd_scale; // Window scale factor. + rtc::FifoBuffer m_rbuf; + + // Outgoing data + SList m_slist; + uint32_t m_sbuf_len, m_snd_nxt, m_snd_wnd, m_lastsend, m_snd_una; + uint8_t m_swnd_scale; // Window scale factor. + rtc::FifoBuffer m_sbuf; + + // Maximum segment size, estimated protocol level, largest segment sent + uint32_t m_mss, m_msslevel, m_largest, m_mtu_advise; + // Retransmit timer + uint32_t m_rto_base; + + // Timestamp tracking + uint32_t m_ts_recent, m_ts_lastack; + + // Round-trip calculation + uint32_t m_rx_rttvar, m_rx_srtt, m_rx_rto; + + // Congestion avoidance, Fast retransmit/recovery, Delayed ACKs + uint32_t m_ssthresh, m_cwnd; + uint8_t m_dup_acks; + uint32_t m_recover; + uint32_t m_t_ack; + + // Configuration options + bool m_use_nagling; + uint32_t m_ack_delay; + + // This is used by unit tests to test backward compatibility of + // PseudoTcp implementations that don't support window scaling. + bool m_support_wnd_scale; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_PSEUDOTCP_H_ diff --git a/webrtc/p2p/base/pseudotcp_unittest.cc b/webrtc/p2p/base/pseudotcp_unittest.cc new file mode 100644 index 0000000000..c9ccbca1d9 --- /dev/null +++ b/webrtc/p2p/base/pseudotcp_unittest.cc @@ -0,0 +1,840 @@ +/* + * Copyright 2011 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <algorithm> +#include <vector> + +#include "webrtc/p2p/base/pseudotcp.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/messagehandler.h" +#include "webrtc/base/stream.h" +#include "webrtc/base/thread.h" +#include "webrtc/base/timeutils.h" + +using cricket::PseudoTcp; + +static const int kConnectTimeoutMs = 10000; // ~3 * default RTO of 3000ms +static const int kTransferTimeoutMs = 15000; +static const int kBlockSize = 4096; + +class PseudoTcpForTest : public cricket::PseudoTcp { + public: + PseudoTcpForTest(cricket::IPseudoTcpNotify* notify, uint32_t conv) + : PseudoTcp(notify, conv) {} + + bool isReceiveBufferFull() const { + return PseudoTcp::isReceiveBufferFull(); + } + + void disableWindowScale() { + PseudoTcp::disableWindowScale(); + } +}; + +class PseudoTcpTestBase : public testing::Test, + public rtc::MessageHandler, + public cricket::IPseudoTcpNotify { + public: + PseudoTcpTestBase() + : local_(this, 1), + remote_(this, 1), + have_connected_(false), + have_disconnected_(false), + local_mtu_(65535), + remote_mtu_(65535), + delay_(0), + loss_(0) { + // Set use of the test RNG to get predictable loss patterns. + rtc::SetRandomTestMode(true); + } + ~PseudoTcpTestBase() { + // Put it back for the next test. + rtc::SetRandomTestMode(false); + } + void SetLocalMtu(int mtu) { + local_.NotifyMTU(mtu); + local_mtu_ = mtu; + } + void SetRemoteMtu(int mtu) { + remote_.NotifyMTU(mtu); + remote_mtu_ = mtu; + } + void SetDelay(int delay) { + delay_ = delay; + } + void SetLoss(int percent) { + loss_ = percent; + } + void SetOptNagling(bool enable_nagles) { + local_.SetOption(PseudoTcp::OPT_NODELAY, !enable_nagles); + remote_.SetOption(PseudoTcp::OPT_NODELAY, !enable_nagles); + } + void SetOptAckDelay(int ack_delay) { + local_.SetOption(PseudoTcp::OPT_ACKDELAY, ack_delay); + remote_.SetOption(PseudoTcp::OPT_ACKDELAY, ack_delay); + } + void SetOptSndBuf(int size) { + local_.SetOption(PseudoTcp::OPT_SNDBUF, size); + remote_.SetOption(PseudoTcp::OPT_SNDBUF, size); + } + void SetRemoteOptRcvBuf(int size) { + remote_.SetOption(PseudoTcp::OPT_RCVBUF, size); + } + void SetLocalOptRcvBuf(int size) { + local_.SetOption(PseudoTcp::OPT_RCVBUF, size); + } + void DisableRemoteWindowScale() { + remote_.disableWindowScale(); + } + void DisableLocalWindowScale() { + local_.disableWindowScale(); + } + + protected: + int Connect() { + int ret = local_.Connect(); + if (ret == 0) { + UpdateLocalClock(); + } + return ret; + } + void Close() { + local_.Close(false); + UpdateLocalClock(); + } + + enum { MSG_LPACKET, MSG_RPACKET, MSG_LCLOCK, MSG_RCLOCK, MSG_IOCOMPLETE, + MSG_WRITE}; + virtual void OnTcpOpen(PseudoTcp* tcp) { + // Consider ourselves connected when the local side gets OnTcpOpen. + // OnTcpWriteable isn't fired at open, so we trigger it now. + LOG(LS_VERBOSE) << "Opened"; + if (tcp == &local_) { + have_connected_ = true; + OnTcpWriteable(tcp); + } + } + // Test derived from the base should override + // virtual void OnTcpReadable(PseudoTcp* tcp) + // and + // virtual void OnTcpWritable(PseudoTcp* tcp) + virtual void OnTcpClosed(PseudoTcp* tcp, uint32_t error) { + // Consider ourselves closed when the remote side gets OnTcpClosed. + // TODO: OnTcpClosed is only ever notified in case of error in + // the current implementation. Solicited close is not (yet) supported. + LOG(LS_VERBOSE) << "Closed"; + EXPECT_EQ(0U, error); + if (tcp == &remote_) { + have_disconnected_ = true; + } + } + virtual WriteResult TcpWritePacket(PseudoTcp* tcp, + const char* buffer, size_t len) { + // Randomly drop the desired percentage of packets. + // Also drop packets that are larger than the configured MTU. + if (rtc::CreateRandomId() % 100 < static_cast<uint32_t>(loss_)) { + LOG(LS_VERBOSE) << "Randomly dropping packet, size=" << len; + } else if (len > static_cast<size_t>(std::min(local_mtu_, remote_mtu_))) { + LOG(LS_VERBOSE) << "Dropping packet that exceeds path MTU, size=" << len; + } else { + int id = (tcp == &local_) ? MSG_RPACKET : MSG_LPACKET; + std::string packet(buffer, len); + rtc::Thread::Current()->PostDelayed(delay_, this, id, + rtc::WrapMessageData(packet)); + } + return WR_SUCCESS; + } + + void UpdateLocalClock() { UpdateClock(&local_, MSG_LCLOCK); } + void UpdateRemoteClock() { UpdateClock(&remote_, MSG_RCLOCK); } + void UpdateClock(PseudoTcp* tcp, uint32_t message) { + long interval = 0; // NOLINT + tcp->GetNextClock(PseudoTcp::Now(), interval); + interval = std::max<int>(interval, 0L); // sometimes interval is < 0 + rtc::Thread::Current()->Clear(this, message); + rtc::Thread::Current()->PostDelayed(interval, this, message); + } + + virtual void OnMessage(rtc::Message* message) { + switch (message->message_id) { + case MSG_LPACKET: { + const std::string& s( + rtc::UseMessageData<std::string>(message->pdata)); + local_.NotifyPacket(s.c_str(), s.size()); + UpdateLocalClock(); + break; + } + case MSG_RPACKET: { + const std::string& s( + rtc::UseMessageData<std::string>(message->pdata)); + remote_.NotifyPacket(s.c_str(), s.size()); + UpdateRemoteClock(); + break; + } + case MSG_LCLOCK: + local_.NotifyClock(PseudoTcp::Now()); + UpdateLocalClock(); + break; + case MSG_RCLOCK: + remote_.NotifyClock(PseudoTcp::Now()); + UpdateRemoteClock(); + break; + default: + break; + } + delete message->pdata; + } + + PseudoTcpForTest local_; + PseudoTcpForTest remote_; + rtc::MemoryStream send_stream_; + rtc::MemoryStream recv_stream_; + bool have_connected_; + bool have_disconnected_; + int local_mtu_; + int remote_mtu_; + int delay_; + int loss_; +}; + +class PseudoTcpTest : public PseudoTcpTestBase { + public: + void TestTransfer(int size) { + uint32_t start, elapsed; + size_t received; + // Create some dummy data to send. + send_stream_.ReserveSize(size); + for (int i = 0; i < size; ++i) { + char ch = static_cast<char>(i); + send_stream_.Write(&ch, 1, NULL, NULL); + } + send_stream_.Rewind(); + // Prepare the receive stream. + recv_stream_.ReserveSize(size); + // Connect and wait until connected. + start = rtc::Time(); + EXPECT_EQ(0, Connect()); + EXPECT_TRUE_WAIT(have_connected_, kConnectTimeoutMs); + // Sending will start from OnTcpWriteable and complete when all data has + // been received. + EXPECT_TRUE_WAIT(have_disconnected_, kTransferTimeoutMs); + elapsed = rtc::TimeSince(start); + recv_stream_.GetSize(&received); + // Ensure we closed down OK and we got the right data. + // TODO: Ensure the errors are cleared properly. + //EXPECT_EQ(0, local_.GetError()); + //EXPECT_EQ(0, remote_.GetError()); + EXPECT_EQ(static_cast<size_t>(size), received); + EXPECT_EQ(0, memcmp(send_stream_.GetBuffer(), + recv_stream_.GetBuffer(), size)); + LOG(LS_INFO) << "Transferred " << received << " bytes in " << elapsed + << " ms (" << size * 8 / elapsed << " Kbps)"; + } + + private: + // IPseudoTcpNotify interface + + virtual void OnTcpReadable(PseudoTcp* tcp) { + // Stream bytes to the recv stream as they arrive. + if (tcp == &remote_) { + ReadData(); + + // TODO: OnTcpClosed() is currently only notified on error - + // there is no on-the-wire equivalent of TCP FIN. + // So we fake the notification when all the data has been read. + size_t received, required; + recv_stream_.GetPosition(&received); + send_stream_.GetSize(&required); + if (received == required) + OnTcpClosed(&remote_, 0); + } + } + virtual void OnTcpWriteable(PseudoTcp* tcp) { + // Write bytes from the send stream when we can. + // Shut down when we've sent everything. + if (tcp == &local_) { + LOG(LS_VERBOSE) << "Flow Control Lifted"; + bool done; + WriteData(&done); + if (done) { + Close(); + } + } + } + + void ReadData() { + char block[kBlockSize]; + size_t position; + int rcvd; + do { + rcvd = remote_.Recv(block, sizeof(block)); + if (rcvd != -1) { + recv_stream_.Write(block, rcvd, NULL, NULL); + recv_stream_.GetPosition(&position); + LOG(LS_VERBOSE) << "Received: " << position; + } + } while (rcvd > 0); + } + void WriteData(bool* done) { + size_t position, tosend; + int sent; + char block[kBlockSize]; + do { + send_stream_.GetPosition(&position); + if (send_stream_.Read(block, sizeof(block), &tosend, NULL) != + rtc::SR_EOS) { + sent = local_.Send(block, tosend); + UpdateLocalClock(); + if (sent != -1) { + send_stream_.SetPosition(position + sent); + LOG(LS_VERBOSE) << "Sent: " << position + sent; + } else { + send_stream_.SetPosition(position); + LOG(LS_VERBOSE) << "Flow Controlled"; + } + } else { + sent = static_cast<int>(tosend = 0); + } + } while (sent > 0); + *done = (tosend == 0); + } + + private: + rtc::MemoryStream send_stream_; + rtc::MemoryStream recv_stream_; +}; + + +class PseudoTcpTestPingPong : public PseudoTcpTestBase { + public: + PseudoTcpTestPingPong() + : iterations_remaining_(0), + sender_(NULL), + receiver_(NULL), + bytes_per_send_(0) { + } + void SetBytesPerSend(int bytes) { + bytes_per_send_ = bytes; + } + void TestPingPong(int size, int iterations) { + uint32_t start, elapsed; + iterations_remaining_ = iterations; + receiver_ = &remote_; + sender_ = &local_; + // Create some dummy data to send. + send_stream_.ReserveSize(size); + for (int i = 0; i < size; ++i) { + char ch = static_cast<char>(i); + send_stream_.Write(&ch, 1, NULL, NULL); + } + send_stream_.Rewind(); + // Prepare the receive stream. + recv_stream_.ReserveSize(size); + // Connect and wait until connected. + start = rtc::Time(); + EXPECT_EQ(0, Connect()); + EXPECT_TRUE_WAIT(have_connected_, kConnectTimeoutMs); + // Sending will start from OnTcpWriteable and stop when the required + // number of iterations have completed. + EXPECT_TRUE_WAIT(have_disconnected_, kTransferTimeoutMs); + elapsed = rtc::TimeSince(start); + LOG(LS_INFO) << "Performed " << iterations << " pings in " + << elapsed << " ms"; + } + + private: + // IPseudoTcpNotify interface + + virtual void OnTcpReadable(PseudoTcp* tcp) { + if (tcp != receiver_) { + LOG_F(LS_ERROR) << "unexpected OnTcpReadable"; + return; + } + // Stream bytes to the recv stream as they arrive. + ReadData(); + // If we've received the desired amount of data, rewind things + // and send it back the other way! + size_t position, desired; + recv_stream_.GetPosition(&position); + send_stream_.GetSize(&desired); + if (position == desired) { + if (receiver_ == &local_ && --iterations_remaining_ == 0) { + Close(); + // TODO: Fake OnTcpClosed() on the receiver for now. + OnTcpClosed(&remote_, 0); + return; + } + PseudoTcp* tmp = receiver_; + receiver_ = sender_; + sender_ = tmp; + recv_stream_.Rewind(); + send_stream_.Rewind(); + OnTcpWriteable(sender_); + } + } + virtual void OnTcpWriteable(PseudoTcp* tcp) { + if (tcp != sender_) + return; + // Write bytes from the send stream when we can. + // Shut down when we've sent everything. + LOG(LS_VERBOSE) << "Flow Control Lifted"; + WriteData(); + } + + void ReadData() { + char block[kBlockSize]; + size_t position; + int rcvd; + do { + rcvd = receiver_->Recv(block, sizeof(block)); + if (rcvd != -1) { + recv_stream_.Write(block, rcvd, NULL, NULL); + recv_stream_.GetPosition(&position); + LOG(LS_VERBOSE) << "Received: " << position; + } + } while (rcvd > 0); + } + void WriteData() { + size_t position, tosend; + int sent; + char block[kBlockSize]; + do { + send_stream_.GetPosition(&position); + tosend = bytes_per_send_ ? bytes_per_send_ : sizeof(block); + if (send_stream_.Read(block, tosend, &tosend, NULL) != + rtc::SR_EOS) { + sent = sender_->Send(block, tosend); + UpdateLocalClock(); + if (sent != -1) { + send_stream_.SetPosition(position + sent); + LOG(LS_VERBOSE) << "Sent: " << position + sent; + } else { + send_stream_.SetPosition(position); + LOG(LS_VERBOSE) << "Flow Controlled"; + } + } else { + sent = static_cast<int>(tosend = 0); + } + } while (sent > 0); + } + + private: + int iterations_remaining_; + PseudoTcp* sender_; + PseudoTcp* receiver_; + int bytes_per_send_; +}; + +// Fill the receiver window until it is full, drain it and then +// fill it with the same amount. This is to test that receiver window +// contracts and enlarges correctly. +class PseudoTcpTestReceiveWindow : public PseudoTcpTestBase { + public: + // Not all the data are transfered, |size| just need to be big enough + // to fill up the receiver window twice. + void TestTransfer(int size) { + // Create some dummy data to send. + send_stream_.ReserveSize(size); + for (int i = 0; i < size; ++i) { + char ch = static_cast<char>(i); + send_stream_.Write(&ch, 1, NULL, NULL); + } + send_stream_.Rewind(); + + // Prepare the receive stream. + recv_stream_.ReserveSize(size); + + // Connect and wait until connected. + EXPECT_EQ(0, Connect()); + EXPECT_TRUE_WAIT(have_connected_, kConnectTimeoutMs); + + rtc::Thread::Current()->Post(this, MSG_WRITE); + EXPECT_TRUE_WAIT(have_disconnected_, kTransferTimeoutMs); + + ASSERT_EQ(2u, send_position_.size()); + ASSERT_EQ(2u, recv_position_.size()); + + const size_t estimated_recv_window = EstimateReceiveWindowSize(); + + // The difference in consecutive send positions should equal the + // receive window size or match very closely. This verifies that receive + // window is open after receiver drained all the data. + const size_t send_position_diff = send_position_[1] - send_position_[0]; + EXPECT_GE(1024u, estimated_recv_window - send_position_diff); + + // Receiver drained the receive window twice. + EXPECT_EQ(2 * estimated_recv_window, recv_position_[1]); + } + + virtual void OnMessage(rtc::Message* message) { + int message_id = message->message_id; + PseudoTcpTestBase::OnMessage(message); + + switch (message_id) { + case MSG_WRITE: { + WriteData(); + break; + } + default: + break; + } + } + + uint32_t EstimateReceiveWindowSize() const { + return static_cast<uint32_t>(recv_position_[0]); + } + + uint32_t EstimateSendWindowSize() const { + return static_cast<uint32_t>(send_position_[0] - recv_position_[0]); + } + + private: + // IPseudoTcpNotify interface + virtual void OnTcpReadable(PseudoTcp* tcp) { + } + + virtual void OnTcpWriteable(PseudoTcp* tcp) { + } + + void ReadUntilIOPending() { + char block[kBlockSize]; + size_t position; + int rcvd; + + do { + rcvd = remote_.Recv(block, sizeof(block)); + if (rcvd != -1) { + recv_stream_.Write(block, rcvd, NULL, NULL); + recv_stream_.GetPosition(&position); + LOG(LS_VERBOSE) << "Received: " << position; + } + } while (rcvd > 0); + + recv_stream_.GetPosition(&position); + recv_position_.push_back(position); + + // Disconnect if we have done two transfers. + if (recv_position_.size() == 2u) { + Close(); + OnTcpClosed(&remote_, 0); + } else { + WriteData(); + } + } + + void WriteData() { + size_t position, tosend; + int sent; + char block[kBlockSize]; + do { + send_stream_.GetPosition(&position); + if (send_stream_.Read(block, sizeof(block), &tosend, NULL) != + rtc::SR_EOS) { + sent = local_.Send(block, tosend); + UpdateLocalClock(); + if (sent != -1) { + send_stream_.SetPosition(position + sent); + LOG(LS_VERBOSE) << "Sent: " << position + sent; + } else { + send_stream_.SetPosition(position); + LOG(LS_VERBOSE) << "Flow Controlled"; + } + } else { + sent = static_cast<int>(tosend = 0); + } + } while (sent > 0); + // At this point, we've filled up the available space in the send queue. + + int message_queue_size = + static_cast<int>(rtc::Thread::Current()->size()); + // The message queue will always have at least 2 messages, an RCLOCK and + // an LCLOCK, since they are added back on the delay queue at the same time + // they are pulled off and therefore are never really removed. + if (message_queue_size > 2) { + // If there are non-clock messages remaining, attempt to continue sending + // after giving those messages time to process, which should free up the + // send buffer. + rtc::Thread::Current()->PostDelayed(10, this, MSG_WRITE); + } else { + if (!remote_.isReceiveBufferFull()) { + LOG(LS_ERROR) << "This shouldn't happen - the send buffer is full, " + << "the receive buffer is not, and there are no " + << "remaining messages to process."; + } + send_stream_.GetPosition(&position); + send_position_.push_back(position); + + // Drain the receiver buffer. + ReadUntilIOPending(); + } + } + + private: + rtc::MemoryStream send_stream_; + rtc::MemoryStream recv_stream_; + + std::vector<size_t> send_position_; + std::vector<size_t> recv_position_; +}; + +// Basic end-to-end data transfer tests + +// Test the normal case of sending data from one side to the other. +TEST_F(PseudoTcpTest, TestSend) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + TestTransfer(1000000); +} + +// Test sending data with a 50 ms RTT. Transmission should take longer due +// to a slower ramp-up in send rate. +TEST_F(PseudoTcpTest, TestSendWithDelay) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetDelay(50); + TestTransfer(1000000); +} + +// Test sending data with packet loss. Transmission should take much longer due +// to send back-off when loss occurs. +TEST_F(PseudoTcpTest, TestSendWithLoss) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetLoss(10); + TestTransfer(100000); // less data so test runs faster +} + +// Test sending data with a 50 ms RTT and 10% packet loss. Transmission should +// take much longer due to send back-off and slower detection of loss. +TEST_F(PseudoTcpTest, TestSendWithDelayAndLoss) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetDelay(50); + SetLoss(10); + TestTransfer(100000); // less data so test runs faster +} + +// Test sending data with 10% packet loss and Nagling disabled. Transmission +// should take about the same time as with Nagling enabled. +TEST_F(PseudoTcpTest, TestSendWithLossAndOptNaglingOff) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetLoss(10); + SetOptNagling(false); + TestTransfer(100000); // less data so test runs faster +} + +// Test sending data with 10% packet loss and Delayed ACK disabled. +// Transmission should be slightly faster than with it enabled. +TEST_F(PseudoTcpTest, TestSendWithLossAndOptAckDelayOff) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetLoss(10); + SetOptAckDelay(0); + TestTransfer(100000); +} + +// Test sending data with 50ms delay and Nagling disabled. +TEST_F(PseudoTcpTest, TestSendWithDelayAndOptNaglingOff) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetDelay(50); + SetOptNagling(false); + TestTransfer(100000); // less data so test runs faster +} + +// Test sending data with 50ms delay and Delayed ACK disabled. +TEST_F(PseudoTcpTest, TestSendWithDelayAndOptAckDelayOff) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetDelay(50); + SetOptAckDelay(0); + TestTransfer(100000); // less data so test runs faster +} + +// Test a large receive buffer with a sender that doesn't support scaling. +TEST_F(PseudoTcpTest, TestSendRemoteNoWindowScale) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetLocalOptRcvBuf(100000); + DisableRemoteWindowScale(); + TestTransfer(1000000); +} + +// Test a large sender-side receive buffer with a receiver that doesn't support +// scaling. +TEST_F(PseudoTcpTest, TestSendLocalNoWindowScale) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetRemoteOptRcvBuf(100000); + DisableLocalWindowScale(); + TestTransfer(1000000); +} + +// Test when both sides use window scaling. +TEST_F(PseudoTcpTest, TestSendBothUseWindowScale) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetRemoteOptRcvBuf(100000); + SetLocalOptRcvBuf(100000); + TestTransfer(1000000); +} + +// Test using a large window scale value. +TEST_F(PseudoTcpTest, TestSendLargeInFlight) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetRemoteOptRcvBuf(100000); + SetLocalOptRcvBuf(100000); + SetOptSndBuf(150000); + TestTransfer(1000000); +} + +TEST_F(PseudoTcpTest, TestSendBothUseLargeWindowScale) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetRemoteOptRcvBuf(1000000); + SetLocalOptRcvBuf(1000000); + TestTransfer(10000000); +} + +// Test using a small receive buffer. +TEST_F(PseudoTcpTest, TestSendSmallReceiveBuffer) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetRemoteOptRcvBuf(10000); + SetLocalOptRcvBuf(10000); + TestTransfer(1000000); +} + +// Test using a very small receive buffer. +TEST_F(PseudoTcpTest, TestSendVerySmallReceiveBuffer) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetRemoteOptRcvBuf(100); + SetLocalOptRcvBuf(100); + TestTransfer(100000); +} + +// Ping-pong (request/response) tests + +// Test sending <= 1x MTU of data in each ping/pong. Should take <10ms. +TEST_F(PseudoTcpTestPingPong, TestPingPong1xMtu) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + TestPingPong(100, 100); +} + +// Test sending 2x-3x MTU of data in each ping/pong. Should take <10ms. +TEST_F(PseudoTcpTestPingPong, TestPingPong3xMtu) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + TestPingPong(400, 100); +} + +// Test sending 1x-2x MTU of data in each ping/pong. +// Should take ~1s, due to interaction between Nagling and Delayed ACK. +TEST_F(PseudoTcpTestPingPong, TestPingPong2xMtu) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + TestPingPong(2000, 5); +} + +// Test sending 1x-2x MTU of data in each ping/pong with Delayed ACK off. +// Should take <10ms. +TEST_F(PseudoTcpTestPingPong, TestPingPong2xMtuWithAckDelayOff) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetOptAckDelay(0); + TestPingPong(2000, 100); +} + +// Test sending 1x-2x MTU of data in each ping/pong with Nagling off. +// Should take <10ms. +TEST_F(PseudoTcpTestPingPong, TestPingPong2xMtuWithNaglingOff) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetOptNagling(false); + TestPingPong(2000, 5); +} + +// Test sending a ping as pair of short (non-full) segments. +// Should take ~1s, due to Delayed ACK interaction with Nagling. +TEST_F(PseudoTcpTestPingPong, TestPingPongShortSegments) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetOptAckDelay(5000); + SetBytesPerSend(50); // i.e. two Send calls per payload + TestPingPong(100, 5); +} + +// Test sending ping as a pair of short (non-full) segments, with Nagling off. +// Should take <10ms. +TEST_F(PseudoTcpTestPingPong, TestPingPongShortSegmentsWithNaglingOff) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetOptNagling(false); + SetBytesPerSend(50); // i.e. two Send calls per payload + TestPingPong(100, 5); +} + +// Test sending <= 1x MTU of data ping/pong, in two segments, no Delayed ACK. +// Should take ~1s. +TEST_F(PseudoTcpTestPingPong, TestPingPongShortSegmentsWithAckDelayOff) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetBytesPerSend(50); // i.e. two Send calls per payload + SetOptAckDelay(0); + TestPingPong(100, 5); +} + +// Test that receive window expands and contract correctly. +TEST_F(PseudoTcpTestReceiveWindow, TestReceiveWindow) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetOptNagling(false); + SetOptAckDelay(0); + TestTransfer(1024 * 1000); +} + +// Test setting send window size to a very small value. +TEST_F(PseudoTcpTestReceiveWindow, TestSetVerySmallSendWindowSize) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetOptNagling(false); + SetOptAckDelay(0); + SetOptSndBuf(900); + TestTransfer(1024 * 1000); + EXPECT_EQ(900u, EstimateSendWindowSize()); +} + +// Test setting receive window size to a value other than default. +TEST_F(PseudoTcpTestReceiveWindow, TestSetReceiveWindowSize) { + SetLocalMtu(1500); + SetRemoteMtu(1500); + SetOptNagling(false); + SetOptAckDelay(0); + SetRemoteOptRcvBuf(100000); + SetLocalOptRcvBuf(100000); + TestTransfer(1024 * 1000); + EXPECT_EQ(100000u, EstimateReceiveWindowSize()); +} + +/* Test sending data with mismatched MTUs. We should detect this and reduce +// our packet size accordingly. +// TODO: This doesn't actually work right now. The current code +// doesn't detect if the MTU is set too high on either side. +TEST_F(PseudoTcpTest, TestSendWithMismatchedMtus) { + SetLocalMtu(1500); + SetRemoteMtu(1280); + TestTransfer(1000000); +} +*/ diff --git a/webrtc/p2p/base/rawtransport.cc b/webrtc/p2p/base/rawtransport.cc new file mode 100644 index 0000000000..cb700ae4a0 --- /dev/null +++ b/webrtc/p2p/base/rawtransport.cc @@ -0,0 +1,2 @@ +// TODO(pthatcher): Remove this file once Chrome's build files no +// longer refer to it. diff --git a/webrtc/p2p/base/rawtransport.h b/webrtc/p2p/base/rawtransport.h new file mode 100644 index 0000000000..cb700ae4a0 --- /dev/null +++ b/webrtc/p2p/base/rawtransport.h @@ -0,0 +1,2 @@ +// TODO(pthatcher): Remove this file once Chrome's build files no +// longer refer to it. diff --git a/webrtc/p2p/base/rawtransportchannel.cc b/webrtc/p2p/base/rawtransportchannel.cc new file mode 100644 index 0000000000..cb700ae4a0 --- /dev/null +++ b/webrtc/p2p/base/rawtransportchannel.cc @@ -0,0 +1,2 @@ +// TODO(pthatcher): Remove this file once Chrome's build files no +// longer refer to it. diff --git a/webrtc/p2p/base/rawtransportchannel.h b/webrtc/p2p/base/rawtransportchannel.h new file mode 100644 index 0000000000..cb700ae4a0 --- /dev/null +++ b/webrtc/p2p/base/rawtransportchannel.h @@ -0,0 +1,2 @@ +// TODO(pthatcher): Remove this file once Chrome's build files no +// longer refer to it. diff --git a/webrtc/p2p/base/relayport.cc b/webrtc/p2p/base/relayport.cc new file mode 100644 index 0000000000..88adcf2f88 --- /dev/null +++ b/webrtc/p2p/base/relayport.cc @@ -0,0 +1,846 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include <algorithm> + +#include "webrtc/p2p/base/relayport.h" +#include "webrtc/base/asyncpacketsocket.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" + +namespace cricket { + +static const uint32_t kMessageConnectTimeout = 1; +static const int kKeepAliveDelay = 10 * 60 * 1000; +static const int kRetryTimeout = 50 * 1000; // ICE says 50 secs +// How long to wait for a socket to connect to remote host in milliseconds +// before trying another connection. +static const int kSoftConnectTimeoutMs = 3 * 1000; + +// Handles a connection to one address/port/protocol combination for a +// particular RelayEntry. +class RelayConnection : public sigslot::has_slots<> { + public: + RelayConnection(const ProtocolAddress* protocol_address, + rtc::AsyncPacketSocket* socket, + rtc::Thread* thread); + ~RelayConnection(); + rtc::AsyncPacketSocket* socket() const { return socket_; } + + const ProtocolAddress* protocol_address() { + return protocol_address_; + } + + rtc::SocketAddress GetAddress() const { + return protocol_address_->address; + } + + ProtocolType GetProtocol() const { + return protocol_address_->proto; + } + + int SetSocketOption(rtc::Socket::Option opt, int value); + + // Validates a response to a STUN allocate request. + bool CheckResponse(StunMessage* msg); + + // Sends data to the relay server. + int Send(const void* pv, size_t cb, const rtc::PacketOptions& options); + + // Sends a STUN allocate request message to the relay server. + void SendAllocateRequest(RelayEntry* entry, int delay); + + // Return the latest error generated by the socket. + int GetError() { return socket_->GetError(); } + + // Called on behalf of a StunRequest to write data to the socket. This is + // already STUN intended for the server, so no wrapping is necessary. + void OnSendPacket(const void* data, size_t size, StunRequest* req); + + private: + rtc::AsyncPacketSocket* socket_; + const ProtocolAddress* protocol_address_; + StunRequestManager *request_manager_; +}; + +// Manages a number of connections to the relayserver, one for each +// available protocol. We aim to use each connection for only a +// specific destination address so that we can avoid wrapping every +// packet in a STUN send / data indication. +class RelayEntry : public rtc::MessageHandler, + public sigslot::has_slots<> { + public: + RelayEntry(RelayPort* port, const rtc::SocketAddress& ext_addr); + ~RelayEntry(); + + RelayPort* port() { return port_; } + + const rtc::SocketAddress& address() const { return ext_addr_; } + void set_address(const rtc::SocketAddress& addr) { ext_addr_ = addr; } + + bool connected() const { return connected_; } + bool locked() const { return locked_; } + + // Returns the last error on the socket of this entry. + int GetError(); + + // Returns the most preferred connection of the given + // ones. Connections are rated based on protocol in the order of: + // UDP, TCP and SSLTCP, where UDP is the most preferred protocol + static RelayConnection* GetBestConnection(RelayConnection* conn1, + RelayConnection* conn2); + + // Sends the STUN requests to the server to initiate this connection. + void Connect(); + + // Called when this entry becomes connected. The address given is the one + // exposed to the outside world on the relay server. + void OnConnect(const rtc::SocketAddress& mapped_addr, + RelayConnection* socket); + + // Sends a packet to the given destination address using the socket of this + // entry. This will wrap the packet in STUN if necessary. + int SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options); + + // Schedules a keep-alive allocate request. + void ScheduleKeepAlive(); + + void SetServerIndex(size_t sindex) { server_index_ = sindex; } + + // Sets this option on the socket of each connection. + int SetSocketOption(rtc::Socket::Option opt, int value); + + size_t ServerIndex() const { return server_index_; } + + // Try a different server address + void HandleConnectFailure(rtc::AsyncPacketSocket* socket); + + // Implementation of the MessageHandler Interface. + virtual void OnMessage(rtc::Message *pmsg); + + private: + RelayPort* port_; + rtc::SocketAddress ext_addr_; + size_t server_index_; + bool connected_; + bool locked_; + RelayConnection* current_connection_; + + // Called when a TCP connection is established or fails + void OnSocketConnect(rtc::AsyncPacketSocket* socket); + void OnSocketClose(rtc::AsyncPacketSocket* socket, int error); + + // Called when a packet is received on this socket. + void OnReadPacket( + rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time); + + void OnSentPacket(rtc::AsyncPacketSocket* socket, + const rtc::SentPacket& sent_packet); + + // Called when the socket is currently able to send. + void OnReadyToSend(rtc::AsyncPacketSocket* socket); + + // Sends the given data on the socket to the server with no wrapping. This + // returns the number of bytes written or -1 if an error occurred. + int SendPacket(const void* data, size_t size, + const rtc::PacketOptions& options); +}; + +// Handles an allocate request for a particular RelayEntry. +class AllocateRequest : public StunRequest { + public: + AllocateRequest(RelayEntry* entry, RelayConnection* connection); + virtual ~AllocateRequest() {} + + void Prepare(StunMessage* request) override; + + void OnSent() override; + int resend_delay() override; + + void OnResponse(StunMessage* response) override; + void OnErrorResponse(StunMessage* response) override; + void OnTimeout() override; + + private: + RelayEntry* entry_; + RelayConnection* connection_; + uint32_t start_time_; +}; + +RelayPort::RelayPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password) + : Port(thread, + RELAY_PORT_TYPE, + factory, + network, + ip, + min_port, + max_port, + username, + password), + ready_(false), + error_(0) { + entries_.push_back( + new RelayEntry(this, rtc::SocketAddress())); + // TODO: set local preference value for TCP based candidates. +} + +RelayPort::~RelayPort() { + for (size_t i = 0; i < entries_.size(); ++i) + delete entries_[i]; + thread()->Clear(this); +} + +void RelayPort::AddServerAddress(const ProtocolAddress& addr) { + // Since HTTP proxies usually only allow 443, + // let's up the priority on PROTO_SSLTCP + if (addr.proto == PROTO_SSLTCP && + (proxy().type == rtc::PROXY_HTTPS || + proxy().type == rtc::PROXY_UNKNOWN)) { + server_addr_.push_front(addr); + } else { + server_addr_.push_back(addr); + } +} + +void RelayPort::AddExternalAddress(const ProtocolAddress& addr) { + std::string proto_name = ProtoToString(addr.proto); + for (std::vector<ProtocolAddress>::iterator it = external_addr_.begin(); + it != external_addr_.end(); ++it) { + if ((it->address == addr.address) && (it->proto == addr.proto)) { + LOG(INFO) << "Redundant relay address: " << proto_name + << " @ " << addr.address.ToSensitiveString(); + return; + } + } + external_addr_.push_back(addr); +} + +void RelayPort::SetReady() { + if (!ready_) { + std::vector<ProtocolAddress>::iterator iter; + for (iter = external_addr_.begin(); + iter != external_addr_.end(); ++iter) { + std::string proto_name = ProtoToString(iter->proto); + // In case of Gturn, related address is set to null socket address. + // This is due to as mapped address stun attribute is used for allocated + // address. + AddAddress(iter->address, iter->address, rtc::SocketAddress(), proto_name, + proto_name, "", RELAY_PORT_TYPE, ICE_TYPE_PREFERENCE_RELAY, 0, + false); + } + ready_ = true; + SignalPortComplete(this); + } +} + +const ProtocolAddress * RelayPort::ServerAddress(size_t index) const { + if (index < server_addr_.size()) + return &server_addr_[index]; + return NULL; +} + +bool RelayPort::HasMagicCookie(const char* data, size_t size) { + if (size < 24 + sizeof(TURN_MAGIC_COOKIE_VALUE)) { + return false; + } else { + return memcmp(data + 24, + TURN_MAGIC_COOKIE_VALUE, + sizeof(TURN_MAGIC_COOKIE_VALUE)) == 0; + } +} + +void RelayPort::PrepareAddress() { + // We initiate a connect on the first entry. If this completes, it will fill + // in the server address as the address of this port. + ASSERT(entries_.size() == 1); + entries_[0]->Connect(); + ready_ = false; +} + +Connection* RelayPort::CreateConnection(const Candidate& address, + CandidateOrigin origin) { + // We only create conns to non-udp sockets if they are incoming on this port + if ((address.protocol() != UDP_PROTOCOL_NAME) && + (origin != ORIGIN_THIS_PORT)) { + return 0; + } + + // We don't support loopback on relays + if (address.type() == Type()) { + return 0; + } + + if (!IsCompatibleAddress(address.address())) { + return 0; + } + + size_t index = 0; + for (size_t i = 0; i < Candidates().size(); ++i) { + const Candidate& local = Candidates()[i]; + if (local.protocol() == address.protocol()) { + index = i; + break; + } + } + + Connection * conn = new ProxyConnection(this, index, address); + AddConnection(conn); + return conn; +} + +int RelayPort::SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, + bool payload) { + // Try to find an entry for this specific address. Note that the first entry + // created was not given an address initially, so it can be set to the first + // address that comes along. + RelayEntry* entry = 0; + + for (size_t i = 0; i < entries_.size(); ++i) { + if (entries_[i]->address().IsNil() && payload) { + entry = entries_[i]; + entry->set_address(addr); + break; + } else if (entries_[i]->address() == addr) { + entry = entries_[i]; + break; + } + } + + // If we did not find one, then we make a new one. This will not be useable + // until it becomes connected, however. + if (!entry && payload) { + entry = new RelayEntry(this, addr); + if (!entries_.empty()) { + entry->SetServerIndex(entries_[0]->ServerIndex()); + } + entry->Connect(); + entries_.push_back(entry); + } + + // If the entry is connected, then we can send on it (though wrapping may + // still be necessary). Otherwise, we can't yet use this connection, so we + // default to the first one. + if (!entry || !entry->connected()) { + ASSERT(!entries_.empty()); + entry = entries_[0]; + if (!entry->connected()) { + error_ = EWOULDBLOCK; + return SOCKET_ERROR; + } + } + + // Send the actual contents to the server using the usual mechanism. + int sent = entry->SendTo(data, size, addr, options); + if (sent <= 0) { + ASSERT(sent < 0); + error_ = entry->GetError(); + return SOCKET_ERROR; + } + // The caller of the function is expecting the number of user data bytes, + // rather than the size of the packet. + return static_cast<int>(size); +} + +int RelayPort::SetOption(rtc::Socket::Option opt, int value) { + int result = 0; + for (size_t i = 0; i < entries_.size(); ++i) { + if (entries_[i]->SetSocketOption(opt, value) < 0) { + result = -1; + error_ = entries_[i]->GetError(); + } + } + options_.push_back(OptionValue(opt, value)); + return result; +} + +int RelayPort::GetOption(rtc::Socket::Option opt, int* value) { + std::vector<OptionValue>::iterator it; + for (it = options_.begin(); it < options_.end(); ++it) { + if (it->first == opt) { + *value = it->second; + return 0; + } + } + return SOCKET_ERROR; +} + +int RelayPort::GetError() { + return error_; +} + +void RelayPort::OnReadPacket( + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + ProtocolType proto, + const rtc::PacketTime& packet_time) { + if (Connection* conn = GetConnection(remote_addr)) { + conn->OnReadPacket(data, size, packet_time); + } else { + Port::OnReadPacket(data, size, remote_addr, proto); + } +} + +RelayConnection::RelayConnection(const ProtocolAddress* protocol_address, + rtc::AsyncPacketSocket* socket, + rtc::Thread* thread) + : socket_(socket), + protocol_address_(protocol_address) { + request_manager_ = new StunRequestManager(thread); + request_manager_->SignalSendPacket.connect(this, + &RelayConnection::OnSendPacket); +} + +RelayConnection::~RelayConnection() { + delete request_manager_; + delete socket_; +} + +int RelayConnection::SetSocketOption(rtc::Socket::Option opt, + int value) { + if (socket_) { + return socket_->SetOption(opt, value); + } + return 0; +} + +bool RelayConnection::CheckResponse(StunMessage* msg) { + return request_manager_->CheckResponse(msg); +} + +void RelayConnection::OnSendPacket(const void* data, size_t size, + StunRequest* req) { + // TODO(mallinath) Find a way to get DSCP value from Port. + rtc::PacketOptions options; // Default dscp set to NO_CHANGE. + int sent = socket_->SendTo(data, size, GetAddress(), options); + if (sent <= 0) { + LOG(LS_VERBOSE) << "OnSendPacket: failed sending to " << GetAddress() << + strerror(socket_->GetError()); + ASSERT(sent < 0); + } +} + +int RelayConnection::Send(const void* pv, size_t cb, + const rtc::PacketOptions& options) { + return socket_->SendTo(pv, cb, GetAddress(), options); +} + +void RelayConnection::SendAllocateRequest(RelayEntry* entry, int delay) { + request_manager_->SendDelayed(new AllocateRequest(entry, this), delay); +} + +RelayEntry::RelayEntry(RelayPort* port, + const rtc::SocketAddress& ext_addr) + : port_(port), ext_addr_(ext_addr), + server_index_(0), connected_(false), locked_(false), + current_connection_(NULL) { +} + +RelayEntry::~RelayEntry() { + // Remove all RelayConnections and dispose sockets. + delete current_connection_; + current_connection_ = NULL; +} + +void RelayEntry::Connect() { + // If we're already connected, return. + if (connected_) + return; + + // If we've exhausted all options, bail out. + const ProtocolAddress* ra = port()->ServerAddress(server_index_); + if (!ra) { + LOG(LS_WARNING) << "No more relay addresses left to try"; + return; + } + + // Remove any previous connection. + if (current_connection_) { + port()->thread()->Dispose(current_connection_); + current_connection_ = NULL; + } + + // Try to set up our new socket. + LOG(LS_INFO) << "Connecting to relay via " << ProtoToString(ra->proto) << + " @ " << ra->address.ToSensitiveString(); + + rtc::AsyncPacketSocket* socket = NULL; + + if (ra->proto == PROTO_UDP) { + // UDP sockets are simple. + socket = port_->socket_factory()->CreateUdpSocket( + rtc::SocketAddress(port_->ip(), 0), + port_->min_port(), port_->max_port()); + } else if (ra->proto == PROTO_TCP || ra->proto == PROTO_SSLTCP) { + int opts = (ra->proto == PROTO_SSLTCP) ? + rtc::PacketSocketFactory::OPT_SSLTCP : 0; + socket = port_->socket_factory()->CreateClientTcpSocket( + rtc::SocketAddress(port_->ip(), 0), ra->address, + port_->proxy(), port_->user_agent(), opts); + } else { + LOG(LS_WARNING) << "Unknown protocol (" << ra->proto << ")"; + } + + if (!socket) { + LOG(LS_WARNING) << "Socket creation failed"; + } + + // If we failed to get a socket, move on to the next protocol. + if (!socket) { + port()->thread()->Post(this, kMessageConnectTimeout); + return; + } + + // Otherwise, create the new connection and configure any socket options. + socket->SignalReadPacket.connect(this, &RelayEntry::OnReadPacket); + socket->SignalSentPacket.connect(this, &RelayEntry::OnSentPacket); + socket->SignalReadyToSend.connect(this, &RelayEntry::OnReadyToSend); + current_connection_ = new RelayConnection(ra, socket, port()->thread()); + for (size_t i = 0; i < port_->options().size(); ++i) { + current_connection_->SetSocketOption(port_->options()[i].first, + port_->options()[i].second); + } + + // If we're trying UDP, start binding requests. + // If we're trying TCP, wait for connection with a fixed timeout. + if ((ra->proto == PROTO_TCP) || (ra->proto == PROTO_SSLTCP)) { + socket->SignalClose.connect(this, &RelayEntry::OnSocketClose); + socket->SignalConnect.connect(this, &RelayEntry::OnSocketConnect); + port()->thread()->PostDelayed(kSoftConnectTimeoutMs, this, + kMessageConnectTimeout); + } else { + current_connection_->SendAllocateRequest(this, 0); + } +} + +int RelayEntry::GetError() { + if (current_connection_ != NULL) { + return current_connection_->GetError(); + } + return 0; +} + +RelayConnection* RelayEntry::GetBestConnection(RelayConnection* conn1, + RelayConnection* conn2) { + return conn1->GetProtocol() <= conn2->GetProtocol() ? conn1 : conn2; +} + +void RelayEntry::OnConnect(const rtc::SocketAddress& mapped_addr, + RelayConnection* connection) { + // We are connected, notify our parent. + ProtocolType proto = PROTO_UDP; + LOG(INFO) << "Relay allocate succeeded: " << ProtoToString(proto) + << " @ " << mapped_addr.ToSensitiveString(); + connected_ = true; + + port_->AddExternalAddress(ProtocolAddress(mapped_addr, proto)); + port_->SetReady(); +} + +int RelayEntry::SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options) { + // If this connection is locked to the address given, then we can send the + // packet with no wrapper. + if (locked_ && (ext_addr_ == addr)) + return SendPacket(data, size, options); + + // Otherwise, we must wrap the given data in a STUN SEND request so that we + // can communicate the destination address to the server. + // + // Note that we do not use a StunRequest here. This is because there is + // likely no reason to resend this packet. If it is late, we just drop it. + // The next send to this address will try again. + + RelayMessage request; + request.SetType(STUN_SEND_REQUEST); + + StunByteStringAttribute* magic_cookie_attr = + StunAttribute::CreateByteString(STUN_ATTR_MAGIC_COOKIE); + magic_cookie_attr->CopyBytes(TURN_MAGIC_COOKIE_VALUE, + sizeof(TURN_MAGIC_COOKIE_VALUE)); + VERIFY(request.AddAttribute(magic_cookie_attr)); + + StunByteStringAttribute* username_attr = + StunAttribute::CreateByteString(STUN_ATTR_USERNAME); + username_attr->CopyBytes(port_->username_fragment().c_str(), + port_->username_fragment().size()); + VERIFY(request.AddAttribute(username_attr)); + + StunAddressAttribute* addr_attr = + StunAttribute::CreateAddress(STUN_ATTR_DESTINATION_ADDRESS); + addr_attr->SetIP(addr.ipaddr()); + addr_attr->SetPort(addr.port()); + VERIFY(request.AddAttribute(addr_attr)); + + // Attempt to lock + if (ext_addr_ == addr) { + StunUInt32Attribute* options_attr = + StunAttribute::CreateUInt32(STUN_ATTR_OPTIONS); + options_attr->SetValue(0x1); + VERIFY(request.AddAttribute(options_attr)); + } + + StunByteStringAttribute* data_attr = + StunAttribute::CreateByteString(STUN_ATTR_DATA); + data_attr->CopyBytes(data, size); + VERIFY(request.AddAttribute(data_attr)); + + // TODO: compute the HMAC. + + rtc::ByteBuffer buf; + request.Write(&buf); + + return SendPacket(buf.Data(), buf.Length(), options); +} + +void RelayEntry::ScheduleKeepAlive() { + if (current_connection_) { + current_connection_->SendAllocateRequest(this, kKeepAliveDelay); + } +} + +int RelayEntry::SetSocketOption(rtc::Socket::Option opt, int value) { + // Set the option on all available sockets. + int socket_error = 0; + if (current_connection_) { + socket_error = current_connection_->SetSocketOption(opt, value); + } + return socket_error; +} + +void RelayEntry::HandleConnectFailure( + rtc::AsyncPacketSocket* socket) { + // Make sure it's the current connection that has failed, it might + // be an old socked that has not yet been disposed. + if (!socket || + (current_connection_ && socket == current_connection_->socket())) { + if (current_connection_) + port()->SignalConnectFailure(current_connection_->protocol_address()); + + // Try to connect to the next server address. + server_index_ += 1; + Connect(); + } +} + +void RelayEntry::OnMessage(rtc::Message *pmsg) { + ASSERT(pmsg->message_id == kMessageConnectTimeout); + if (current_connection_) { + const ProtocolAddress* ra = current_connection_->protocol_address(); + LOG(LS_WARNING) << "Relay " << ra->proto << " connection to " << + ra->address << " timed out"; + + // Currently we connect to each server address in sequence. If we + // have more addresses to try, treat this is an error and move on to + // the next address, otherwise give this connection more time and + // await the real timeout. + // + // TODO: Connect to servers in parallel to speed up connect time + // and to avoid giving up too early. + port_->SignalSoftTimeout(ra); + HandleConnectFailure(current_connection_->socket()); + } else { + HandleConnectFailure(NULL); + } +} + +void RelayEntry::OnSocketConnect(rtc::AsyncPacketSocket* socket) { + LOG(INFO) << "relay tcp connected to " << + socket->GetRemoteAddress().ToSensitiveString(); + if (current_connection_ != NULL) { + current_connection_->SendAllocateRequest(this, 0); + } +} + +void RelayEntry::OnSocketClose(rtc::AsyncPacketSocket* socket, + int error) { + PLOG(LERROR, error) << "Relay connection failed: socket closed"; + HandleConnectFailure(socket); +} + +void RelayEntry::OnReadPacket( + rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + // ASSERT(remote_addr == port_->server_addr()); + // TODO: are we worried about this? + + if (current_connection_ == NULL || socket != current_connection_->socket()) { + // This packet comes from an unknown address. + LOG(WARNING) << "Dropping packet: unknown address"; + return; + } + + // If the magic cookie is not present, then this is an unwrapped packet sent + // by the server, The actual remote address is the one we recorded. + if (!port_->HasMagicCookie(data, size)) { + if (locked_) { + port_->OnReadPacket(data, size, ext_addr_, PROTO_UDP, packet_time); + } else { + LOG(WARNING) << "Dropping packet: entry not locked"; + } + return; + } + + rtc::ByteBuffer buf(data, size); + RelayMessage msg; + if (!msg.Read(&buf)) { + LOG(INFO) << "Incoming packet was not STUN"; + return; + } + + // The incoming packet should be a STUN ALLOCATE response, SEND response, or + // DATA indication. + if (current_connection_->CheckResponse(&msg)) { + return; + } else if (msg.type() == STUN_SEND_RESPONSE) { + if (const StunUInt32Attribute* options_attr = + msg.GetUInt32(STUN_ATTR_OPTIONS)) { + if (options_attr->value() & 0x1) { + locked_ = true; + } + } + return; + } else if (msg.type() != STUN_DATA_INDICATION) { + LOG(INFO) << "Received BAD stun type from server: " << msg.type(); + return; + } + + // This must be a data indication. + + const StunAddressAttribute* addr_attr = + msg.GetAddress(STUN_ATTR_SOURCE_ADDRESS2); + if (!addr_attr) { + LOG(INFO) << "Data indication has no source address"; + return; + } else if (addr_attr->family() != 1) { + LOG(INFO) << "Source address has bad family"; + return; + } + + rtc::SocketAddress remote_addr2(addr_attr->ipaddr(), addr_attr->port()); + + const StunByteStringAttribute* data_attr = msg.GetByteString(STUN_ATTR_DATA); + if (!data_attr) { + LOG(INFO) << "Data indication has no data"; + return; + } + + // Process the actual data and remote address in the normal manner. + port_->OnReadPacket(data_attr->bytes(), data_attr->length(), remote_addr2, + PROTO_UDP, packet_time); +} + +void RelayEntry::OnSentPacket(rtc::AsyncPacketSocket* socket, + const rtc::SentPacket& sent_packet) { + port_->OnSentPacket(sent_packet); +} + +void RelayEntry::OnReadyToSend(rtc::AsyncPacketSocket* socket) { + if (connected()) { + port_->OnReadyToSend(); + } +} + +int RelayEntry::SendPacket(const void* data, size_t size, + const rtc::PacketOptions& options) { + int sent = 0; + if (current_connection_) { + // We are connected, no need to send packets anywere else than to + // the current connection. + sent = current_connection_->Send(data, size, options); + } + return sent; +} + +AllocateRequest::AllocateRequest(RelayEntry* entry, + RelayConnection* connection) + : StunRequest(new RelayMessage()), + entry_(entry), + connection_(connection) { + start_time_ = rtc::Time(); +} + +void AllocateRequest::Prepare(StunMessage* request) { + request->SetType(STUN_ALLOCATE_REQUEST); + + StunByteStringAttribute* username_attr = + StunAttribute::CreateByteString(STUN_ATTR_USERNAME); + username_attr->CopyBytes( + entry_->port()->username_fragment().c_str(), + entry_->port()->username_fragment().size()); + VERIFY(request->AddAttribute(username_attr)); +} + +void AllocateRequest::OnSent() { + count_ += 1; + if (count_ == 5) + timeout_ = true; +} + +int AllocateRequest::resend_delay() { + if (count_ == 0) { + return 0; + } + return 100 * std::max(1 << (count_-1), 2); +} + + +void AllocateRequest::OnResponse(StunMessage* response) { + const StunAddressAttribute* addr_attr = + response->GetAddress(STUN_ATTR_MAPPED_ADDRESS); + if (!addr_attr) { + LOG(INFO) << "Allocate response missing mapped address."; + } else if (addr_attr->family() != 1) { + LOG(INFO) << "Mapped address has bad family"; + } else { + rtc::SocketAddress addr(addr_attr->ipaddr(), addr_attr->port()); + entry_->OnConnect(addr, connection_); + } + + // We will do a keep-alive regardless of whether this request suceeds. + // This should have almost no impact on network usage. + entry_->ScheduleKeepAlive(); +} + +void AllocateRequest::OnErrorResponse(StunMessage* response) { + const StunErrorCodeAttribute* attr = response->GetErrorCode(); + if (!attr) { + LOG(INFO) << "Bad allocate response error code"; + } else { + LOG(INFO) << "Allocate error response:" + << " code=" << attr->code() + << " reason='" << attr->reason() << "'"; + } + + if (rtc::TimeSince(start_time_) <= kRetryTimeout) + entry_->ScheduleKeepAlive(); +} + +void AllocateRequest::OnTimeout() { + LOG(INFO) << "Allocate request timed out"; + entry_->HandleConnectFailure(connection_->socket()); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/relayport.h b/webrtc/p2p/base/relayport.h new file mode 100644 index 0000000000..8452b5b430 --- /dev/null +++ b/webrtc/p2p/base/relayport.h @@ -0,0 +1,108 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_RELAYPORT_H_ +#define WEBRTC_P2P_BASE_RELAYPORT_H_ + +#include <deque> +#include <string> +#include <utility> +#include <vector> + +#include "webrtc/p2p/base/port.h" +#include "webrtc/p2p/base/stunrequest.h" + +namespace cricket { + +class RelayEntry; +class RelayConnection; + +// Communicates using an allocated port on the relay server. For each +// remote candidate that we try to send data to a RelayEntry instance +// is created. The RelayEntry will try to reach the remote destination +// by connecting to all available server addresses in a pre defined +// order with a small delay in between. When a connection is +// successful all other connection attemts are aborted. +class RelayPort : public Port { + public: + typedef std::pair<rtc::Socket::Option, int> OptionValue; + + // RelayPort doesn't yet do anything fancy in the ctor. + static RelayPort* Create(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password) { + return new RelayPort(thread, factory, network, ip, min_port, max_port, + username, password); + } + virtual ~RelayPort(); + + void AddServerAddress(const ProtocolAddress& addr); + void AddExternalAddress(const ProtocolAddress& addr); + + const std::vector<OptionValue>& options() const { return options_; } + bool HasMagicCookie(const char* data, size_t size); + + virtual void PrepareAddress(); + virtual Connection* CreateConnection(const Candidate& address, + CandidateOrigin origin); + virtual int SetOption(rtc::Socket::Option opt, int value); + virtual int GetOption(rtc::Socket::Option opt, int* value); + virtual int GetError(); + + const ProtocolAddress * ServerAddress(size_t index) const; + bool IsReady() { return ready_; } + + // Used for testing. + sigslot::signal1<const ProtocolAddress*> SignalConnectFailure; + sigslot::signal1<const ProtocolAddress*> SignalSoftTimeout; + + protected: + RelayPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network*, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password); + bool Init(); + + void SetReady(); + + virtual int SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, + bool payload); + + // Dispatches the given packet to the port or connection as appropriate. + void OnReadPacket(const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + ProtocolType proto, + const rtc::PacketTime& packet_time); + + private: + friend class RelayEntry; + + std::deque<ProtocolAddress> server_addr_; + std::vector<ProtocolAddress> external_addr_; + bool ready_; + std::vector<RelayEntry*> entries_; + std::vector<OptionValue> options_; + int error_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_RELAYPORT_H_ diff --git a/webrtc/p2p/base/relayport_unittest.cc b/webrtc/p2p/base/relayport_unittest.cc new file mode 100644 index 0000000000..d644d67c4f --- /dev/null +++ b/webrtc/p2p/base/relayport_unittest.cc @@ -0,0 +1,272 @@ +/* + * Copyright 2009 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/basicpacketsocketfactory.h" +#include "webrtc/p2p/base/relayport.h" +#include "webrtc/p2p/base/relayserver.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/physicalsocketserver.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/socketadapters.h" +#include "webrtc/base/socketaddress.h" +#include "webrtc/base/ssladapter.h" +#include "webrtc/base/thread.h" +#include "webrtc/base/virtualsocketserver.h" + +using rtc::SocketAddress; + +static const SocketAddress kLocalAddress = SocketAddress("192.168.1.2", 0); +static const SocketAddress kRelayUdpAddr = SocketAddress("99.99.99.1", 5000); +static const SocketAddress kRelayTcpAddr = SocketAddress("99.99.99.2", 5001); +static const SocketAddress kRelaySslAddr = SocketAddress("99.99.99.3", 443); +static const SocketAddress kRelayExtAddr = SocketAddress("99.99.99.3", 5002); + +static const int kTimeoutMs = 1000; +static const int kMaxTimeoutMs = 5000; + +// Tests connecting a RelayPort to a fake relay server +// (cricket::RelayServer) using all currently available protocols. The +// network layer is faked out by using a VirtualSocketServer for +// creating sockets. The test will monitor the current state of the +// RelayPort and created sockets by listening for signals such as, +// SignalConnectFailure, SignalConnectTimeout, SignalSocketClosed and +// SignalReadPacket. +class RelayPortTest : public testing::Test, + public sigslot::has_slots<> { + public: + RelayPortTest() + : main_(rtc::Thread::Current()), + physical_socket_server_(new rtc::PhysicalSocketServer), + virtual_socket_server_(new rtc::VirtualSocketServer( + physical_socket_server_.get())), + ss_scope_(virtual_socket_server_.get()), + network_("unittest", "unittest", rtc::IPAddress(INADDR_ANY), 32), + socket_factory_(rtc::Thread::Current()), + username_(rtc::CreateRandomString(16)), + password_(rtc::CreateRandomString(16)), + relay_port_(cricket::RelayPort::Create(main_, &socket_factory_, + &network_, + kLocalAddress.ipaddr(), + 0, 0, username_, password_)), + relay_server_(new cricket::RelayServer(main_)) { + } + + void OnReadPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + received_packet_count_[socket]++; + } + + void OnConnectFailure(const cricket::ProtocolAddress* addr) { + failed_connections_.push_back(*addr); + } + + void OnSoftTimeout(const cricket::ProtocolAddress* addr) { + soft_timedout_connections_.push_back(*addr); + } + + protected: + virtual void SetUp() { + // The relay server needs an external socket to work properly. + rtc::AsyncUDPSocket* ext_socket = + CreateAsyncUdpSocket(kRelayExtAddr); + relay_server_->AddExternalSocket(ext_socket); + + // Listen for failures. + relay_port_->SignalConnectFailure. + connect(this, &RelayPortTest::OnConnectFailure); + + // Listen for soft timeouts. + relay_port_->SignalSoftTimeout. + connect(this, &RelayPortTest::OnSoftTimeout); + } + + // Udp has the highest 'goodness' value of the three different + // protocols used for connecting to the relay server. As soon as + // PrepareAddress is called, the RelayPort will start trying to + // connect to the given UDP address. As soon as a response to the + // sent STUN allocate request message has been received, the + // RelayPort will consider the connection to be complete and will + // abort any other connection attempts. + void TestConnectUdp() { + // Add a UDP socket to the relay server. + rtc::AsyncUDPSocket* internal_udp_socket = + CreateAsyncUdpSocket(kRelayUdpAddr); + rtc::AsyncSocket* server_socket = CreateServerSocket(kRelayTcpAddr); + + relay_server_->AddInternalSocket(internal_udp_socket); + relay_server_->AddInternalServerSocket(server_socket, cricket::PROTO_TCP); + + // Now add our relay addresses to the relay port and let it start. + relay_port_->AddServerAddress( + cricket::ProtocolAddress(kRelayUdpAddr, cricket::PROTO_UDP)); + relay_port_->AddServerAddress( + cricket::ProtocolAddress(kRelayTcpAddr, cricket::PROTO_TCP)); + relay_port_->PrepareAddress(); + + // Should be connected. + EXPECT_TRUE_WAIT(relay_port_->IsReady(), kTimeoutMs); + + // Make sure that we are happy with UDP, ie. not continuing with + // TCP, SSLTCP, etc. + WAIT(relay_server_->HasConnection(kRelayTcpAddr), kTimeoutMs); + + // Should have only one connection. + EXPECT_EQ(1, relay_server_->GetConnectionCount()); + + // Should be the UDP address. + EXPECT_TRUE(relay_server_->HasConnection(kRelayUdpAddr)); + } + + // TCP has the second best 'goodness' value, and as soon as UDP + // connection has failed, the RelayPort will attempt to connect via + // TCP. Here we add a fake UDP address together with a real TCP + // address to simulate an UDP failure. As soon as UDP has failed the + // RelayPort will try the TCP adress and succed. + void TestConnectTcp() { + // Create a fake UDP address for relay port to simulate a failure. + cricket::ProtocolAddress fake_protocol_address = + cricket::ProtocolAddress(kRelayUdpAddr, cricket::PROTO_UDP); + + // Create a server socket for the RelayServer. + rtc::AsyncSocket* server_socket = CreateServerSocket(kRelayTcpAddr); + relay_server_->AddInternalServerSocket(server_socket, cricket::PROTO_TCP); + + // Add server addresses to the relay port and let it start. + relay_port_->AddServerAddress( + cricket::ProtocolAddress(fake_protocol_address)); + relay_port_->AddServerAddress( + cricket::ProtocolAddress(kRelayTcpAddr, cricket::PROTO_TCP)); + relay_port_->PrepareAddress(); + + EXPECT_FALSE(relay_port_->IsReady()); + + // Should have timed out in 200 + 200 + 400 + 800 + 1600 ms. + EXPECT_TRUE_WAIT(HasFailed(&fake_protocol_address), 3600); + + // Wait until relayport is ready. + EXPECT_TRUE_WAIT(relay_port_->IsReady(), kMaxTimeoutMs); + + // Should have only one connection. + EXPECT_EQ(1, relay_server_->GetConnectionCount()); + + // Should be the TCP address. + EXPECT_TRUE(relay_server_->HasConnection(kRelayTcpAddr)); + } + + void TestConnectSslTcp() { + // Create a fake TCP address for relay port to simulate a failure. + // We skip UDP here since transition from UDP to TCP has been + // tested above. + cricket::ProtocolAddress fake_protocol_address = + cricket::ProtocolAddress(kRelayTcpAddr, cricket::PROTO_TCP); + + // Create a ssl server socket for the RelayServer. + rtc::AsyncSocket* ssl_server_socket = + CreateServerSocket(kRelaySslAddr); + relay_server_->AddInternalServerSocket(ssl_server_socket, + cricket::PROTO_SSLTCP); + + // Create a tcp server socket that listens on the fake address so + // the relay port can attempt to connect to it. + rtc::scoped_ptr<rtc::AsyncSocket> tcp_server_socket( + CreateServerSocket(kRelayTcpAddr)); + + // Add server addresses to the relay port and let it start. + relay_port_->AddServerAddress(fake_protocol_address); + relay_port_->AddServerAddress( + cricket::ProtocolAddress(kRelaySslAddr, cricket::PROTO_SSLTCP)); + relay_port_->PrepareAddress(); + EXPECT_FALSE(relay_port_->IsReady()); + + // Should have timed out in 3000 ms(relayport.cc, kSoftConnectTimeoutMs). + EXPECT_TRUE_WAIT_MARGIN(HasTimedOut(&fake_protocol_address), 3000, 100); + + // Wait until relayport is ready. + EXPECT_TRUE_WAIT(relay_port_->IsReady(), kMaxTimeoutMs); + + // Should have only one connection. + EXPECT_EQ(1, relay_server_->GetConnectionCount()); + + // Should be the SSLTCP address. + EXPECT_TRUE(relay_server_->HasConnection(kRelaySslAddr)); + } + + private: + rtc::AsyncUDPSocket* CreateAsyncUdpSocket(const SocketAddress addr) { + rtc::AsyncSocket* socket = + virtual_socket_server_->CreateAsyncSocket(SOCK_DGRAM); + rtc::AsyncUDPSocket* packet_socket = + rtc::AsyncUDPSocket::Create(socket, addr); + EXPECT_TRUE(packet_socket != NULL); + packet_socket->SignalReadPacket.connect(this, &RelayPortTest::OnReadPacket); + return packet_socket; + } + + rtc::AsyncSocket* CreateServerSocket(const SocketAddress addr) { + rtc::AsyncSocket* socket = + virtual_socket_server_->CreateAsyncSocket(SOCK_STREAM); + EXPECT_GE(socket->Bind(addr), 0); + EXPECT_GE(socket->Listen(5), 0); + return socket; + } + + bool HasFailed(cricket::ProtocolAddress* addr) { + for (size_t i = 0; i < failed_connections_.size(); i++) { + if (failed_connections_[i].address == addr->address && + failed_connections_[i].proto == addr->proto) { + return true; + } + } + return false; + } + + bool HasTimedOut(cricket::ProtocolAddress* addr) { + for (size_t i = 0; i < soft_timedout_connections_.size(); i++) { + if (soft_timedout_connections_[i].address == addr->address && + soft_timedout_connections_[i].proto == addr->proto) { + return true; + } + } + return false; + } + + typedef std::map<rtc::AsyncPacketSocket*, int> PacketMap; + + rtc::Thread* main_; + rtc::scoped_ptr<rtc::PhysicalSocketServer> + physical_socket_server_; + rtc::scoped_ptr<rtc::VirtualSocketServer> virtual_socket_server_; + rtc::SocketServerScope ss_scope_; + rtc::Network network_; + rtc::BasicPacketSocketFactory socket_factory_; + std::string username_; + std::string password_; + rtc::scoped_ptr<cricket::RelayPort> relay_port_; + rtc::scoped_ptr<cricket::RelayServer> relay_server_; + std::vector<cricket::ProtocolAddress> failed_connections_; + std::vector<cricket::ProtocolAddress> soft_timedout_connections_; + PacketMap received_packet_count_; +}; + +TEST_F(RelayPortTest, ConnectUdp) { + TestConnectUdp(); +} + +TEST_F(RelayPortTest, ConnectTcp) { + TestConnectTcp(); +} + +TEST_F(RelayPortTest, ConnectSslTcp) { + TestConnectSslTcp(); +} diff --git a/webrtc/p2p/base/relayserver.cc b/webrtc/p2p/base/relayserver.cc new file mode 100644 index 0000000000..e208d70d0f --- /dev/null +++ b/webrtc/p2p/base/relayserver.cc @@ -0,0 +1,749 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/relayserver.h" + +#ifdef WEBRTC_POSIX +#include <errno.h> +#endif // WEBRTC_POSIX + +#include <algorithm> + +#include "webrtc/base/asynctcpsocket.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/socketadapters.h" + +namespace cricket { + +// By default, we require a ping every 90 seconds. +const int MAX_LIFETIME = 15 * 60 * 1000; + +// The number of bytes in each of the usernames we use. +const uint32_t USERNAME_LENGTH = 16; + +// Calls SendTo on the given socket and logs any bad results. +void Send(rtc::AsyncPacketSocket* socket, const char* bytes, size_t size, + const rtc::SocketAddress& addr) { + rtc::PacketOptions options; + int result = socket->SendTo(bytes, size, addr, options); + if (result < static_cast<int>(size)) { + LOG(LS_ERROR) << "SendTo wrote only " << result << " of " << size + << " bytes"; + } else if (result < 0) { + LOG_ERR(LS_ERROR) << "SendTo"; + } +} + +// Sends the given STUN message on the given socket. +void SendStun(const StunMessage& msg, + rtc::AsyncPacketSocket* socket, + const rtc::SocketAddress& addr) { + rtc::ByteBuffer buf; + msg.Write(&buf); + Send(socket, buf.Data(), buf.Length(), addr); +} + +// Constructs a STUN error response and sends it on the given socket. +void SendStunError(const StunMessage& msg, rtc::AsyncPacketSocket* socket, + const rtc::SocketAddress& remote_addr, int error_code, + const char* error_desc, const std::string& magic_cookie) { + RelayMessage err_msg; + err_msg.SetType(GetStunErrorResponseType(msg.type())); + err_msg.SetTransactionID(msg.transaction_id()); + + StunByteStringAttribute* magic_cookie_attr = + StunAttribute::CreateByteString(cricket::STUN_ATTR_MAGIC_COOKIE); + if (magic_cookie.size() == 0) { + magic_cookie_attr->CopyBytes(cricket::TURN_MAGIC_COOKIE_VALUE, + sizeof(cricket::TURN_MAGIC_COOKIE_VALUE)); + } else { + magic_cookie_attr->CopyBytes(magic_cookie.c_str(), magic_cookie.size()); + } + err_msg.AddAttribute(magic_cookie_attr); + + StunErrorCodeAttribute* err_code = StunAttribute::CreateErrorCode(); + err_code->SetClass(error_code / 100); + err_code->SetNumber(error_code % 100); + err_code->SetReason(error_desc); + err_msg.AddAttribute(err_code); + + SendStun(err_msg, socket, remote_addr); +} + +RelayServer::RelayServer(rtc::Thread* thread) + : thread_(thread), log_bindings_(true) { +} + +RelayServer::~RelayServer() { + // Deleting the binding will cause it to be removed from the map. + while (!bindings_.empty()) + delete bindings_.begin()->second; + for (size_t i = 0; i < internal_sockets_.size(); ++i) + delete internal_sockets_[i]; + for (size_t i = 0; i < external_sockets_.size(); ++i) + delete external_sockets_[i]; + for (size_t i = 0; i < removed_sockets_.size(); ++i) + delete removed_sockets_[i]; + while (!server_sockets_.empty()) { + rtc::AsyncSocket* socket = server_sockets_.begin()->first; + server_sockets_.erase(server_sockets_.begin()->first); + delete socket; + } +} + +void RelayServer::AddInternalSocket(rtc::AsyncPacketSocket* socket) { + ASSERT(internal_sockets_.end() == + std::find(internal_sockets_.begin(), internal_sockets_.end(), socket)); + internal_sockets_.push_back(socket); + socket->SignalReadPacket.connect(this, &RelayServer::OnInternalPacket); +} + +void RelayServer::RemoveInternalSocket(rtc::AsyncPacketSocket* socket) { + SocketList::iterator iter = + std::find(internal_sockets_.begin(), internal_sockets_.end(), socket); + ASSERT(iter != internal_sockets_.end()); + internal_sockets_.erase(iter); + removed_sockets_.push_back(socket); + socket->SignalReadPacket.disconnect(this); +} + +void RelayServer::AddExternalSocket(rtc::AsyncPacketSocket* socket) { + ASSERT(external_sockets_.end() == + std::find(external_sockets_.begin(), external_sockets_.end(), socket)); + external_sockets_.push_back(socket); + socket->SignalReadPacket.connect(this, &RelayServer::OnExternalPacket); +} + +void RelayServer::RemoveExternalSocket(rtc::AsyncPacketSocket* socket) { + SocketList::iterator iter = + std::find(external_sockets_.begin(), external_sockets_.end(), socket); + ASSERT(iter != external_sockets_.end()); + external_sockets_.erase(iter); + removed_sockets_.push_back(socket); + socket->SignalReadPacket.disconnect(this); +} + +void RelayServer::AddInternalServerSocket(rtc::AsyncSocket* socket, + cricket::ProtocolType proto) { + ASSERT(server_sockets_.end() == + server_sockets_.find(socket)); + server_sockets_[socket] = proto; + socket->SignalReadEvent.connect(this, &RelayServer::OnReadEvent); +} + +void RelayServer::RemoveInternalServerSocket( + rtc::AsyncSocket* socket) { + ServerSocketMap::iterator iter = server_sockets_.find(socket); + ASSERT(iter != server_sockets_.end()); + server_sockets_.erase(iter); + socket->SignalReadEvent.disconnect(this); +} + +int RelayServer::GetConnectionCount() const { + return static_cast<int>(connections_.size()); +} + +rtc::SocketAddressPair RelayServer::GetConnection(int connection) const { + int i = 0; + for (ConnectionMap::const_iterator it = connections_.begin(); + it != connections_.end(); ++it) { + if (i == connection) { + return it->second->addr_pair(); + } + ++i; + } + return rtc::SocketAddressPair(); +} + +bool RelayServer::HasConnection(const rtc::SocketAddress& address) const { + for (ConnectionMap::const_iterator it = connections_.begin(); + it != connections_.end(); ++it) { + if (it->second->addr_pair().destination() == address) { + return true; + } + } + return false; +} + +void RelayServer::OnReadEvent(rtc::AsyncSocket* socket) { + ASSERT(server_sockets_.find(socket) != server_sockets_.end()); + AcceptConnection(socket); +} + +void RelayServer::OnInternalPacket( + rtc::AsyncPacketSocket* socket, const char* bytes, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + + // Get the address of the connection we just received on. + rtc::SocketAddressPair ap(remote_addr, socket->GetLocalAddress()); + ASSERT(!ap.destination().IsNil()); + + // If this did not come from an existing connection, it should be a STUN + // allocate request. + ConnectionMap::iterator piter = connections_.find(ap); + if (piter == connections_.end()) { + HandleStunAllocate(bytes, size, ap, socket); + return; + } + + RelayServerConnection* int_conn = piter->second; + + // Handle STUN requests to the server itself. + if (int_conn->binding()->HasMagicCookie(bytes, size)) { + HandleStun(int_conn, bytes, size); + return; + } + + // Otherwise, this is a non-wrapped packet that we are to forward. Make sure + // that this connection has been locked. (Otherwise, we would not know what + // address to forward to.) + if (!int_conn->locked()) { + LOG(LS_WARNING) << "Dropping packet: connection not locked"; + return; + } + + // Forward this to the destination address into the connection. + RelayServerConnection* ext_conn = int_conn->binding()->GetExternalConnection( + int_conn->default_destination()); + if (ext_conn && ext_conn->locked()) { + // TODO: Check the HMAC. + ext_conn->Send(bytes, size); + } else { + // This happens very often and is not an error. + LOG(LS_INFO) << "Dropping packet: no external connection"; + } +} + +void RelayServer::OnExternalPacket( + rtc::AsyncPacketSocket* socket, const char* bytes, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + + // Get the address of the connection we just received on. + rtc::SocketAddressPair ap(remote_addr, socket->GetLocalAddress()); + ASSERT(!ap.destination().IsNil()); + + // If this connection already exists, then forward the traffic. + ConnectionMap::iterator piter = connections_.find(ap); + if (piter != connections_.end()) { + // TODO: Check the HMAC. + RelayServerConnection* ext_conn = piter->second; + RelayServerConnection* int_conn = + ext_conn->binding()->GetInternalConnection( + ext_conn->addr_pair().source()); + ASSERT(int_conn != NULL); + int_conn->Send(bytes, size, ext_conn->addr_pair().source()); + ext_conn->Lock(); // allow outgoing packets + return; + } + + // The first packet should always be a STUN / TURN packet. If it isn't, then + // we should just ignore this packet. + RelayMessage msg; + rtc::ByteBuffer buf(bytes, size); + if (!msg.Read(&buf)) { + LOG(LS_WARNING) << "Dropping packet: first packet not STUN"; + return; + } + + // The initial packet should have a username (which identifies the binding). + const StunByteStringAttribute* username_attr = + msg.GetByteString(STUN_ATTR_USERNAME); + if (!username_attr) { + LOG(LS_WARNING) << "Dropping packet: no username"; + return; + } + + uint32_t length = + std::min(static_cast<uint32_t>(username_attr->length()), USERNAME_LENGTH); + std::string username(username_attr->bytes(), length); + // TODO: Check the HMAC. + + // The binding should already be present. + BindingMap::iterator biter = bindings_.find(username); + if (biter == bindings_.end()) { + LOG(LS_WARNING) << "Dropping packet: no binding with username"; + return; + } + + // Add this authenticted connection to the binding. + RelayServerConnection* ext_conn = + new RelayServerConnection(biter->second, ap, socket); + ext_conn->binding()->AddExternalConnection(ext_conn); + AddConnection(ext_conn); + + // We always know where external packets should be forwarded, so we can lock + // them from the beginning. + ext_conn->Lock(); + + // Send this message on the appropriate internal connection. + RelayServerConnection* int_conn = ext_conn->binding()->GetInternalConnection( + ext_conn->addr_pair().source()); + ASSERT(int_conn != NULL); + int_conn->Send(bytes, size, ext_conn->addr_pair().source()); +} + +bool RelayServer::HandleStun( + const char* bytes, size_t size, const rtc::SocketAddress& remote_addr, + rtc::AsyncPacketSocket* socket, std::string* username, + StunMessage* msg) { + + // Parse this into a stun message. Eat the message if this fails. + rtc::ByteBuffer buf(bytes, size); + if (!msg->Read(&buf)) { + return false; + } + + // The initial packet should have a username (which identifies the binding). + const StunByteStringAttribute* username_attr = + msg->GetByteString(STUN_ATTR_USERNAME); + if (!username_attr) { + SendStunError(*msg, socket, remote_addr, 432, "Missing Username", ""); + return false; + } + + // Record the username if requested. + if (username) + username->append(username_attr->bytes(), username_attr->length()); + + // TODO: Check for unknown attributes (<= 0x7fff) + + return true; +} + +void RelayServer::HandleStunAllocate( + const char* bytes, size_t size, const rtc::SocketAddressPair& ap, + rtc::AsyncPacketSocket* socket) { + + // Make sure this is a valid STUN request. + RelayMessage request; + std::string username; + if (!HandleStun(bytes, size, ap.source(), socket, &username, &request)) + return; + + // Make sure this is a an allocate request. + if (request.type() != STUN_ALLOCATE_REQUEST) { + SendStunError(request, + socket, + ap.source(), + 600, + "Operation Not Supported", + ""); + return; + } + + // TODO: Check the HMAC. + + // Find or create the binding for this username. + + RelayServerBinding* binding; + + BindingMap::iterator biter = bindings_.find(username); + if (biter != bindings_.end()) { + binding = biter->second; + } else { + // NOTE: In the future, bindings will be created by the bot only. This + // else-branch will then disappear. + + // Compute the appropriate lifetime for this binding. + uint32_t lifetime = MAX_LIFETIME; + const StunUInt32Attribute* lifetime_attr = + request.GetUInt32(STUN_ATTR_LIFETIME); + if (lifetime_attr) + lifetime = std::min(lifetime, lifetime_attr->value() * 1000); + + binding = new RelayServerBinding(this, username, "0", lifetime); + binding->SignalTimeout.connect(this, &RelayServer::OnTimeout); + bindings_[username] = binding; + + if (log_bindings_) { + LOG(LS_INFO) << "Added new binding " << username << ", " + << bindings_.size() << " total"; + } + } + + // Add this connection to the binding. It starts out unlocked. + RelayServerConnection* int_conn = + new RelayServerConnection(binding, ap, socket); + binding->AddInternalConnection(int_conn); + AddConnection(int_conn); + + // Now that we have a connection, this other method takes over. + HandleStunAllocate(int_conn, request); +} + +void RelayServer::HandleStun( + RelayServerConnection* int_conn, const char* bytes, size_t size) { + + // Make sure this is a valid STUN request. + RelayMessage request; + std::string username; + if (!HandleStun(bytes, size, int_conn->addr_pair().source(), + int_conn->socket(), &username, &request)) + return; + + // Make sure the username is the one were were expecting. + if (username != int_conn->binding()->username()) { + int_conn->SendStunError(request, 430, "Stale Credentials"); + return; + } + + // TODO: Check the HMAC. + + // Send this request to the appropriate handler. + if (request.type() == STUN_SEND_REQUEST) + HandleStunSend(int_conn, request); + else if (request.type() == STUN_ALLOCATE_REQUEST) + HandleStunAllocate(int_conn, request); + else + int_conn->SendStunError(request, 600, "Operation Not Supported"); +} + +void RelayServer::HandleStunAllocate( + RelayServerConnection* int_conn, const StunMessage& request) { + + // Create a response message that includes an address with which external + // clients can communicate. + + RelayMessage response; + response.SetType(STUN_ALLOCATE_RESPONSE); + response.SetTransactionID(request.transaction_id()); + + StunByteStringAttribute* magic_cookie_attr = + StunAttribute::CreateByteString(cricket::STUN_ATTR_MAGIC_COOKIE); + magic_cookie_attr->CopyBytes(int_conn->binding()->magic_cookie().c_str(), + int_conn->binding()->magic_cookie().size()); + response.AddAttribute(magic_cookie_attr); + + size_t index = rand() % external_sockets_.size(); + rtc::SocketAddress ext_addr = + external_sockets_[index]->GetLocalAddress(); + + StunAddressAttribute* addr_attr = + StunAttribute::CreateAddress(STUN_ATTR_MAPPED_ADDRESS); + addr_attr->SetIP(ext_addr.ipaddr()); + addr_attr->SetPort(ext_addr.port()); + response.AddAttribute(addr_attr); + + StunUInt32Attribute* res_lifetime_attr = + StunAttribute::CreateUInt32(STUN_ATTR_LIFETIME); + res_lifetime_attr->SetValue(int_conn->binding()->lifetime() / 1000); + response.AddAttribute(res_lifetime_attr); + + // TODO: Support transport-prefs (preallocate RTCP port). + // TODO: Support bandwidth restrictions. + // TODO: Add message integrity check. + + // Send a response to the caller. + int_conn->SendStun(response); +} + +void RelayServer::HandleStunSend( + RelayServerConnection* int_conn, const StunMessage& request) { + + const StunAddressAttribute* addr_attr = + request.GetAddress(STUN_ATTR_DESTINATION_ADDRESS); + if (!addr_attr) { + int_conn->SendStunError(request, 400, "Bad Request"); + return; + } + + const StunByteStringAttribute* data_attr = + request.GetByteString(STUN_ATTR_DATA); + if (!data_attr) { + int_conn->SendStunError(request, 400, "Bad Request"); + return; + } + + rtc::SocketAddress ext_addr(addr_attr->ipaddr(), addr_attr->port()); + RelayServerConnection* ext_conn = + int_conn->binding()->GetExternalConnection(ext_addr); + if (!ext_conn) { + // Create a new connection to establish the relationship with this binding. + ASSERT(external_sockets_.size() == 1); + rtc::AsyncPacketSocket* socket = external_sockets_[0]; + rtc::SocketAddressPair ap(ext_addr, socket->GetLocalAddress()); + ext_conn = new RelayServerConnection(int_conn->binding(), ap, socket); + ext_conn->binding()->AddExternalConnection(ext_conn); + AddConnection(ext_conn); + } + + // If this connection has pinged us, then allow outgoing traffic. + if (ext_conn->locked()) + ext_conn->Send(data_attr->bytes(), data_attr->length()); + + const StunUInt32Attribute* options_attr = + request.GetUInt32(STUN_ATTR_OPTIONS); + if (options_attr && (options_attr->value() & 0x01)) { + int_conn->set_default_destination(ext_addr); + int_conn->Lock(); + + RelayMessage response; + response.SetType(STUN_SEND_RESPONSE); + response.SetTransactionID(request.transaction_id()); + + StunByteStringAttribute* magic_cookie_attr = + StunAttribute::CreateByteString(cricket::STUN_ATTR_MAGIC_COOKIE); + magic_cookie_attr->CopyBytes(int_conn->binding()->magic_cookie().c_str(), + int_conn->binding()->magic_cookie().size()); + response.AddAttribute(magic_cookie_attr); + + StunUInt32Attribute* options2_attr = + StunAttribute::CreateUInt32(cricket::STUN_ATTR_OPTIONS); + options2_attr->SetValue(0x01); + response.AddAttribute(options2_attr); + + int_conn->SendStun(response); + } +} + +void RelayServer::AddConnection(RelayServerConnection* conn) { + ASSERT(connections_.find(conn->addr_pair()) == connections_.end()); + connections_[conn->addr_pair()] = conn; +} + +void RelayServer::RemoveConnection(RelayServerConnection* conn) { + ConnectionMap::iterator iter = connections_.find(conn->addr_pair()); + ASSERT(iter != connections_.end()); + connections_.erase(iter); +} + +void RelayServer::RemoveBinding(RelayServerBinding* binding) { + BindingMap::iterator iter = bindings_.find(binding->username()); + ASSERT(iter != bindings_.end()); + bindings_.erase(iter); + + if (log_bindings_) { + LOG(LS_INFO) << "Removed binding " << binding->username() << ", " + << bindings_.size() << " remaining"; + } +} + +void RelayServer::OnMessage(rtc::Message *pmsg) { +#if ENABLE_DEBUG + static const uint32_t kMessageAcceptConnection = 1; + ASSERT(pmsg->message_id == kMessageAcceptConnection); +#endif + rtc::MessageData* data = pmsg->pdata; + rtc::AsyncSocket* socket = + static_cast <rtc::TypedMessageData<rtc::AsyncSocket*>*> + (data)->data(); + AcceptConnection(socket); + delete data; +} + +void RelayServer::OnTimeout(RelayServerBinding* binding) { + // This call will result in all of the necessary clean-up. We can't call + // delete here, because you can't delete an object that is signaling you. + thread_->Dispose(binding); +} + +void RelayServer::AcceptConnection(rtc::AsyncSocket* server_socket) { + // Check if someone is trying to connect to us. + rtc::SocketAddress accept_addr; + rtc::AsyncSocket* accepted_socket = + server_socket->Accept(&accept_addr); + if (accepted_socket != NULL) { + // We had someone trying to connect, now check which protocol to + // use and create a packet socket. + ASSERT(server_sockets_[server_socket] == cricket::PROTO_TCP || + server_sockets_[server_socket] == cricket::PROTO_SSLTCP); + if (server_sockets_[server_socket] == cricket::PROTO_SSLTCP) { + accepted_socket = new rtc::AsyncSSLServerSocket(accepted_socket); + } + rtc::AsyncTCPSocket* tcp_socket = + new rtc::AsyncTCPSocket(accepted_socket, false); + + // Finally add the socket so it can start communicating with the client. + AddInternalSocket(tcp_socket); + } +} + +RelayServerConnection::RelayServerConnection( + RelayServerBinding* binding, const rtc::SocketAddressPair& addrs, + rtc::AsyncPacketSocket* socket) + : binding_(binding), addr_pair_(addrs), socket_(socket), locked_(false) { + // The creation of a new connection constitutes a use of the binding. + binding_->NoteUsed(); +} + +RelayServerConnection::~RelayServerConnection() { + // Remove this connection from the server's map (if it exists there). + binding_->server()->RemoveConnection(this); +} + +void RelayServerConnection::Send(const char* data, size_t size) { + // Note that the binding has been used again. + binding_->NoteUsed(); + + cricket::Send(socket_, data, size, addr_pair_.source()); +} + +void RelayServerConnection::Send( + const char* data, size_t size, const rtc::SocketAddress& from_addr) { + // If the from address is known to the client, we don't need to send it. + if (locked() && (from_addr == default_dest_)) { + Send(data, size); + return; + } + + // Wrap the given data in a data-indication packet. + + RelayMessage msg; + msg.SetType(STUN_DATA_INDICATION); + + StunByteStringAttribute* magic_cookie_attr = + StunAttribute::CreateByteString(cricket::STUN_ATTR_MAGIC_COOKIE); + magic_cookie_attr->CopyBytes(binding_->magic_cookie().c_str(), + binding_->magic_cookie().size()); + msg.AddAttribute(magic_cookie_attr); + + StunAddressAttribute* addr_attr = + StunAttribute::CreateAddress(STUN_ATTR_SOURCE_ADDRESS2); + addr_attr->SetIP(from_addr.ipaddr()); + addr_attr->SetPort(from_addr.port()); + msg.AddAttribute(addr_attr); + + StunByteStringAttribute* data_attr = + StunAttribute::CreateByteString(STUN_ATTR_DATA); + ASSERT(size <= 65536); + data_attr->CopyBytes(data, uint16_t(size)); + msg.AddAttribute(data_attr); + + SendStun(msg); +} + +void RelayServerConnection::SendStun(const StunMessage& msg) { + // Note that the binding has been used again. + binding_->NoteUsed(); + + cricket::SendStun(msg, socket_, addr_pair_.source()); +} + +void RelayServerConnection::SendStunError( + const StunMessage& request, int error_code, const char* error_desc) { + // An error does not indicate use. If no legitimate use off the binding + // occurs, we want it to be cleaned up even if errors are still occuring. + + cricket::SendStunError( + request, socket_, addr_pair_.source(), error_code, error_desc, + binding_->magic_cookie()); +} + +void RelayServerConnection::Lock() { + locked_ = true; +} + +void RelayServerConnection::Unlock() { + locked_ = false; +} + +// IDs used for posted messages: +const uint32_t MSG_LIFETIME_TIMER = 1; + +RelayServerBinding::RelayServerBinding(RelayServer* server, + const std::string& username, + const std::string& password, + uint32_t lifetime) + : server_(server), + username_(username), + password_(password), + lifetime_(lifetime) { + // For now, every connection uses the standard magic cookie value. + magic_cookie_.append( + reinterpret_cast<const char*>(TURN_MAGIC_COOKIE_VALUE), + sizeof(TURN_MAGIC_COOKIE_VALUE)); + + // Initialize the last-used time to now. + NoteUsed(); + + // Set the first timeout check. + server_->thread()->PostDelayed(lifetime_, this, MSG_LIFETIME_TIMER); +} + +RelayServerBinding::~RelayServerBinding() { + // Clear the outstanding timeout check. + server_->thread()->Clear(this); + + // Clean up all of the connections. + for (size_t i = 0; i < internal_connections_.size(); ++i) + delete internal_connections_[i]; + for (size_t i = 0; i < external_connections_.size(); ++i) + delete external_connections_[i]; + + // Remove this binding from the server's map. + server_->RemoveBinding(this); +} + +void RelayServerBinding::AddInternalConnection(RelayServerConnection* conn) { + internal_connections_.push_back(conn); +} + +void RelayServerBinding::AddExternalConnection(RelayServerConnection* conn) { + external_connections_.push_back(conn); +} + +void RelayServerBinding::NoteUsed() { + last_used_ = rtc::Time(); +} + +bool RelayServerBinding::HasMagicCookie(const char* bytes, size_t size) const { + if (size < 24 + magic_cookie_.size()) { + return false; + } else { + return memcmp(bytes + 24, magic_cookie_.c_str(), magic_cookie_.size()) == 0; + } +} + +RelayServerConnection* RelayServerBinding::GetInternalConnection( + const rtc::SocketAddress& ext_addr) { + + // Look for an internal connection that is locked to this address. + for (size_t i = 0; i < internal_connections_.size(); ++i) { + if (internal_connections_[i]->locked() && + (ext_addr == internal_connections_[i]->default_destination())) + return internal_connections_[i]; + } + + // If one was not found, we send to the first connection. + ASSERT(internal_connections_.size() > 0); + return internal_connections_[0]; +} + +RelayServerConnection* RelayServerBinding::GetExternalConnection( + const rtc::SocketAddress& ext_addr) { + for (size_t i = 0; i < external_connections_.size(); ++i) { + if (ext_addr == external_connections_[i]->addr_pair().source()) + return external_connections_[i]; + } + return 0; +} + +void RelayServerBinding::OnMessage(rtc::Message *pmsg) { + if (pmsg->message_id == MSG_LIFETIME_TIMER) { + ASSERT(!pmsg->pdata); + + // If the lifetime timeout has been exceeded, then send a signal. + // Otherwise, just keep waiting. + if (rtc::Time() >= last_used_ + lifetime_) { + LOG(LS_INFO) << "Expiring binding " << username_; + SignalTimeout(this); + } else { + server_->thread()->PostDelayed(lifetime_, this, MSG_LIFETIME_TIMER); + } + + } else { + ASSERT(false); + } +} + +} // namespace cricket diff --git a/webrtc/p2p/base/relayserver.h b/webrtc/p2p/base/relayserver.h new file mode 100644 index 0000000000..f1109f1ce4 --- /dev/null +++ b/webrtc/p2p/base/relayserver.h @@ -0,0 +1,236 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_RELAYSERVER_H_ +#define WEBRTC_P2P_BASE_RELAYSERVER_H_ + +#include <map> +#include <string> +#include <vector> + +#include "webrtc/p2p/base/port.h" +#include "webrtc/p2p/base/stun.h" +#include "webrtc/base/asyncudpsocket.h" +#include "webrtc/base/socketaddresspair.h" +#include "webrtc/base/thread.h" +#include "webrtc/base/timeutils.h" + +namespace cricket { + +class RelayServerBinding; +class RelayServerConnection; + +// Relays traffic between connections to the server that are "bound" together. +// All connections created with the same username/password are bound together. +class RelayServer : public rtc::MessageHandler, + public sigslot::has_slots<> { + public: + // Creates a server, which will use this thread to post messages to itself. + explicit RelayServer(rtc::Thread* thread); + ~RelayServer(); + + rtc::Thread* thread() { return thread_; } + + // Indicates whether we will print updates of the number of bindings. + bool log_bindings() const { return log_bindings_; } + void set_log_bindings(bool log_bindings) { log_bindings_ = log_bindings; } + + // Updates the set of sockets that the server uses to talk to "internal" + // clients. These are clients that do the "port allocations". + void AddInternalSocket(rtc::AsyncPacketSocket* socket); + void RemoveInternalSocket(rtc::AsyncPacketSocket* socket); + + // Updates the set of sockets that the server uses to talk to "external" + // clients. These are the clients that do not do allocations. They do not + // know that these addresses represent a relay server. + void AddExternalSocket(rtc::AsyncPacketSocket* socket); + void RemoveExternalSocket(rtc::AsyncPacketSocket* socket); + + // Starts listening for connections on this sockets. When someone + // tries to connect, the connection will be accepted and a new + // internal socket will be added. + void AddInternalServerSocket(rtc::AsyncSocket* socket, + cricket::ProtocolType proto); + + // Removes this server socket from the list. + void RemoveInternalServerSocket(rtc::AsyncSocket* socket); + + // Methods for testing and debuging. + int GetConnectionCount() const; + rtc::SocketAddressPair GetConnection(int connection) const; + bool HasConnection(const rtc::SocketAddress& address) const; + + private: + typedef std::vector<rtc::AsyncPacketSocket*> SocketList; + typedef std::map<rtc::AsyncSocket*, + cricket::ProtocolType> ServerSocketMap; + typedef std::map<std::string, RelayServerBinding*> BindingMap; + typedef std::map<rtc::SocketAddressPair, + RelayServerConnection*> ConnectionMap; + + rtc::Thread* thread_; + bool log_bindings_; + SocketList internal_sockets_; + SocketList external_sockets_; + SocketList removed_sockets_; + ServerSocketMap server_sockets_; + BindingMap bindings_; + ConnectionMap connections_; + + // Called when a packet is received by the server on one of its sockets. + void OnInternalPacket(rtc::AsyncPacketSocket* socket, + const char* bytes, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time); + void OnExternalPacket(rtc::AsyncPacketSocket* socket, + const char* bytes, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time); + + void OnReadEvent(rtc::AsyncSocket* socket); + + // Processes the relevant STUN request types from the client. + bool HandleStun(const char* bytes, size_t size, + const rtc::SocketAddress& remote_addr, + rtc::AsyncPacketSocket* socket, + std::string* username, StunMessage* msg); + void HandleStunAllocate(const char* bytes, size_t size, + const rtc::SocketAddressPair& ap, + rtc::AsyncPacketSocket* socket); + void HandleStun(RelayServerConnection* int_conn, const char* bytes, + size_t size); + void HandleStunAllocate(RelayServerConnection* int_conn, + const StunMessage& msg); + void HandleStunSend(RelayServerConnection* int_conn, const StunMessage& msg); + + // Adds/Removes the a connection or binding. + void AddConnection(RelayServerConnection* conn); + void RemoveConnection(RelayServerConnection* conn); + void RemoveBinding(RelayServerBinding* binding); + + // Handle messages in our worker thread. + void OnMessage(rtc::Message *pmsg); + + // Called when the timer for checking lifetime times out. + void OnTimeout(RelayServerBinding* binding); + + // Accept connections on this server socket. + void AcceptConnection(rtc::AsyncSocket* server_socket); + + friend class RelayServerConnection; + friend class RelayServerBinding; +}; + +// Maintains information about a connection to the server. Each connection is +// part of one and only one binding. +class RelayServerConnection { + public: + RelayServerConnection(RelayServerBinding* binding, + const rtc::SocketAddressPair& addrs, + rtc::AsyncPacketSocket* socket); + ~RelayServerConnection(); + + RelayServerBinding* binding() { return binding_; } + rtc::AsyncPacketSocket* socket() { return socket_; } + + // Returns a pair where the source is the remote address and the destination + // is the local address. + const rtc::SocketAddressPair& addr_pair() { return addr_pair_; } + + // Sends a packet to the connected client. If an address is provided, then + // we make sure the internal client receives it, wrapping if necessary. + void Send(const char* data, size_t size); + void Send(const char* data, size_t size, + const rtc::SocketAddress& ext_addr); + + // Sends a STUN message to the connected client with no wrapping. + void SendStun(const StunMessage& msg); + void SendStunError(const StunMessage& request, int code, const char* desc); + + // A locked connection is one for which we know the intended destination of + // any raw packet received. + bool locked() const { return locked_; } + void Lock(); + void Unlock(); + + // Records the address that raw packets should be forwarded to (for internal + // packets only; for external, we already know where they go). + const rtc::SocketAddress& default_destination() const { + return default_dest_; + } + void set_default_destination(const rtc::SocketAddress& addr) { + default_dest_ = addr; + } + + private: + RelayServerBinding* binding_; + rtc::SocketAddressPair addr_pair_; + rtc::AsyncPacketSocket* socket_; + bool locked_; + rtc::SocketAddress default_dest_; +}; + +// Records a set of internal and external connections that we relay between, +// or in other words, that are "bound" together. +class RelayServerBinding : public rtc::MessageHandler { + public: + RelayServerBinding(RelayServer* server, + const std::string& username, + const std::string& password, + uint32_t lifetime); + virtual ~RelayServerBinding(); + + RelayServer* server() { return server_; } + uint32_t lifetime() { return lifetime_; } + const std::string& username() { return username_; } + const std::string& password() { return password_; } + const std::string& magic_cookie() { return magic_cookie_; } + + // Adds/Removes a connection into the binding. + void AddInternalConnection(RelayServerConnection* conn); + void AddExternalConnection(RelayServerConnection* conn); + + // We keep track of the use of each binding. If we detect that it was not + // used for longer than the lifetime, then we send a signal. + void NoteUsed(); + sigslot::signal1<RelayServerBinding*> SignalTimeout; + + // Determines whether the given packet has the magic cookie present (in the + // right place). + bool HasMagicCookie(const char* bytes, size_t size) const; + + // Determines the connection to use to send packets to or from the given + // external address. + RelayServerConnection* GetInternalConnection( + const rtc::SocketAddress& ext_addr); + RelayServerConnection* GetExternalConnection( + const rtc::SocketAddress& ext_addr); + + // MessageHandler: + void OnMessage(rtc::Message *pmsg); + + private: + RelayServer* server_; + + std::string username_; + std::string password_; + std::string magic_cookie_; + + std::vector<RelayServerConnection*> internal_connections_; + std::vector<RelayServerConnection*> external_connections_; + + uint32_t lifetime_; + uint32_t last_used_; + // TODO: bandwidth +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_RELAYSERVER_H_ diff --git a/webrtc/p2p/base/relayserver_unittest.cc b/webrtc/p2p/base/relayserver_unittest.cc new file mode 100644 index 0000000000..83e5353fc9 --- /dev/null +++ b/webrtc/p2p/base/relayserver_unittest.cc @@ -0,0 +1,529 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <string> + +#include "webrtc/p2p/base/relayserver.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/physicalsocketserver.h" +#include "webrtc/base/socketaddress.h" +#include "webrtc/base/ssladapter.h" +#include "webrtc/base/testclient.h" +#include "webrtc/base/thread.h" +#include "webrtc/base/virtualsocketserver.h" + +using rtc::SocketAddress; +using namespace cricket; + +static const uint32_t LIFETIME = 4; // seconds +static const SocketAddress server_int_addr("127.0.0.1", 5000); +static const SocketAddress server_ext_addr("127.0.0.1", 5001); +static const SocketAddress client1_addr("127.0.0.1", 6000 + (rand() % 1000)); +static const SocketAddress client2_addr("127.0.0.1", 7000 + (rand() % 1000)); +static const char* bad = "this is a completely nonsensical message whose only " + "purpose is to make the parser go 'ack'. it doesn't " + "look anything like a normal stun message"; +static const char* msg1 = "spamspamspamspamspamspamspambakedbeansspam"; +static const char* msg2 = "Lobster Thermidor a Crevette with a mornay sauce..."; + +class RelayServerTest : public testing::Test { + public: + RelayServerTest() + : pss_(new rtc::PhysicalSocketServer), + ss_(new rtc::VirtualSocketServer(pss_.get())), + ss_scope_(ss_.get()), + username_(rtc::CreateRandomString(12)), + password_(rtc::CreateRandomString(12)) {} + + protected: + virtual void SetUp() { + server_.reset(new RelayServer(rtc::Thread::Current())); + + server_->AddInternalSocket( + rtc::AsyncUDPSocket::Create(ss_.get(), server_int_addr)); + server_->AddExternalSocket( + rtc::AsyncUDPSocket::Create(ss_.get(), server_ext_addr)); + + client1_.reset(new rtc::TestClient( + rtc::AsyncUDPSocket::Create(ss_.get(), client1_addr))); + client2_.reset(new rtc::TestClient( + rtc::AsyncUDPSocket::Create(ss_.get(), client2_addr))); + } + + void Allocate() { + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_ALLOCATE_REQUEST)); + AddUsernameAttr(req.get(), username_); + AddLifetimeAttr(req.get(), LIFETIME); + Send1(req.get()); + delete Receive1(); + } + void Bind() { + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_BINDING_REQUEST)); + AddUsernameAttr(req.get(), username_); + Send2(req.get()); + delete Receive1(); + } + + void Send1(const StunMessage* msg) { + rtc::ByteBuffer buf; + msg->Write(&buf); + SendRaw1(buf.Data(), static_cast<int>(buf.Length())); + } + void Send2(const StunMessage* msg) { + rtc::ByteBuffer buf; + msg->Write(&buf); + SendRaw2(buf.Data(), static_cast<int>(buf.Length())); + } + void SendRaw1(const char* data, int len) { + return Send(client1_.get(), data, len, server_int_addr); + } + void SendRaw2(const char* data, int len) { + return Send(client2_.get(), data, len, server_ext_addr); + } + void Send(rtc::TestClient* client, const char* data, + int len, const SocketAddress& addr) { + client->SendTo(data, len, addr); + } + + bool Receive1Fails() { + return client1_.get()->CheckNoPacket(); + } + bool Receive2Fails() { + return client2_.get()->CheckNoPacket(); + } + + StunMessage* Receive1() { + return Receive(client1_.get()); + } + StunMessage* Receive2() { + return Receive(client2_.get()); + } + std::string ReceiveRaw1() { + return ReceiveRaw(client1_.get()); + } + std::string ReceiveRaw2() { + return ReceiveRaw(client2_.get()); + } + StunMessage* Receive(rtc::TestClient* client) { + StunMessage* msg = NULL; + rtc::TestClient::Packet* packet = + client->NextPacket(rtc::TestClient::kTimeoutMs); + if (packet) { + rtc::ByteBuffer buf(packet->buf, packet->size); + msg = new RelayMessage(); + msg->Read(&buf); + delete packet; + } + return msg; + } + std::string ReceiveRaw(rtc::TestClient* client) { + std::string raw; + rtc::TestClient::Packet* packet = + client->NextPacket(rtc::TestClient::kTimeoutMs); + if (packet) { + raw = std::string(packet->buf, packet->size); + delete packet; + } + return raw; + } + + static StunMessage* CreateStunMessage(int type) { + StunMessage* msg = new RelayMessage(); + msg->SetType(type); + msg->SetTransactionID( + rtc::CreateRandomString(kStunTransactionIdLength)); + return msg; + } + static void AddMagicCookieAttr(StunMessage* msg) { + StunByteStringAttribute* attr = + StunAttribute::CreateByteString(STUN_ATTR_MAGIC_COOKIE); + attr->CopyBytes(TURN_MAGIC_COOKIE_VALUE, sizeof(TURN_MAGIC_COOKIE_VALUE)); + msg->AddAttribute(attr); + } + static void AddUsernameAttr(StunMessage* msg, const std::string& val) { + StunByteStringAttribute* attr = + StunAttribute::CreateByteString(STUN_ATTR_USERNAME); + attr->CopyBytes(val.c_str(), val.size()); + msg->AddAttribute(attr); + } + static void AddLifetimeAttr(StunMessage* msg, int val) { + StunUInt32Attribute* attr = + StunAttribute::CreateUInt32(STUN_ATTR_LIFETIME); + attr->SetValue(val); + msg->AddAttribute(attr); + } + static void AddDestinationAttr(StunMessage* msg, const SocketAddress& addr) { + StunAddressAttribute* attr = + StunAttribute::CreateAddress(STUN_ATTR_DESTINATION_ADDRESS); + attr->SetIP(addr.ipaddr()); + attr->SetPort(addr.port()); + msg->AddAttribute(attr); + } + + rtc::scoped_ptr<rtc::PhysicalSocketServer> pss_; + rtc::scoped_ptr<rtc::VirtualSocketServer> ss_; + rtc::SocketServerScope ss_scope_; + rtc::scoped_ptr<RelayServer> server_; + rtc::scoped_ptr<rtc::TestClient> client1_; + rtc::scoped_ptr<rtc::TestClient> client2_; + std::string username_; + std::string password_; +}; + +// Send a complete nonsense message and verify that it is eaten. +TEST_F(RelayServerTest, TestBadRequest) { + SendRaw1(bad, static_cast<int>(strlen(bad))); + ASSERT_TRUE(Receive1Fails()); +} + +// Send an allocate request without a username and verify it is rejected. +TEST_F(RelayServerTest, TestAllocateNoUsername) { + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_ALLOCATE_REQUEST)), res; + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_ALLOCATE_ERROR_RESPONSE, res->type()); + EXPECT_EQ(req->transaction_id(), res->transaction_id()); + + const StunErrorCodeAttribute* err = res->GetErrorCode(); + ASSERT_TRUE(err != NULL); + EXPECT_EQ(4, err->eclass()); + EXPECT_EQ(32, err->number()); + EXPECT_EQ("Missing Username", err->reason()); +} + +// Send a binding request and verify that it is rejected. +TEST_F(RelayServerTest, TestBindingRequest) { + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_BINDING_REQUEST)), res; + AddUsernameAttr(req.get(), username_); + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_BINDING_ERROR_RESPONSE, res->type()); + EXPECT_EQ(req->transaction_id(), res->transaction_id()); + + const StunErrorCodeAttribute* err = res->GetErrorCode(); + ASSERT_TRUE(err != NULL); + EXPECT_EQ(6, err->eclass()); + EXPECT_EQ(0, err->number()); + EXPECT_EQ("Operation Not Supported", err->reason()); +} + +// Send an allocate request and verify that it is accepted. +TEST_F(RelayServerTest, TestAllocate) { + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_ALLOCATE_REQUEST)), res; + AddUsernameAttr(req.get(), username_); + AddLifetimeAttr(req.get(), LIFETIME); + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_ALLOCATE_RESPONSE, res->type()); + EXPECT_EQ(req->transaction_id(), res->transaction_id()); + + const StunAddressAttribute* mapped_addr = + res->GetAddress(STUN_ATTR_MAPPED_ADDRESS); + ASSERT_TRUE(mapped_addr != NULL); + EXPECT_EQ(1, mapped_addr->family()); + EXPECT_EQ(server_ext_addr.port(), mapped_addr->port()); + EXPECT_EQ(server_ext_addr.ipaddr(), mapped_addr->ipaddr()); + + const StunUInt32Attribute* res_lifetime_attr = + res->GetUInt32(STUN_ATTR_LIFETIME); + ASSERT_TRUE(res_lifetime_attr != NULL); + EXPECT_EQ(LIFETIME, res_lifetime_attr->value()); +} + +// Send a second allocate request and verify that it is also accepted, though +// the lifetime should be ignored. +TEST_F(RelayServerTest, TestReallocate) { + Allocate(); + + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_ALLOCATE_REQUEST)), res; + AddMagicCookieAttr(req.get()); + AddUsernameAttr(req.get(), username_); + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_ALLOCATE_RESPONSE, res->type()); + EXPECT_EQ(req->transaction_id(), res->transaction_id()); + + const StunAddressAttribute* mapped_addr = + res->GetAddress(STUN_ATTR_MAPPED_ADDRESS); + ASSERT_TRUE(mapped_addr != NULL); + EXPECT_EQ(1, mapped_addr->family()); + EXPECT_EQ(server_ext_addr.port(), mapped_addr->port()); + EXPECT_EQ(server_ext_addr.ipaddr(), mapped_addr->ipaddr()); + + const StunUInt32Attribute* lifetime_attr = + res->GetUInt32(STUN_ATTR_LIFETIME); + ASSERT_TRUE(lifetime_attr != NULL); + EXPECT_EQ(LIFETIME, lifetime_attr->value()); +} + +// Send a request from another client and see that it arrives at the first +// client in the binding. +TEST_F(RelayServerTest, TestRemoteBind) { + Allocate(); + + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_BINDING_REQUEST)), res; + AddUsernameAttr(req.get(), username_); + + Send2(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_DATA_INDICATION, res->type()); + + const StunByteStringAttribute* recv_data = + res->GetByteString(STUN_ATTR_DATA); + ASSERT_TRUE(recv_data != NULL); + + rtc::ByteBuffer buf(recv_data->bytes(), recv_data->length()); + rtc::scoped_ptr<StunMessage> res2(new StunMessage()); + EXPECT_TRUE(res2->Read(&buf)); + EXPECT_EQ(STUN_BINDING_REQUEST, res2->type()); + EXPECT_EQ(req->transaction_id(), res2->transaction_id()); + + const StunAddressAttribute* src_addr = + res->GetAddress(STUN_ATTR_SOURCE_ADDRESS2); + ASSERT_TRUE(src_addr != NULL); + EXPECT_EQ(1, src_addr->family()); + EXPECT_EQ(client2_addr.ipaddr(), src_addr->ipaddr()); + EXPECT_EQ(client2_addr.port(), src_addr->port()); + + EXPECT_TRUE(Receive2Fails()); +} + +// Send a complete nonsense message to the established connection and verify +// that it is dropped by the server. +TEST_F(RelayServerTest, TestRemoteBadRequest) { + Allocate(); + Bind(); + + SendRaw1(bad, static_cast<int>(strlen(bad))); + EXPECT_TRUE(Receive1Fails()); + EXPECT_TRUE(Receive2Fails()); +} + +// Send a send request without a username and verify it is rejected. +TEST_F(RelayServerTest, TestSendRequestMissingUsername) { + Allocate(); + Bind(); + + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_SEND_REQUEST)), res; + AddMagicCookieAttr(req.get()); + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_SEND_ERROR_RESPONSE, res->type()); + EXPECT_EQ(req->transaction_id(), res->transaction_id()); + + const StunErrorCodeAttribute* err = res->GetErrorCode(); + ASSERT_TRUE(err != NULL); + EXPECT_EQ(4, err->eclass()); + EXPECT_EQ(32, err->number()); + EXPECT_EQ("Missing Username", err->reason()); +} + +// Send a send request with the wrong username and verify it is rejected. +TEST_F(RelayServerTest, TestSendRequestBadUsername) { + Allocate(); + Bind(); + + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_SEND_REQUEST)), res; + AddMagicCookieAttr(req.get()); + AddUsernameAttr(req.get(), "foobarbizbaz"); + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_SEND_ERROR_RESPONSE, res->type()); + EXPECT_EQ(req->transaction_id(), res->transaction_id()); + + const StunErrorCodeAttribute* err = res->GetErrorCode(); + ASSERT_TRUE(err != NULL); + EXPECT_EQ(4, err->eclass()); + EXPECT_EQ(30, err->number()); + EXPECT_EQ("Stale Credentials", err->reason()); +} + +// Send a send request without a destination address and verify that it is +// rejected. +TEST_F(RelayServerTest, TestSendRequestNoDestinationAddress) { + Allocate(); + Bind(); + + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_SEND_REQUEST)), res; + AddMagicCookieAttr(req.get()); + AddUsernameAttr(req.get(), username_); + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_SEND_ERROR_RESPONSE, res->type()); + EXPECT_EQ(req->transaction_id(), res->transaction_id()); + + const StunErrorCodeAttribute* err = res->GetErrorCode(); + ASSERT_TRUE(err != NULL); + EXPECT_EQ(4, err->eclass()); + EXPECT_EQ(0, err->number()); + EXPECT_EQ("Bad Request", err->reason()); +} + +// Send a send request without data and verify that it is rejected. +TEST_F(RelayServerTest, TestSendRequestNoData) { + Allocate(); + Bind(); + + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_SEND_REQUEST)), res; + AddMagicCookieAttr(req.get()); + AddUsernameAttr(req.get(), username_); + AddDestinationAttr(req.get(), client2_addr); + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_SEND_ERROR_RESPONSE, res->type()); + EXPECT_EQ(req->transaction_id(), res->transaction_id()); + + const StunErrorCodeAttribute* err = res->GetErrorCode(); + ASSERT_TRUE(err != NULL); + EXPECT_EQ(4, err->eclass()); + EXPECT_EQ(00, err->number()); + EXPECT_EQ("Bad Request", err->reason()); +} + +// Send a binding request after an allocate and verify that it is rejected. +TEST_F(RelayServerTest, TestSendRequestWrongType) { + Allocate(); + Bind(); + + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_BINDING_REQUEST)), res; + AddMagicCookieAttr(req.get()); + AddUsernameAttr(req.get(), username_); + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_BINDING_ERROR_RESPONSE, res->type()); + EXPECT_EQ(req->transaction_id(), res->transaction_id()); + + const StunErrorCodeAttribute* err = res->GetErrorCode(); + ASSERT_TRUE(err != NULL); + EXPECT_EQ(6, err->eclass()); + EXPECT_EQ(0, err->number()); + EXPECT_EQ("Operation Not Supported", err->reason()); +} + +// Verify that we can send traffic back and forth between the clients after a +// successful allocate and bind. +TEST_F(RelayServerTest, TestSendRaw) { + Allocate(); + Bind(); + + for (int i = 0; i < 10; i++) { + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_SEND_REQUEST)), res; + AddMagicCookieAttr(req.get()); + AddUsernameAttr(req.get(), username_); + AddDestinationAttr(req.get(), client2_addr); + + StunByteStringAttribute* send_data = + StunAttribute::CreateByteString(STUN_ATTR_DATA); + send_data->CopyBytes(msg1); + req->AddAttribute(send_data); + + Send1(req.get()); + EXPECT_EQ(msg1, ReceiveRaw2()); + SendRaw2(msg2, static_cast<int>(strlen(msg2))); + res.reset(Receive1()); + + ASSERT_TRUE(res); + EXPECT_EQ(STUN_DATA_INDICATION, res->type()); + + const StunAddressAttribute* src_addr = + res->GetAddress(STUN_ATTR_SOURCE_ADDRESS2); + ASSERT_TRUE(src_addr != NULL); + EXPECT_EQ(1, src_addr->family()); + EXPECT_EQ(client2_addr.ipaddr(), src_addr->ipaddr()); + EXPECT_EQ(client2_addr.port(), src_addr->port()); + + const StunByteStringAttribute* recv_data = + res->GetByteString(STUN_ATTR_DATA); + ASSERT_TRUE(recv_data != NULL); + EXPECT_EQ(strlen(msg2), recv_data->length()); + EXPECT_EQ(0, memcmp(msg2, recv_data->bytes(), recv_data->length())); + } +} + +// Verify that a binding expires properly, and rejects send requests. +// Flaky, see https://code.google.com/p/webrtc/issues/detail?id=4134 +TEST_F(RelayServerTest, DISABLED_TestExpiration) { + Allocate(); + Bind(); + + // Wait twice the lifetime to make sure the server has expired the binding. + rtc::Thread::Current()->ProcessMessages((LIFETIME * 2) * 1000); + + rtc::scoped_ptr<StunMessage> req( + CreateStunMessage(STUN_SEND_REQUEST)), res; + AddMagicCookieAttr(req.get()); + AddUsernameAttr(req.get(), username_); + AddDestinationAttr(req.get(), client2_addr); + + StunByteStringAttribute* data_attr = + StunAttribute::CreateByteString(STUN_ATTR_DATA); + data_attr->CopyBytes(msg1); + req->AddAttribute(data_attr); + + Send1(req.get()); + res.reset(Receive1()); + + ASSERT_TRUE(res.get() != NULL); + EXPECT_EQ(STUN_SEND_ERROR_RESPONSE, res->type()); + + const StunErrorCodeAttribute* err = res->GetErrorCode(); + ASSERT_TRUE(err != NULL); + EXPECT_EQ(6, err->eclass()); + EXPECT_EQ(0, err->number()); + EXPECT_EQ("Operation Not Supported", err->reason()); + + // Also verify that traffic from the external client is ignored. + SendRaw2(msg2, static_cast<int>(strlen(msg2))); + EXPECT_TRUE(ReceiveRaw1().empty()); +} diff --git a/webrtc/p2p/base/session.cc b/webrtc/p2p/base/session.cc new file mode 100644 index 0000000000..1a23f8363f --- /dev/null +++ b/webrtc/p2p/base/session.cc @@ -0,0 +1,12 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// TODO(deadbeef): Remove this file when Chrome build files no longer reference +// it. diff --git a/webrtc/p2p/base/session.h b/webrtc/p2p/base/session.h new file mode 100644 index 0000000000..a98a5efe13 --- /dev/null +++ b/webrtc/p2p/base/session.h @@ -0,0 +1,13 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// TODO(deadbeef): Remove this file when Chrome build files no longer reference +// it. +#error "DONT INCLUDE THIS" diff --git a/webrtc/p2p/base/sessiondescription.cc b/webrtc/p2p/base/sessiondescription.cc new file mode 100644 index 0000000000..5320b0596a --- /dev/null +++ b/webrtc/p2p/base/sessiondescription.cc @@ -0,0 +1,220 @@ +/* + * Copyright 2010 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/sessiondescription.h" + +namespace cricket { + +ContentInfo* FindContentInfoByName( + ContentInfos& contents, const std::string& name) { + for (ContentInfos::iterator content = contents.begin(); + content != contents.end(); ++content) { + if (content->name == name) { + return &(*content); + } + } + return NULL; +} + +const ContentInfo* FindContentInfoByName( + const ContentInfos& contents, const std::string& name) { + for (ContentInfos::const_iterator content = contents.begin(); + content != contents.end(); ++content) { + if (content->name == name) { + return &(*content); + } + } + return NULL; +} + +const ContentInfo* FindContentInfoByType( + const ContentInfos& contents, const std::string& type) { + for (ContentInfos::const_iterator content = contents.begin(); + content != contents.end(); ++content) { + if (content->type == type) { + return &(*content); + } + } + return NULL; +} + +const std::string* ContentGroup::FirstContentName() const { + return (!content_names_.empty()) ? &(*content_names_.begin()) : NULL; +} + +bool ContentGroup::HasContentName(const std::string& content_name) const { + return (std::find(content_names_.begin(), content_names_.end(), + content_name) != content_names_.end()); +} + +void ContentGroup::AddContentName(const std::string& content_name) { + if (!HasContentName(content_name)) { + content_names_.push_back(content_name); + } +} + +bool ContentGroup::RemoveContentName(const std::string& content_name) { + ContentNames::iterator iter = std::find( + content_names_.begin(), content_names_.end(), content_name); + if (iter == content_names_.end()) { + return false; + } + content_names_.erase(iter); + return true; +} + +SessionDescription* SessionDescription::Copy() const { + SessionDescription* copy = new SessionDescription(*this); + // Copy all ContentDescriptions. + for (ContentInfos::iterator content = copy->contents_.begin(); + content != copy->contents().end(); ++content) { + content->description = content->description->Copy(); + } + return copy; +} + +const ContentInfo* SessionDescription::GetContentByName( + const std::string& name) const { + return FindContentInfoByName(contents_, name); +} + +ContentInfo* SessionDescription::GetContentByName( + const std::string& name) { + return FindContentInfoByName(contents_, name); +} + +const ContentDescription* SessionDescription::GetContentDescriptionByName( + const std::string& name) const { + const ContentInfo* cinfo = FindContentInfoByName(contents_, name); + if (cinfo == NULL) { + return NULL; + } + + return cinfo->description; +} + +ContentDescription* SessionDescription::GetContentDescriptionByName( + const std::string& name) { + ContentInfo* cinfo = FindContentInfoByName(contents_, name); + if (cinfo == NULL) { + return NULL; + } + + return cinfo->description; +} + +const ContentInfo* SessionDescription::FirstContentByType( + const std::string& type) const { + return FindContentInfoByType(contents_, type); +} + +const ContentInfo* SessionDescription::FirstContent() const { + return (contents_.empty()) ? NULL : &(*contents_.begin()); +} + +void SessionDescription::AddContent(const std::string& name, + const std::string& type, + ContentDescription* description) { + contents_.push_back(ContentInfo(name, type, description)); +} + +void SessionDescription::AddContent(const std::string& name, + const std::string& type, + bool rejected, + ContentDescription* description) { + contents_.push_back(ContentInfo(name, type, rejected, description)); +} + +bool SessionDescription::RemoveContentByName(const std::string& name) { + for (ContentInfos::iterator content = contents_.begin(); + content != contents_.end(); ++content) { + if (content->name == name) { + delete content->description; + contents_.erase(content); + return true; + } + } + + return false; +} + +bool SessionDescription::AddTransportInfo(const TransportInfo& transport_info) { + if (GetTransportInfoByName(transport_info.content_name) != NULL) { + return false; + } + transport_infos_.push_back(transport_info); + return true; +} + +bool SessionDescription::RemoveTransportInfoByName(const std::string& name) { + for (TransportInfos::iterator transport_info = transport_infos_.begin(); + transport_info != transport_infos_.end(); ++transport_info) { + if (transport_info->content_name == name) { + transport_infos_.erase(transport_info); + return true; + } + } + return false; +} + +const TransportInfo* SessionDescription::GetTransportInfoByName( + const std::string& name) const { + for (TransportInfos::const_iterator iter = transport_infos_.begin(); + iter != transport_infos_.end(); ++iter) { + if (iter->content_name == name) { + return &(*iter); + } + } + return NULL; +} + +TransportInfo* SessionDescription::GetTransportInfoByName( + const std::string& name) { + for (TransportInfos::iterator iter = transport_infos_.begin(); + iter != transport_infos_.end(); ++iter) { + if (iter->content_name == name) { + return &(*iter); + } + } + return NULL; +} + +void SessionDescription::RemoveGroupByName(const std::string& name) { + for (ContentGroups::iterator iter = content_groups_.begin(); + iter != content_groups_.end(); ++iter) { + if (iter->semantics() == name) { + content_groups_.erase(iter); + break; + } + } +} + +bool SessionDescription::HasGroup(const std::string& name) const { + for (ContentGroups::const_iterator iter = content_groups_.begin(); + iter != content_groups_.end(); ++iter) { + if (iter->semantics() == name) { + return true; + } + } + return false; +} + +const ContentGroup* SessionDescription::GetGroupByName( + const std::string& name) const { + for (ContentGroups::const_iterator iter = content_groups_.begin(); + iter != content_groups_.end(); ++iter) { + if (iter->semantics() == name) { + return &(*iter); + } + } + return NULL; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/sessiondescription.h b/webrtc/p2p/base/sessiondescription.h new file mode 100644 index 0000000000..7880167569 --- /dev/null +++ b/webrtc/p2p/base/sessiondescription.h @@ -0,0 +1,190 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_SESSIONDESCRIPTION_H_ +#define WEBRTC_P2P_BASE_SESSIONDESCRIPTION_H_ + +#include <string> +#include <vector> + +#include "webrtc/p2p/base/transportinfo.h" +#include "webrtc/base/constructormagic.h" + +namespace cricket { + +// Describes a session content. Individual content types inherit from +// this class. Analagous to a <jingle><content><description> or +// <session><description>. +class ContentDescription { + public: + virtual ~ContentDescription() {} + virtual ContentDescription* Copy() const = 0; +}; + +// Analagous to a <jingle><content> or <session><description>. +// name = name of <content name="..."> +// type = xmlns of <content> +struct ContentInfo { + ContentInfo() : description(NULL) {} + ContentInfo(const std::string& name, + const std::string& type, + ContentDescription* description) : + name(name), type(type), rejected(false), description(description) {} + ContentInfo(const std::string& name, + const std::string& type, + bool rejected, + ContentDescription* description) : + name(name), type(type), rejected(rejected), description(description) {} + std::string name; + std::string type; + bool rejected; + ContentDescription* description; +}; + +typedef std::vector<std::string> ContentNames; + +// This class provides a mechanism to aggregate different media contents into a +// group. This group can also be shared with the peers in a pre-defined format. +// GroupInfo should be populated only with the |content_name| of the +// MediaDescription. +class ContentGroup { + public: + explicit ContentGroup(const std::string& semantics) : + semantics_(semantics) {} + + const std::string& semantics() const { return semantics_; } + const ContentNames& content_names() const { return content_names_; } + + const std::string* FirstContentName() const; + bool HasContentName(const std::string& content_name) const; + void AddContentName(const std::string& content_name); + bool RemoveContentName(const std::string& content_name); + + private: + std::string semantics_; + ContentNames content_names_; +}; + +typedef std::vector<ContentInfo> ContentInfos; +typedef std::vector<ContentGroup> ContentGroups; + +const ContentInfo* FindContentInfoByName( + const ContentInfos& contents, const std::string& name); +const ContentInfo* FindContentInfoByType( + const ContentInfos& contents, const std::string& type); + +// Describes a collection of contents, each with its own name and +// type. Analogous to a <jingle> or <session> stanza. Assumes that +// contents are unique be name, but doesn't enforce that. +class SessionDescription { + public: + SessionDescription() {} + explicit SessionDescription(const ContentInfos& contents) : + contents_(contents) {} + SessionDescription(const ContentInfos& contents, + const ContentGroups& groups) : + contents_(contents), + content_groups_(groups) {} + SessionDescription(const ContentInfos& contents, + const TransportInfos& transports, + const ContentGroups& groups) : + contents_(contents), + transport_infos_(transports), + content_groups_(groups) {} + ~SessionDescription() { + for (ContentInfos::iterator content = contents_.begin(); + content != contents_.end(); ++content) { + delete content->description; + } + } + + SessionDescription* Copy() const; + + // Content accessors. + const ContentInfos& contents() const { return contents_; } + ContentInfos& contents() { return contents_; } + const ContentInfo* GetContentByName(const std::string& name) const; + ContentInfo* GetContentByName(const std::string& name); + const ContentDescription* GetContentDescriptionByName( + const std::string& name) const; + ContentDescription* GetContentDescriptionByName(const std::string& name); + const ContentInfo* FirstContentByType(const std::string& type) const; + const ContentInfo* FirstContent() const; + + // Content mutators. + // Adds a content to this description. Takes ownership of ContentDescription*. + void AddContent(const std::string& name, + const std::string& type, + ContentDescription* description); + void AddContent(const std::string& name, + const std::string& type, + bool rejected, + ContentDescription* description); + bool RemoveContentByName(const std::string& name); + + // Transport accessors. + const TransportInfos& transport_infos() const { return transport_infos_; } + TransportInfos& transport_infos() { return transport_infos_; } + const TransportInfo* GetTransportInfoByName( + const std::string& name) const; + TransportInfo* GetTransportInfoByName(const std::string& name); + const TransportDescription* GetTransportDescriptionByName( + const std::string& name) const { + const TransportInfo* tinfo = GetTransportInfoByName(name); + return tinfo ? &tinfo->description : NULL; + } + + // Transport mutators. + void set_transport_infos(const TransportInfos& transport_infos) { + transport_infos_ = transport_infos; + } + // Adds a TransportInfo to this description. + // Returns false if a TransportInfo with the same name already exists. + bool AddTransportInfo(const TransportInfo& transport_info); + bool RemoveTransportInfoByName(const std::string& name); + + // Group accessors. + const ContentGroups& groups() const { return content_groups_; } + const ContentGroup* GetGroupByName(const std::string& name) const; + bool HasGroup(const std::string& name) const; + + // Group mutators. + void AddGroup(const ContentGroup& group) { content_groups_.push_back(group); } + // Remove the first group with the same semantics specified by |name|. + void RemoveGroupByName(const std::string& name); + + // Global attributes. + void set_msid_supported(bool supported) { msid_supported_ = supported; } + bool msid_supported() const { return msid_supported_; } + + private: + ContentInfos contents_; + TransportInfos transport_infos_; + ContentGroups content_groups_; + bool msid_supported_ = true; +}; + +// Indicates whether a ContentDescription was an offer or an answer, as +// described in http://www.ietf.org/rfc/rfc3264.txt. CA_UPDATE +// indicates a jingle update message which contains a subset of a full +// session description +enum ContentAction { + CA_OFFER, CA_PRANSWER, CA_ANSWER, CA_UPDATE +}; + +// Indicates whether a ContentDescription was sent by the local client +// or received from the remote client. +enum ContentSource { + CS_LOCAL, CS_REMOTE +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_SESSIONDESCRIPTION_H_ diff --git a/webrtc/p2p/base/sessionid.h b/webrtc/p2p/base/sessionid.h new file mode 100644 index 0000000000..f69570039b --- /dev/null +++ b/webrtc/p2p/base/sessionid.h @@ -0,0 +1,20 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_SESSIONID_H_ +#define WEBRTC_P2P_BASE_SESSIONID_H_ + +// TODO: Remove this file. + +namespace cricket { + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_SESSIONID_H_ diff --git a/webrtc/p2p/base/stun.cc b/webrtc/p2p/base/stun.cc new file mode 100644 index 0000000000..9c22995755 --- /dev/null +++ b/webrtc/p2p/base/stun.cc @@ -0,0 +1,918 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/stun.h" + +#include <string.h> + +#include "webrtc/base/byteorder.h" +#include "webrtc/base/common.h" +#include "webrtc/base/crc32.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/messagedigest.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/stringencode.h" + +using rtc::ByteBuffer; + +namespace cricket { + +const char STUN_ERROR_REASON_TRY_ALTERNATE_SERVER[] = "Try Alternate Server"; +const char STUN_ERROR_REASON_BAD_REQUEST[] = "Bad Request"; +const char STUN_ERROR_REASON_UNAUTHORIZED[] = "Unauthorized"; +const char STUN_ERROR_REASON_FORBIDDEN[] = "Forbidden"; +const char STUN_ERROR_REASON_STALE_CREDENTIALS[] = "Stale Credentials"; +const char STUN_ERROR_REASON_ALLOCATION_MISMATCH[] = "Allocation Mismatch"; +const char STUN_ERROR_REASON_STALE_NONCE[] = "Stale Nonce"; +const char STUN_ERROR_REASON_WRONG_CREDENTIALS[] = "Wrong Credentials"; +const char STUN_ERROR_REASON_UNSUPPORTED_PROTOCOL[] = "Unsupported Protocol"; +const char STUN_ERROR_REASON_ROLE_CONFLICT[] = "Role Conflict"; +const char STUN_ERROR_REASON_SERVER_ERROR[] = "Server Error"; + +const char TURN_MAGIC_COOKIE_VALUE[] = { '\x72', '\xC6', '\x4B', '\xC6' }; +const char EMPTY_TRANSACTION_ID[] = "0000000000000000"; +const uint32_t STUN_FINGERPRINT_XOR_VALUE = 0x5354554E; + +// StunMessage + +StunMessage::StunMessage() + : type_(0), + length_(0), + transaction_id_(EMPTY_TRANSACTION_ID) { + ASSERT(IsValidTransactionId(transaction_id_)); + attrs_ = new std::vector<StunAttribute*>(); +} + +StunMessage::~StunMessage() { + for (size_t i = 0; i < attrs_->size(); i++) + delete (*attrs_)[i]; + delete attrs_; +} + +bool StunMessage::IsLegacy() const { + if (transaction_id_.size() == kStunLegacyTransactionIdLength) + return true; + ASSERT(transaction_id_.size() == kStunTransactionIdLength); + return false; +} + +bool StunMessage::SetTransactionID(const std::string& str) { + if (!IsValidTransactionId(str)) { + return false; + } + transaction_id_ = str; + return true; +} + +bool StunMessage::AddAttribute(StunAttribute* attr) { + // Fail any attributes that aren't valid for this type of message. + if (attr->value_type() != GetAttributeValueType(attr->type())) { + return false; + } + attrs_->push_back(attr); + attr->SetOwner(this); + size_t attr_length = attr->length(); + if (attr_length % 4 != 0) { + attr_length += (4 - (attr_length % 4)); + } + length_ += static_cast<uint16_t>(attr_length + 4); + return true; +} + +const StunAddressAttribute* StunMessage::GetAddress(int type) const { + switch (type) { + case STUN_ATTR_MAPPED_ADDRESS: { + // Return XOR-MAPPED-ADDRESS when MAPPED-ADDRESS attribute is + // missing. + const StunAttribute* mapped_address = + GetAttribute(STUN_ATTR_MAPPED_ADDRESS); + if (!mapped_address) + mapped_address = GetAttribute(STUN_ATTR_XOR_MAPPED_ADDRESS); + return reinterpret_cast<const StunAddressAttribute*>(mapped_address); + } + + default: + return static_cast<const StunAddressAttribute*>(GetAttribute(type)); + } +} + +const StunUInt32Attribute* StunMessage::GetUInt32(int type) const { + return static_cast<const StunUInt32Attribute*>(GetAttribute(type)); +} + +const StunUInt64Attribute* StunMessage::GetUInt64(int type) const { + return static_cast<const StunUInt64Attribute*>(GetAttribute(type)); +} + +const StunByteStringAttribute* StunMessage::GetByteString(int type) const { + return static_cast<const StunByteStringAttribute*>(GetAttribute(type)); +} + +const StunErrorCodeAttribute* StunMessage::GetErrorCode() const { + return static_cast<const StunErrorCodeAttribute*>( + GetAttribute(STUN_ATTR_ERROR_CODE)); +} + +const StunUInt16ListAttribute* StunMessage::GetUnknownAttributes() const { + return static_cast<const StunUInt16ListAttribute*>( + GetAttribute(STUN_ATTR_UNKNOWN_ATTRIBUTES)); +} + +// Verifies a STUN message has a valid MESSAGE-INTEGRITY attribute, using the +// procedure outlined in RFC 5389, section 15.4. +bool StunMessage::ValidateMessageIntegrity(const char* data, size_t size, + const std::string& password) { + // Verifying the size of the message. + if ((size % 4) != 0) { + return false; + } + + // Getting the message length from the STUN header. + uint16_t msg_length = rtc::GetBE16(&data[2]); + if (size != (msg_length + kStunHeaderSize)) { + return false; + } + + // Finding Message Integrity attribute in stun message. + size_t current_pos = kStunHeaderSize; + bool has_message_integrity_attr = false; + while (current_pos < size) { + uint16_t attr_type, attr_length; + // Getting attribute type and length. + attr_type = rtc::GetBE16(&data[current_pos]); + attr_length = rtc::GetBE16(&data[current_pos + sizeof(attr_type)]); + + // If M-I, sanity check it, and break out. + if (attr_type == STUN_ATTR_MESSAGE_INTEGRITY) { + if (attr_length != kStunMessageIntegritySize || + current_pos + attr_length > size) { + return false; + } + has_message_integrity_attr = true; + break; + } + + // Otherwise, skip to the next attribute. + current_pos += sizeof(attr_type) + sizeof(attr_length) + attr_length; + if ((attr_length % 4) != 0) { + current_pos += (4 - (attr_length % 4)); + } + } + + if (!has_message_integrity_attr) { + return false; + } + + // Getting length of the message to calculate Message Integrity. + size_t mi_pos = current_pos; + rtc::scoped_ptr<char[]> temp_data(new char[current_pos]); + memcpy(temp_data.get(), data, current_pos); + if (size > mi_pos + kStunAttributeHeaderSize + kStunMessageIntegritySize) { + // Stun message has other attributes after message integrity. + // Adjust the length parameter in stun message to calculate HMAC. + size_t extra_offset = size - + (mi_pos + kStunAttributeHeaderSize + kStunMessageIntegritySize); + size_t new_adjusted_len = size - extra_offset - kStunHeaderSize; + + // Writing new length of the STUN message @ Message Length in temp buffer. + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // |0 0| STUN Message Type | Message Length | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + rtc::SetBE16(temp_data.get() + 2, static_cast<uint16_t>(new_adjusted_len)); + } + + char hmac[kStunMessageIntegritySize]; + size_t ret = rtc::ComputeHmac(rtc::DIGEST_SHA_1, + password.c_str(), password.size(), + temp_data.get(), mi_pos, + hmac, sizeof(hmac)); + ASSERT(ret == sizeof(hmac)); + if (ret != sizeof(hmac)) + return false; + + // Comparing the calculated HMAC with the one present in the message. + return memcmp(data + current_pos + kStunAttributeHeaderSize, + hmac, + sizeof(hmac)) == 0; +} + +bool StunMessage::AddMessageIntegrity(const std::string& password) { + return AddMessageIntegrity(password.c_str(), password.size()); +} + +bool StunMessage::AddMessageIntegrity(const char* key, + size_t keylen) { + // Add the attribute with a dummy value. Since this is a known attribute, it + // can't fail. + StunByteStringAttribute* msg_integrity_attr = + new StunByteStringAttribute(STUN_ATTR_MESSAGE_INTEGRITY, + std::string(kStunMessageIntegritySize, '0')); + VERIFY(AddAttribute(msg_integrity_attr)); + + // Calculate the HMAC for the message. + rtc::ByteBuffer buf; + if (!Write(&buf)) + return false; + + int msg_len_for_hmac = static_cast<int>( + buf.Length() - kStunAttributeHeaderSize - msg_integrity_attr->length()); + char hmac[kStunMessageIntegritySize]; + size_t ret = rtc::ComputeHmac(rtc::DIGEST_SHA_1, + key, keylen, + buf.Data(), msg_len_for_hmac, + hmac, sizeof(hmac)); + ASSERT(ret == sizeof(hmac)); + if (ret != sizeof(hmac)) { + LOG(LS_ERROR) << "HMAC computation failed. Message-Integrity " + << "has dummy value."; + return false; + } + + // Insert correct HMAC into the attribute. + msg_integrity_attr->CopyBytes(hmac, sizeof(hmac)); + return true; +} + +// Verifies a message is in fact a STUN message, by performing the checks +// outlined in RFC 5389, section 7.3, including the FINGERPRINT check detailed +// in section 15.5. +bool StunMessage::ValidateFingerprint(const char* data, size_t size) { + // Check the message length. + size_t fingerprint_attr_size = + kStunAttributeHeaderSize + StunUInt32Attribute::SIZE; + if (size % 4 != 0 || size < kStunHeaderSize + fingerprint_attr_size) + return false; + + // Skip the rest if the magic cookie isn't present. + const char* magic_cookie = + data + kStunTransactionIdOffset - kStunMagicCookieLength; + if (rtc::GetBE32(magic_cookie) != kStunMagicCookie) + return false; + + // Check the fingerprint type and length. + const char* fingerprint_attr_data = data + size - fingerprint_attr_size; + if (rtc::GetBE16(fingerprint_attr_data) != STUN_ATTR_FINGERPRINT || + rtc::GetBE16(fingerprint_attr_data + sizeof(uint16_t)) != + StunUInt32Attribute::SIZE) + return false; + + // Check the fingerprint value. + uint32_t fingerprint = + rtc::GetBE32(fingerprint_attr_data + kStunAttributeHeaderSize); + return ((fingerprint ^ STUN_FINGERPRINT_XOR_VALUE) == + rtc::ComputeCrc32(data, size - fingerprint_attr_size)); +} + +bool StunMessage::AddFingerprint() { + // Add the attribute with a dummy value. Since this is a known attribute, + // it can't fail. + StunUInt32Attribute* fingerprint_attr = + new StunUInt32Attribute(STUN_ATTR_FINGERPRINT, 0); + VERIFY(AddAttribute(fingerprint_attr)); + + // Calculate the CRC-32 for the message and insert it. + rtc::ByteBuffer buf; + if (!Write(&buf)) + return false; + + int msg_len_for_crc32 = static_cast<int>( + buf.Length() - kStunAttributeHeaderSize - fingerprint_attr->length()); + uint32_t c = rtc::ComputeCrc32(buf.Data(), msg_len_for_crc32); + + // Insert the correct CRC-32, XORed with a constant, into the attribute. + fingerprint_attr->SetValue(c ^ STUN_FINGERPRINT_XOR_VALUE); + return true; +} + +bool StunMessage::Read(ByteBuffer* buf) { + if (!buf->ReadUInt16(&type_)) + return false; + + if (type_ & 0x8000) { + // RTP and RTCP set the MSB of first byte, since first two bits are version, + // and version is always 2 (10). If set, this is not a STUN packet. + return false; + } + + if (!buf->ReadUInt16(&length_)) + return false; + + std::string magic_cookie; + if (!buf->ReadString(&magic_cookie, kStunMagicCookieLength)) + return false; + + std::string transaction_id; + if (!buf->ReadString(&transaction_id, kStunTransactionIdLength)) + return false; + + uint32_t magic_cookie_int = + *reinterpret_cast<const uint32_t*>(magic_cookie.data()); + if (rtc::NetworkToHost32(magic_cookie_int) != kStunMagicCookie) { + // If magic cookie is invalid it means that the peer implements + // RFC3489 instead of RFC5389. + transaction_id.insert(0, magic_cookie); + } + ASSERT(IsValidTransactionId(transaction_id)); + transaction_id_ = transaction_id; + + if (length_ != buf->Length()) + return false; + + attrs_->resize(0); + + size_t rest = buf->Length() - length_; + while (buf->Length() > rest) { + uint16_t attr_type, attr_length; + if (!buf->ReadUInt16(&attr_type)) + return false; + if (!buf->ReadUInt16(&attr_length)) + return false; + + StunAttribute* attr = CreateAttribute(attr_type, attr_length); + if (!attr) { + // Skip any unknown or malformed attributes. + if ((attr_length % 4) != 0) { + attr_length += (4 - (attr_length % 4)); + } + if (!buf->Consume(attr_length)) + return false; + } else { + if (!attr->Read(buf)) + return false; + attrs_->push_back(attr); + } + } + + ASSERT(buf->Length() == rest); + return true; +} + +bool StunMessage::Write(ByteBuffer* buf) const { + buf->WriteUInt16(type_); + buf->WriteUInt16(length_); + if (!IsLegacy()) + buf->WriteUInt32(kStunMagicCookie); + buf->WriteString(transaction_id_); + + for (size_t i = 0; i < attrs_->size(); ++i) { + buf->WriteUInt16((*attrs_)[i]->type()); + buf->WriteUInt16(static_cast<uint16_t>((*attrs_)[i]->length())); + if (!(*attrs_)[i]->Write(buf)) + return false; + } + + return true; +} + +StunAttributeValueType StunMessage::GetAttributeValueType(int type) const { + switch (type) { + case STUN_ATTR_MAPPED_ADDRESS: return STUN_VALUE_ADDRESS; + case STUN_ATTR_USERNAME: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_MESSAGE_INTEGRITY: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_ERROR_CODE: return STUN_VALUE_ERROR_CODE; + case STUN_ATTR_UNKNOWN_ATTRIBUTES: return STUN_VALUE_UINT16_LIST; + case STUN_ATTR_REALM: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_NONCE: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_XOR_MAPPED_ADDRESS: return STUN_VALUE_XOR_ADDRESS; + case STUN_ATTR_SOFTWARE: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_ALTERNATE_SERVER: return STUN_VALUE_ADDRESS; + case STUN_ATTR_FINGERPRINT: return STUN_VALUE_UINT32; + case STUN_ATTR_ORIGIN: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_RETRANSMIT_COUNT: return STUN_VALUE_UINT32; + default: return STUN_VALUE_UNKNOWN; + } +} + +StunAttribute* StunMessage::CreateAttribute(int type, size_t length) /*const*/ { + StunAttributeValueType value_type = GetAttributeValueType(type); + return StunAttribute::Create(value_type, type, static_cast<uint16_t>(length), + this); +} + +const StunAttribute* StunMessage::GetAttribute(int type) const { + for (size_t i = 0; i < attrs_->size(); ++i) { + if ((*attrs_)[i]->type() == type) + return (*attrs_)[i]; + } + return NULL; +} + +bool StunMessage::IsValidTransactionId(const std::string& transaction_id) { + return transaction_id.size() == kStunTransactionIdLength || + transaction_id.size() == kStunLegacyTransactionIdLength; +} + +// StunAttribute + +StunAttribute::StunAttribute(uint16_t type, uint16_t length) + : type_(type), length_(length) { +} + +void StunAttribute::ConsumePadding(rtc::ByteBuffer* buf) const { + int remainder = length_ % 4; + if (remainder > 0) { + buf->Consume(4 - remainder); + } +} + +void StunAttribute::WritePadding(rtc::ByteBuffer* buf) const { + int remainder = length_ % 4; + if (remainder > 0) { + char zeroes[4] = {0}; + buf->WriteBytes(zeroes, 4 - remainder); + } +} + +StunAttribute* StunAttribute::Create(StunAttributeValueType value_type, + uint16_t type, + uint16_t length, + StunMessage* owner) { + switch (value_type) { + case STUN_VALUE_ADDRESS: + return new StunAddressAttribute(type, length); + case STUN_VALUE_XOR_ADDRESS: + return new StunXorAddressAttribute(type, length, owner); + case STUN_VALUE_UINT32: + return new StunUInt32Attribute(type); + case STUN_VALUE_UINT64: + return new StunUInt64Attribute(type); + case STUN_VALUE_BYTE_STRING: + return new StunByteStringAttribute(type, length); + case STUN_VALUE_ERROR_CODE: + return new StunErrorCodeAttribute(type, length); + case STUN_VALUE_UINT16_LIST: + return new StunUInt16ListAttribute(type, length); + default: + return NULL; + } +} + +StunAddressAttribute* StunAttribute::CreateAddress(uint16_t type) { + return new StunAddressAttribute(type, 0); +} + +StunXorAddressAttribute* StunAttribute::CreateXorAddress(uint16_t type) { + return new StunXorAddressAttribute(type, 0, NULL); +} + +StunUInt64Attribute* StunAttribute::CreateUInt64(uint16_t type) { + return new StunUInt64Attribute(type); +} + +StunUInt32Attribute* StunAttribute::CreateUInt32(uint16_t type) { + return new StunUInt32Attribute(type); +} + +StunByteStringAttribute* StunAttribute::CreateByteString(uint16_t type) { + return new StunByteStringAttribute(type, 0); +} + +StunErrorCodeAttribute* StunAttribute::CreateErrorCode() { + return new StunErrorCodeAttribute( + STUN_ATTR_ERROR_CODE, StunErrorCodeAttribute::MIN_SIZE); +} + +StunUInt16ListAttribute* StunAttribute::CreateUnknownAttributes() { + return new StunUInt16ListAttribute(STUN_ATTR_UNKNOWN_ATTRIBUTES, 0); +} + +StunAddressAttribute::StunAddressAttribute(uint16_t type, + const rtc::SocketAddress& addr) + : StunAttribute(type, 0) { + SetAddress(addr); +} + +StunAddressAttribute::StunAddressAttribute(uint16_t type, uint16_t length) + : StunAttribute(type, length) { +} + +bool StunAddressAttribute::Read(ByteBuffer* buf) { + uint8_t dummy; + if (!buf->ReadUInt8(&dummy)) + return false; + + uint8_t stun_family; + if (!buf->ReadUInt8(&stun_family)) { + return false; + } + uint16_t port; + if (!buf->ReadUInt16(&port)) + return false; + if (stun_family == STUN_ADDRESS_IPV4) { + in_addr v4addr; + if (length() != SIZE_IP4) { + return false; + } + if (!buf->ReadBytes(reinterpret_cast<char*>(&v4addr), sizeof(v4addr))) { + return false; + } + rtc::IPAddress ipaddr(v4addr); + SetAddress(rtc::SocketAddress(ipaddr, port)); + } else if (stun_family == STUN_ADDRESS_IPV6) { + in6_addr v6addr; + if (length() != SIZE_IP6) { + return false; + } + if (!buf->ReadBytes(reinterpret_cast<char*>(&v6addr), sizeof(v6addr))) { + return false; + } + rtc::IPAddress ipaddr(v6addr); + SetAddress(rtc::SocketAddress(ipaddr, port)); + } else { + return false; + } + return true; +} + +bool StunAddressAttribute::Write(ByteBuffer* buf) const { + StunAddressFamily address_family = family(); + if (address_family == STUN_ADDRESS_UNDEF) { + LOG(LS_ERROR) << "Error writing address attribute: unknown family."; + return false; + } + buf->WriteUInt8(0); + buf->WriteUInt8(address_family); + buf->WriteUInt16(address_.port()); + switch (address_.family()) { + case AF_INET: { + in_addr v4addr = address_.ipaddr().ipv4_address(); + buf->WriteBytes(reinterpret_cast<char*>(&v4addr), sizeof(v4addr)); + break; + } + case AF_INET6: { + in6_addr v6addr = address_.ipaddr().ipv6_address(); + buf->WriteBytes(reinterpret_cast<char*>(&v6addr), sizeof(v6addr)); + break; + } + } + return true; +} + +StunXorAddressAttribute::StunXorAddressAttribute(uint16_t type, + const rtc::SocketAddress& addr) + : StunAddressAttribute(type, addr), owner_(NULL) { +} + +StunXorAddressAttribute::StunXorAddressAttribute(uint16_t type, + uint16_t length, + StunMessage* owner) + : StunAddressAttribute(type, length), owner_(owner) { +} + +rtc::IPAddress StunXorAddressAttribute::GetXoredIP() const { + if (owner_) { + rtc::IPAddress ip = ipaddr(); + switch (ip.family()) { + case AF_INET: { + in_addr v4addr = ip.ipv4_address(); + v4addr.s_addr = + (v4addr.s_addr ^ rtc::HostToNetwork32(kStunMagicCookie)); + return rtc::IPAddress(v4addr); + } + case AF_INET6: { + in6_addr v6addr = ip.ipv6_address(); + const std::string& transaction_id = owner_->transaction_id(); + if (transaction_id.length() == kStunTransactionIdLength) { + uint32_t transactionid_as_ints[3]; + memcpy(&transactionid_as_ints[0], transaction_id.c_str(), + transaction_id.length()); + uint32_t* ip_as_ints = reinterpret_cast<uint32_t*>(&v6addr.s6_addr); + // Transaction ID is in network byte order, but magic cookie + // is stored in host byte order. + ip_as_ints[0] = + (ip_as_ints[0] ^ rtc::HostToNetwork32(kStunMagicCookie)); + ip_as_ints[1] = (ip_as_ints[1] ^ transactionid_as_ints[0]); + ip_as_ints[2] = (ip_as_ints[2] ^ transactionid_as_ints[1]); + ip_as_ints[3] = (ip_as_ints[3] ^ transactionid_as_ints[2]); + return rtc::IPAddress(v6addr); + } + break; + } + } + } + // Invalid ip family or transaction ID, or missing owner. + // Return an AF_UNSPEC address. + return rtc::IPAddress(); +} + +bool StunXorAddressAttribute::Read(ByteBuffer* buf) { + if (!StunAddressAttribute::Read(buf)) + return false; + uint16_t xoredport = port() ^ (kStunMagicCookie >> 16); + rtc::IPAddress xored_ip = GetXoredIP(); + SetAddress(rtc::SocketAddress(xored_ip, xoredport)); + return true; +} + +bool StunXorAddressAttribute::Write(ByteBuffer* buf) const { + StunAddressFamily address_family = family(); + if (address_family == STUN_ADDRESS_UNDEF) { + LOG(LS_ERROR) << "Error writing xor-address attribute: unknown family."; + return false; + } + rtc::IPAddress xored_ip = GetXoredIP(); + if (xored_ip.family() == AF_UNSPEC) { + return false; + } + buf->WriteUInt8(0); + buf->WriteUInt8(family()); + buf->WriteUInt16(port() ^ (kStunMagicCookie >> 16)); + switch (xored_ip.family()) { + case AF_INET: { + in_addr v4addr = xored_ip.ipv4_address(); + buf->WriteBytes(reinterpret_cast<const char*>(&v4addr), sizeof(v4addr)); + break; + } + case AF_INET6: { + in6_addr v6addr = xored_ip.ipv6_address(); + buf->WriteBytes(reinterpret_cast<const char*>(&v6addr), sizeof(v6addr)); + break; + } + } + return true; +} + +StunUInt32Attribute::StunUInt32Attribute(uint16_t type, uint32_t value) + : StunAttribute(type, SIZE), bits_(value) { +} + +StunUInt32Attribute::StunUInt32Attribute(uint16_t type) + : StunAttribute(type, SIZE), bits_(0) { +} + +bool StunUInt32Attribute::GetBit(size_t index) const { + ASSERT(index < 32); + return static_cast<bool>((bits_ >> index) & 0x1); +} + +void StunUInt32Attribute::SetBit(size_t index, bool value) { + ASSERT(index < 32); + bits_ &= ~(1 << index); + bits_ |= value ? (1 << index) : 0; +} + +bool StunUInt32Attribute::Read(ByteBuffer* buf) { + if (length() != SIZE || !buf->ReadUInt32(&bits_)) + return false; + return true; +} + +bool StunUInt32Attribute::Write(ByteBuffer* buf) const { + buf->WriteUInt32(bits_); + return true; +} + +StunUInt64Attribute::StunUInt64Attribute(uint16_t type, uint64_t value) + : StunAttribute(type, SIZE), bits_(value) { +} + +StunUInt64Attribute::StunUInt64Attribute(uint16_t type) + : StunAttribute(type, SIZE), bits_(0) { +} + +bool StunUInt64Attribute::Read(ByteBuffer* buf) { + if (length() != SIZE || !buf->ReadUInt64(&bits_)) + return false; + return true; +} + +bool StunUInt64Attribute::Write(ByteBuffer* buf) const { + buf->WriteUInt64(bits_); + return true; +} + +StunByteStringAttribute::StunByteStringAttribute(uint16_t type) + : StunAttribute(type, 0), bytes_(NULL) { +} + +StunByteStringAttribute::StunByteStringAttribute(uint16_t type, + const std::string& str) + : StunAttribute(type, 0), bytes_(NULL) { + CopyBytes(str.c_str(), str.size()); +} + +StunByteStringAttribute::StunByteStringAttribute(uint16_t type, + const void* bytes, + size_t length) + : StunAttribute(type, 0), bytes_(NULL) { + CopyBytes(bytes, length); +} + +StunByteStringAttribute::StunByteStringAttribute(uint16_t type, uint16_t length) + : StunAttribute(type, length), bytes_(NULL) { +} + +StunByteStringAttribute::~StunByteStringAttribute() { + delete [] bytes_; +} + +void StunByteStringAttribute::CopyBytes(const char* bytes) { + CopyBytes(bytes, strlen(bytes)); +} + +void StunByteStringAttribute::CopyBytes(const void* bytes, size_t length) { + char* new_bytes = new char[length]; + memcpy(new_bytes, bytes, length); + SetBytes(new_bytes, length); +} + +uint8_t StunByteStringAttribute::GetByte(size_t index) const { + ASSERT(bytes_ != NULL); + ASSERT(index < length()); + return static_cast<uint8_t>(bytes_[index]); +} + +void StunByteStringAttribute::SetByte(size_t index, uint8_t value) { + ASSERT(bytes_ != NULL); + ASSERT(index < length()); + bytes_[index] = value; +} + +bool StunByteStringAttribute::Read(ByteBuffer* buf) { + bytes_ = new char[length()]; + if (!buf->ReadBytes(bytes_, length())) { + return false; + } + + ConsumePadding(buf); + return true; +} + +bool StunByteStringAttribute::Write(ByteBuffer* buf) const { + buf->WriteBytes(bytes_, length()); + WritePadding(buf); + return true; +} + +void StunByteStringAttribute::SetBytes(char* bytes, size_t length) { + delete [] bytes_; + bytes_ = bytes; + SetLength(static_cast<uint16_t>(length)); +} + +StunErrorCodeAttribute::StunErrorCodeAttribute(uint16_t type, + int code, + const std::string& reason) + : StunAttribute(type, 0) { + SetCode(code); + SetReason(reason); +} + +StunErrorCodeAttribute::StunErrorCodeAttribute(uint16_t type, uint16_t length) + : StunAttribute(type, length), class_(0), number_(0) { +} + +StunErrorCodeAttribute::~StunErrorCodeAttribute() { +} + +int StunErrorCodeAttribute::code() const { + return class_ * 100 + number_; +} + +void StunErrorCodeAttribute::SetCode(int code) { + class_ = static_cast<uint8_t>(code / 100); + number_ = static_cast<uint8_t>(code % 100); +} + +void StunErrorCodeAttribute::SetReason(const std::string& reason) { + SetLength(MIN_SIZE + static_cast<uint16_t>(reason.size())); + reason_ = reason; +} + +bool StunErrorCodeAttribute::Read(ByteBuffer* buf) { + uint32_t val; + if (length() < MIN_SIZE || !buf->ReadUInt32(&val)) + return false; + + if ((val >> 11) != 0) + LOG(LS_ERROR) << "error-code bits not zero"; + + class_ = ((val >> 8) & 0x7); + number_ = (val & 0xff); + + if (!buf->ReadString(&reason_, length() - 4)) + return false; + + ConsumePadding(buf); + return true; +} + +bool StunErrorCodeAttribute::Write(ByteBuffer* buf) const { + buf->WriteUInt32(class_ << 8 | number_); + buf->WriteString(reason_); + WritePadding(buf); + return true; +} + +StunUInt16ListAttribute::StunUInt16ListAttribute(uint16_t type, uint16_t length) + : StunAttribute(type, length) { + attr_types_ = new std::vector<uint16_t>(); +} + +StunUInt16ListAttribute::~StunUInt16ListAttribute() { + delete attr_types_; +} + +size_t StunUInt16ListAttribute::Size() const { + return attr_types_->size(); +} + +uint16_t StunUInt16ListAttribute::GetType(int index) const { + return (*attr_types_)[index]; +} + +void StunUInt16ListAttribute::SetType(int index, uint16_t value) { + (*attr_types_)[index] = value; +} + +void StunUInt16ListAttribute::AddType(uint16_t value) { + attr_types_->push_back(value); + SetLength(static_cast<uint16_t>(attr_types_->size() * 2)); +} + +bool StunUInt16ListAttribute::Read(ByteBuffer* buf) { + if (length() % 2) + return false; + + for (size_t i = 0; i < length() / 2; i++) { + uint16_t attr; + if (!buf->ReadUInt16(&attr)) + return false; + attr_types_->push_back(attr); + } + // Padding of these attributes is done in RFC 5389 style. This is + // slightly different from RFC3489, but it shouldn't be important. + // RFC3489 pads out to a 32 bit boundary by duplicating one of the + // entries in the list (not necessarily the last one - it's unspecified). + // RFC5389 pads on the end, and the bytes are always ignored. + ConsumePadding(buf); + return true; +} + +bool StunUInt16ListAttribute::Write(ByteBuffer* buf) const { + for (size_t i = 0; i < attr_types_->size(); ++i) { + buf->WriteUInt16((*attr_types_)[i]); + } + WritePadding(buf); + return true; +} + +int GetStunSuccessResponseType(int req_type) { + return IsStunRequestType(req_type) ? (req_type | 0x100) : -1; +} + +int GetStunErrorResponseType(int req_type) { + return IsStunRequestType(req_type) ? (req_type | 0x110) : -1; +} + +bool IsStunRequestType(int msg_type) { + return ((msg_type & kStunTypeMask) == 0x000); +} + +bool IsStunIndicationType(int msg_type) { + return ((msg_type & kStunTypeMask) == 0x010); +} + +bool IsStunSuccessResponseType(int msg_type) { + return ((msg_type & kStunTypeMask) == 0x100); +} + +bool IsStunErrorResponseType(int msg_type) { + return ((msg_type & kStunTypeMask) == 0x110); +} + +bool ComputeStunCredentialHash(const std::string& username, + const std::string& realm, + const std::string& password, + std::string* hash) { + // http://tools.ietf.org/html/rfc5389#section-15.4 + // long-term credentials will be calculated using the key and key is + // key = MD5(username ":" realm ":" SASLprep(password)) + std::string input = username; + input += ':'; + input += realm; + input += ':'; + input += password; + + char digest[rtc::MessageDigest::kMaxSize]; + size_t size = rtc::ComputeDigest( + rtc::DIGEST_MD5, input.c_str(), input.size(), + digest, sizeof(digest)); + if (size == 0) { + return false; + } + + *hash = std::string(digest, size); + return true; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/stun.h b/webrtc/p2p/base/stun.h new file mode 100644 index 0000000000..75b89afb8a --- /dev/null +++ b/webrtc/p2p/base/stun.h @@ -0,0 +1,634 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_STUN_H_ +#define WEBRTC_P2P_BASE_STUN_H_ + +// This file contains classes for dealing with the STUN protocol, as specified +// in RFC 5389, and its descendants. + +#include <string> +#include <vector> + +#include "webrtc/base/basictypes.h" +#include "webrtc/base/bytebuffer.h" +#include "webrtc/base/socketaddress.h" + +namespace cricket { + +// These are the types of STUN messages defined in RFC 5389. +enum StunMessageType { + STUN_BINDING_REQUEST = 0x0001, + STUN_BINDING_INDICATION = 0x0011, + STUN_BINDING_RESPONSE = 0x0101, + STUN_BINDING_ERROR_RESPONSE = 0x0111, +}; + +// These are all known STUN attributes, defined in RFC 5389 and elsewhere. +// Next to each is the name of the class (T is StunTAttribute) that implements +// that type. +// RETRANSMIT_COUNT is the number of outstanding pings without a response at +// the time the packet is generated. +enum StunAttributeType { + STUN_ATTR_MAPPED_ADDRESS = 0x0001, // Address + STUN_ATTR_USERNAME = 0x0006, // ByteString + STUN_ATTR_MESSAGE_INTEGRITY = 0x0008, // ByteString, 20 bytes + STUN_ATTR_ERROR_CODE = 0x0009, // ErrorCode + STUN_ATTR_UNKNOWN_ATTRIBUTES = 0x000a, // UInt16List + STUN_ATTR_REALM = 0x0014, // ByteString + STUN_ATTR_NONCE = 0x0015, // ByteString + STUN_ATTR_XOR_MAPPED_ADDRESS = 0x0020, // XorAddress + STUN_ATTR_SOFTWARE = 0x8022, // ByteString + STUN_ATTR_ALTERNATE_SERVER = 0x8023, // Address + STUN_ATTR_FINGERPRINT = 0x8028, // UInt32 + STUN_ATTR_ORIGIN = 0x802F, // ByteString + STUN_ATTR_RETRANSMIT_COUNT = 0xFF00 // UInt32 +}; + +// These are the types of the values associated with the attributes above. +// This allows us to perform some basic validation when reading or adding +// attributes. Note that these values are for our own use, and not defined in +// RFC 5389. +enum StunAttributeValueType { + STUN_VALUE_UNKNOWN = 0, + STUN_VALUE_ADDRESS = 1, + STUN_VALUE_XOR_ADDRESS = 2, + STUN_VALUE_UINT32 = 3, + STUN_VALUE_UINT64 = 4, + STUN_VALUE_BYTE_STRING = 5, + STUN_VALUE_ERROR_CODE = 6, + STUN_VALUE_UINT16_LIST = 7 +}; + +// These are the types of STUN addresses defined in RFC 5389. +enum StunAddressFamily { + // NB: UNDEF is not part of the STUN spec. + STUN_ADDRESS_UNDEF = 0, + STUN_ADDRESS_IPV4 = 1, + STUN_ADDRESS_IPV6 = 2 +}; + +// These are the types of STUN error codes defined in RFC 5389. +enum StunErrorCode { + STUN_ERROR_TRY_ALTERNATE = 300, + STUN_ERROR_BAD_REQUEST = 400, + STUN_ERROR_UNAUTHORIZED = 401, + STUN_ERROR_UNKNOWN_ATTRIBUTE = 420, + STUN_ERROR_STALE_CREDENTIALS = 430, // GICE only + STUN_ERROR_STALE_NONCE = 438, + STUN_ERROR_SERVER_ERROR = 500, + STUN_ERROR_GLOBAL_FAILURE = 600 +}; + +// Strings for the error codes above. +extern const char STUN_ERROR_REASON_TRY_ALTERNATE_SERVER[]; +extern const char STUN_ERROR_REASON_BAD_REQUEST[]; +extern const char STUN_ERROR_REASON_UNAUTHORIZED[]; +extern const char STUN_ERROR_REASON_UNKNOWN_ATTRIBUTE[]; +extern const char STUN_ERROR_REASON_STALE_CREDENTIALS[]; +extern const char STUN_ERROR_REASON_STALE_NONCE[]; +extern const char STUN_ERROR_REASON_SERVER_ERROR[]; + +// The mask used to determine whether a STUN message is a request/response etc. +const uint32_t kStunTypeMask = 0x0110; + +// STUN Attribute header length. +const size_t kStunAttributeHeaderSize = 4; + +// Following values correspond to RFC5389. +const size_t kStunHeaderSize = 20; +const size_t kStunTransactionIdOffset = 8; +const size_t kStunTransactionIdLength = 12; +const uint32_t kStunMagicCookie = 0x2112A442; +const size_t kStunMagicCookieLength = sizeof(kStunMagicCookie); + +// Following value corresponds to an earlier version of STUN from +// RFC3489. +const size_t kStunLegacyTransactionIdLength = 16; + +// STUN Message Integrity HMAC length. +const size_t kStunMessageIntegritySize = 20; + +class StunAttribute; +class StunAddressAttribute; +class StunXorAddressAttribute; +class StunUInt32Attribute; +class StunUInt64Attribute; +class StunByteStringAttribute; +class StunErrorCodeAttribute; +class StunUInt16ListAttribute; + +// Records a complete STUN/TURN message. Each message consists of a type and +// any number of attributes. Each attribute is parsed into an instance of an +// appropriate class (see above). The Get* methods will return instances of +// that attribute class. +class StunMessage { + public: + StunMessage(); + virtual ~StunMessage(); + + int type() const { return type_; } + size_t length() const { return length_; } + const std::string& transaction_id() const { return transaction_id_; } + + // Returns true if the message confirms to RFC3489 rather than + // RFC5389. The main difference between two version of the STUN + // protocol is the presence of the magic cookie and different length + // of transaction ID. For outgoing packets version of the protocol + // is determined by the lengths of the transaction ID. + bool IsLegacy() const; + + void SetType(int type) { type_ = static_cast<uint16_t>(type); } + bool SetTransactionID(const std::string& str); + + // Gets the desired attribute value, or NULL if no such attribute type exists. + const StunAddressAttribute* GetAddress(int type) const; + const StunUInt32Attribute* GetUInt32(int type) const; + const StunUInt64Attribute* GetUInt64(int type) const; + const StunByteStringAttribute* GetByteString(int type) const; + + // Gets these specific attribute values. + const StunErrorCodeAttribute* GetErrorCode() const; + const StunUInt16ListAttribute* GetUnknownAttributes() const; + + // Takes ownership of the specified attribute, verifies it is of the correct + // type, and adds it to the message. The return value indicates whether this + // was successful. + bool AddAttribute(StunAttribute* attr); + + // Validates that a raw STUN message has a correct MESSAGE-INTEGRITY value. + // This can't currently be done on a StunMessage, since it is affected by + // padding data (which we discard when reading a StunMessage). + static bool ValidateMessageIntegrity(const char* data, size_t size, + const std::string& password); + // Adds a MESSAGE-INTEGRITY attribute that is valid for the current message. + bool AddMessageIntegrity(const std::string& password); + bool AddMessageIntegrity(const char* key, size_t keylen); + + // Verifies that a given buffer is STUN by checking for a correct FINGERPRINT. + static bool ValidateFingerprint(const char* data, size_t size); + + // Adds a FINGERPRINT attribute that is valid for the current message. + bool AddFingerprint(); + + // Parses the STUN packet in the given buffer and records it here. The + // return value indicates whether this was successful. + bool Read(rtc::ByteBuffer* buf); + + // Writes this object into a STUN packet. The return value indicates whether + // this was successful. + bool Write(rtc::ByteBuffer* buf) const; + + // Creates an empty message. Overridable by derived classes. + virtual StunMessage* CreateNew() const { return new StunMessage(); } + + protected: + // Verifies that the given attribute is allowed for this message. + virtual StunAttributeValueType GetAttributeValueType(int type) const; + + private: + StunAttribute* CreateAttribute(int type, size_t length) /* const*/; + const StunAttribute* GetAttribute(int type) const; + static bool IsValidTransactionId(const std::string& transaction_id); + + uint16_t type_; + uint16_t length_; + std::string transaction_id_; + std::vector<StunAttribute*>* attrs_; +}; + +// Base class for all STUN/TURN attributes. +class StunAttribute { + public: + virtual ~StunAttribute() { + } + + int type() const { return type_; } + size_t length() const { return length_; } + + // Return the type of this attribute. + virtual StunAttributeValueType value_type() const = 0; + + // Only XorAddressAttribute needs this so far. + virtual void SetOwner(StunMessage* owner) {} + + // Reads the body (not the type or length) for this type of attribute from + // the given buffer. Return value is true if successful. + virtual bool Read(rtc::ByteBuffer* buf) = 0; + + // Writes the body (not the type or length) to the given buffer. Return + // value is true if successful. + virtual bool Write(rtc::ByteBuffer* buf) const = 0; + + // Creates an attribute object with the given type and smallest length. + static StunAttribute* Create(StunAttributeValueType value_type, + uint16_t type, + uint16_t length, + StunMessage* owner); + // TODO: Allow these create functions to take parameters, to reduce + // the amount of work callers need to do to initialize attributes. + static StunAddressAttribute* CreateAddress(uint16_t type); + static StunXorAddressAttribute* CreateXorAddress(uint16_t type); + static StunUInt32Attribute* CreateUInt32(uint16_t type); + static StunUInt64Attribute* CreateUInt64(uint16_t type); + static StunByteStringAttribute* CreateByteString(uint16_t type); + static StunErrorCodeAttribute* CreateErrorCode(); + static StunUInt16ListAttribute* CreateUnknownAttributes(); + + protected: + StunAttribute(uint16_t type, uint16_t length); + void SetLength(uint16_t length) { length_ = length; } + void WritePadding(rtc::ByteBuffer* buf) const; + void ConsumePadding(rtc::ByteBuffer* buf) const; + + private: + uint16_t type_; + uint16_t length_; +}; + +// Implements STUN attributes that record an Internet address. +class StunAddressAttribute : public StunAttribute { + public: + static const uint16_t SIZE_UNDEF = 0; + static const uint16_t SIZE_IP4 = 8; + static const uint16_t SIZE_IP6 = 20; + StunAddressAttribute(uint16_t type, const rtc::SocketAddress& addr); + StunAddressAttribute(uint16_t type, uint16_t length); + + virtual StunAttributeValueType value_type() const { + return STUN_VALUE_ADDRESS; + } + + StunAddressFamily family() const { + switch (address_.ipaddr().family()) { + case AF_INET: + return STUN_ADDRESS_IPV4; + case AF_INET6: + return STUN_ADDRESS_IPV6; + } + return STUN_ADDRESS_UNDEF; + } + + const rtc::SocketAddress& GetAddress() const { return address_; } + const rtc::IPAddress& ipaddr() const { return address_.ipaddr(); } + uint16_t port() const { return address_.port(); } + + void SetAddress(const rtc::SocketAddress& addr) { + address_ = addr; + EnsureAddressLength(); + } + void SetIP(const rtc::IPAddress& ip) { + address_.SetIP(ip); + EnsureAddressLength(); + } + void SetPort(uint16_t port) { address_.SetPort(port); } + + virtual bool Read(rtc::ByteBuffer* buf); + virtual bool Write(rtc::ByteBuffer* buf) const; + + private: + void EnsureAddressLength() { + switch (family()) { + case STUN_ADDRESS_IPV4: { + SetLength(SIZE_IP4); + break; + } + case STUN_ADDRESS_IPV6: { + SetLength(SIZE_IP6); + break; + } + default: { + SetLength(SIZE_UNDEF); + break; + } + } + } + rtc::SocketAddress address_; +}; + +// Implements STUN attributes that record an Internet address. When encoded +// in a STUN message, the address contained in this attribute is XORed with the +// transaction ID of the message. +class StunXorAddressAttribute : public StunAddressAttribute { + public: + StunXorAddressAttribute(uint16_t type, const rtc::SocketAddress& addr); + StunXorAddressAttribute(uint16_t type, uint16_t length, StunMessage* owner); + + virtual StunAttributeValueType value_type() const { + return STUN_VALUE_XOR_ADDRESS; + } + virtual void SetOwner(StunMessage* owner) { + owner_ = owner; + } + virtual bool Read(rtc::ByteBuffer* buf); + virtual bool Write(rtc::ByteBuffer* buf) const; + + private: + rtc::IPAddress GetXoredIP() const; + StunMessage* owner_; +}; + +// Implements STUN attributes that record a 32-bit integer. +class StunUInt32Attribute : public StunAttribute { + public: + static const uint16_t SIZE = 4; + StunUInt32Attribute(uint16_t type, uint32_t value); + explicit StunUInt32Attribute(uint16_t type); + + virtual StunAttributeValueType value_type() const { + return STUN_VALUE_UINT32; + } + + uint32_t value() const { return bits_; } + void SetValue(uint32_t bits) { bits_ = bits; } + + bool GetBit(size_t index) const; + void SetBit(size_t index, bool value); + + virtual bool Read(rtc::ByteBuffer* buf); + virtual bool Write(rtc::ByteBuffer* buf) const; + + private: + uint32_t bits_; +}; + +class StunUInt64Attribute : public StunAttribute { + public: + static const uint16_t SIZE = 8; + StunUInt64Attribute(uint16_t type, uint64_t value); + explicit StunUInt64Attribute(uint16_t type); + + virtual StunAttributeValueType value_type() const { + return STUN_VALUE_UINT64; + } + + uint64_t value() const { return bits_; } + void SetValue(uint64_t bits) { bits_ = bits; } + + virtual bool Read(rtc::ByteBuffer* buf); + virtual bool Write(rtc::ByteBuffer* buf) const; + + private: + uint64_t bits_; +}; + +// Implements STUN attributes that record an arbitrary byte string. +class StunByteStringAttribute : public StunAttribute { + public: + explicit StunByteStringAttribute(uint16_t type); + StunByteStringAttribute(uint16_t type, const std::string& str); + StunByteStringAttribute(uint16_t type, const void* bytes, size_t length); + StunByteStringAttribute(uint16_t type, uint16_t length); + ~StunByteStringAttribute(); + + virtual StunAttributeValueType value_type() const { + return STUN_VALUE_BYTE_STRING; + } + + const char* bytes() const { return bytes_; } + std::string GetString() const { return std::string(bytes_, length()); } + + void CopyBytes(const char* bytes); // uses strlen + void CopyBytes(const void* bytes, size_t length); + + uint8_t GetByte(size_t index) const; + void SetByte(size_t index, uint8_t value); + + virtual bool Read(rtc::ByteBuffer* buf); + virtual bool Write(rtc::ByteBuffer* buf) const; + + private: + void SetBytes(char* bytes, size_t length); + + char* bytes_; +}; + +// Implements STUN attributes that record an error code. +class StunErrorCodeAttribute : public StunAttribute { + public: + static const uint16_t MIN_SIZE = 4; + StunErrorCodeAttribute(uint16_t type, int code, const std::string& reason); + StunErrorCodeAttribute(uint16_t type, uint16_t length); + ~StunErrorCodeAttribute(); + + virtual StunAttributeValueType value_type() const { + return STUN_VALUE_ERROR_CODE; + } + + // The combined error and class, e.g. 0x400. + int code() const; + void SetCode(int code); + + // The individual error components. + int eclass() const { return class_; } + int number() const { return number_; } + const std::string& reason() const { return reason_; } + void SetClass(uint8_t eclass) { class_ = eclass; } + void SetNumber(uint8_t number) { number_ = number; } + void SetReason(const std::string& reason); + + bool Read(rtc::ByteBuffer* buf); + bool Write(rtc::ByteBuffer* buf) const; + + private: + uint8_t class_; + uint8_t number_; + std::string reason_; +}; + +// Implements STUN attributes that record a list of attribute names. +class StunUInt16ListAttribute : public StunAttribute { + public: + StunUInt16ListAttribute(uint16_t type, uint16_t length); + ~StunUInt16ListAttribute(); + + virtual StunAttributeValueType value_type() const { + return STUN_VALUE_UINT16_LIST; + } + + size_t Size() const; + uint16_t GetType(int index) const; + void SetType(int index, uint16_t value); + void AddType(uint16_t value); + + bool Read(rtc::ByteBuffer* buf); + bool Write(rtc::ByteBuffer* buf) const; + + private: + std::vector<uint16_t>* attr_types_; +}; + +// Returns the (successful) response type for the given request type. +// Returns -1 if |request_type| is not a valid request type. +int GetStunSuccessResponseType(int request_type); + +// Returns the error response type for the given request type. +// Returns -1 if |request_type| is not a valid request type. +int GetStunErrorResponseType(int request_type); + +// Returns whether a given message is a request type. +bool IsStunRequestType(int msg_type); + +// Returns whether a given message is an indication type. +bool IsStunIndicationType(int msg_type); + +// Returns whether a given response is a success type. +bool IsStunSuccessResponseType(int msg_type); + +// Returns whether a given response is an error type. +bool IsStunErrorResponseType(int msg_type); + +// Computes the STUN long-term credential hash. +bool ComputeStunCredentialHash(const std::string& username, + const std::string& realm, const std::string& password, std::string* hash); + +// TODO: Move the TURN/ICE stuff below out to separate files. +extern const char TURN_MAGIC_COOKIE_VALUE[4]; + +// "GTURN" STUN methods. +// TODO: Rename these methods to GTURN_ to make it clear they aren't +// part of standard STUN/TURN. +enum RelayMessageType { + // For now, using the same defs from TurnMessageType below. + // STUN_ALLOCATE_REQUEST = 0x0003, + // STUN_ALLOCATE_RESPONSE = 0x0103, + // STUN_ALLOCATE_ERROR_RESPONSE = 0x0113, + STUN_SEND_REQUEST = 0x0004, + STUN_SEND_RESPONSE = 0x0104, + STUN_SEND_ERROR_RESPONSE = 0x0114, + STUN_DATA_INDICATION = 0x0115, +}; + +// "GTURN"-specific STUN attributes. +// TODO: Rename these attributes to GTURN_ to avoid conflicts. +enum RelayAttributeType { + STUN_ATTR_LIFETIME = 0x000d, // UInt32 + STUN_ATTR_MAGIC_COOKIE = 0x000f, // ByteString, 4 bytes + STUN_ATTR_BANDWIDTH = 0x0010, // UInt32 + STUN_ATTR_DESTINATION_ADDRESS = 0x0011, // Address + STUN_ATTR_SOURCE_ADDRESS2 = 0x0012, // Address + STUN_ATTR_DATA = 0x0013, // ByteString + STUN_ATTR_OPTIONS = 0x8001, // UInt32 +}; + +// A "GTURN" STUN message. +class RelayMessage : public StunMessage { + protected: + virtual StunAttributeValueType GetAttributeValueType(int type) const { + switch (type) { + case STUN_ATTR_LIFETIME: return STUN_VALUE_UINT32; + case STUN_ATTR_MAGIC_COOKIE: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_BANDWIDTH: return STUN_VALUE_UINT32; + case STUN_ATTR_DESTINATION_ADDRESS: return STUN_VALUE_ADDRESS; + case STUN_ATTR_SOURCE_ADDRESS2: return STUN_VALUE_ADDRESS; + case STUN_ATTR_DATA: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_OPTIONS: return STUN_VALUE_UINT32; + default: return StunMessage::GetAttributeValueType(type); + } + } + virtual StunMessage* CreateNew() const { return new RelayMessage(); } +}; + +// Defined in TURN RFC 5766. +enum TurnMessageType { + STUN_ALLOCATE_REQUEST = 0x0003, + STUN_ALLOCATE_RESPONSE = 0x0103, + STUN_ALLOCATE_ERROR_RESPONSE = 0x0113, + TURN_REFRESH_REQUEST = 0x0004, + TURN_REFRESH_RESPONSE = 0x0104, + TURN_REFRESH_ERROR_RESPONSE = 0x0114, + TURN_SEND_INDICATION = 0x0016, + TURN_DATA_INDICATION = 0x0017, + TURN_CREATE_PERMISSION_REQUEST = 0x0008, + TURN_CREATE_PERMISSION_RESPONSE = 0x0108, + TURN_CREATE_PERMISSION_ERROR_RESPONSE = 0x0118, + TURN_CHANNEL_BIND_REQUEST = 0x0009, + TURN_CHANNEL_BIND_RESPONSE = 0x0109, + TURN_CHANNEL_BIND_ERROR_RESPONSE = 0x0119, +}; + +enum TurnAttributeType { + STUN_ATTR_CHANNEL_NUMBER = 0x000C, // UInt32 + STUN_ATTR_TURN_LIFETIME = 0x000d, // UInt32 + STUN_ATTR_XOR_PEER_ADDRESS = 0x0012, // XorAddress + // TODO(mallinath) - Uncomment after RelayAttributes are renamed. + // STUN_ATTR_DATA = 0x0013, // ByteString + STUN_ATTR_XOR_RELAYED_ADDRESS = 0x0016, // XorAddress + STUN_ATTR_EVEN_PORT = 0x0018, // ByteString, 1 byte. + STUN_ATTR_REQUESTED_TRANSPORT = 0x0019, // UInt32 + STUN_ATTR_DONT_FRAGMENT = 0x001A, // No content, Length = 0 + STUN_ATTR_RESERVATION_TOKEN = 0x0022, // ByteString, 8 bytes. + // TODO(mallinath) - Rename STUN_ATTR_TURN_LIFETIME to STUN_ATTR_LIFETIME and + // STUN_ATTR_TURN_DATA to STUN_ATTR_DATA. Also rename RelayMessage attributes + // by appending G to attribute name. +}; + +// RFC 5766-defined errors. +enum TurnErrorType { + STUN_ERROR_FORBIDDEN = 403, + STUN_ERROR_ALLOCATION_MISMATCH = 437, + STUN_ERROR_WRONG_CREDENTIALS = 441, + STUN_ERROR_UNSUPPORTED_PROTOCOL = 442 +}; +extern const char STUN_ERROR_REASON_FORBIDDEN[]; +extern const char STUN_ERROR_REASON_ALLOCATION_MISMATCH[]; +extern const char STUN_ERROR_REASON_WRONG_CREDENTIALS[]; +extern const char STUN_ERROR_REASON_UNSUPPORTED_PROTOCOL[]; +class TurnMessage : public StunMessage { + protected: + virtual StunAttributeValueType GetAttributeValueType(int type) const { + switch (type) { + case STUN_ATTR_CHANNEL_NUMBER: return STUN_VALUE_UINT32; + case STUN_ATTR_TURN_LIFETIME: return STUN_VALUE_UINT32; + case STUN_ATTR_XOR_PEER_ADDRESS: return STUN_VALUE_XOR_ADDRESS; + case STUN_ATTR_DATA: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_XOR_RELAYED_ADDRESS: return STUN_VALUE_XOR_ADDRESS; + case STUN_ATTR_EVEN_PORT: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_REQUESTED_TRANSPORT: return STUN_VALUE_UINT32; + case STUN_ATTR_DONT_FRAGMENT: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_RESERVATION_TOKEN: return STUN_VALUE_BYTE_STRING; + default: return StunMessage::GetAttributeValueType(type); + } + } + virtual StunMessage* CreateNew() const { return new TurnMessage(); } +}; + +// RFC 5245 ICE STUN attributes. +enum IceAttributeType { + STUN_ATTR_PRIORITY = 0x0024, // UInt32 + STUN_ATTR_USE_CANDIDATE = 0x0025, // No content, Length = 0 + STUN_ATTR_ICE_CONTROLLED = 0x8029, // UInt64 + STUN_ATTR_ICE_CONTROLLING = 0x802A // UInt64 +}; + +// RFC 5245-defined errors. +enum IceErrorCode { + STUN_ERROR_ROLE_CONFLICT = 487, +}; +extern const char STUN_ERROR_REASON_ROLE_CONFLICT[]; + +// A RFC 5245 ICE STUN message. +class IceMessage : public StunMessage { + protected: + virtual StunAttributeValueType GetAttributeValueType(int type) const { + switch (type) { + case STUN_ATTR_PRIORITY: return STUN_VALUE_UINT32; + case STUN_ATTR_USE_CANDIDATE: return STUN_VALUE_BYTE_STRING; + case STUN_ATTR_ICE_CONTROLLED: return STUN_VALUE_UINT64; + case STUN_ATTR_ICE_CONTROLLING: return STUN_VALUE_UINT64; + default: return StunMessage::GetAttributeValueType(type); + } + } + virtual StunMessage* CreateNew() const { return new IceMessage(); } +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_STUN_H_ diff --git a/webrtc/p2p/base/stun_unittest.cc b/webrtc/p2p/base/stun_unittest.cc new file mode 100644 index 0000000000..cd4f7e1cbb --- /dev/null +++ b/webrtc/p2p/base/stun_unittest.cc @@ -0,0 +1,1446 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <string> + +#include "webrtc/p2p/base/stun.h" +#include "webrtc/base/bytebuffer.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/messagedigest.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/socketaddress.h" + +namespace cricket { + +class StunTest : public ::testing::Test { + protected: + void CheckStunHeader(const StunMessage& msg, StunMessageType expected_type, + size_t expected_length) { + ASSERT_EQ(expected_type, msg.type()); + ASSERT_EQ(expected_length, msg.length()); + } + + void CheckStunTransactionID(const StunMessage& msg, + const unsigned char* expectedID, size_t length) { + ASSERT_EQ(length, msg.transaction_id().size()); + ASSERT_EQ(length == kStunTransactionIdLength + 4, msg.IsLegacy()); + ASSERT_EQ(length == kStunTransactionIdLength, !msg.IsLegacy()); + ASSERT_EQ(0, memcmp(msg.transaction_id().c_str(), expectedID, length)); + } + + void CheckStunAddressAttribute(const StunAddressAttribute* addr, + StunAddressFamily expected_family, + int expected_port, + rtc::IPAddress expected_address) { + ASSERT_EQ(expected_family, addr->family()); + ASSERT_EQ(expected_port, addr->port()); + + if (addr->family() == STUN_ADDRESS_IPV4) { + in_addr v4_address = expected_address.ipv4_address(); + in_addr stun_address = addr->ipaddr().ipv4_address(); + ASSERT_EQ(0, memcmp(&v4_address, &stun_address, sizeof(stun_address))); + } else if (addr->family() == STUN_ADDRESS_IPV6) { + in6_addr v6_address = expected_address.ipv6_address(); + in6_addr stun_address = addr->ipaddr().ipv6_address(); + ASSERT_EQ(0, memcmp(&v6_address, &stun_address, sizeof(stun_address))); + } else { + ASSERT_TRUE(addr->family() == STUN_ADDRESS_IPV6 || + addr->family() == STUN_ADDRESS_IPV4); + } + } + + size_t ReadStunMessageTestCase(StunMessage* msg, + const unsigned char* testcase, + size_t size) { + const char* input = reinterpret_cast<const char*>(testcase); + rtc::ByteBuffer buf(input, size); + if (msg->Read(&buf)) { + // Returns the size the stun message should report itself as being + return (size - 20); + } else { + return 0; + } + } +}; + + +// Sample STUN packets with various attributes +// Gathered by wiresharking pjproject's pjnath test programs +// pjproject available at www.pjsip.org + +static const unsigned char kStunMessageWithIPv6MappedAddress[] = { + 0x00, 0x01, 0x00, 0x18, // message header + 0x21, 0x12, 0xa4, 0x42, // transaction id + 0x29, 0x1f, 0xcd, 0x7c, + 0xba, 0x58, 0xab, 0xd7, + 0xf2, 0x41, 0x01, 0x00, + 0x00, 0x01, 0x00, 0x14, // Address type (mapped), length + 0x00, 0x02, 0xb8, 0x81, // family (IPv6), port + 0x24, 0x01, 0xfa, 0x00, // an IPv6 address + 0x00, 0x04, 0x10, 0x00, + 0xbe, 0x30, 0x5b, 0xff, + 0xfe, 0xe5, 0x00, 0xc3 +}; + +static const unsigned char kStunMessageWithIPv4MappedAddress[] = { + 0x01, 0x01, 0x00, 0x0c, // binding response, length 12 + 0x21, 0x12, 0xa4, 0x42, // magic cookie + 0x29, 0x1f, 0xcd, 0x7c, // transaction ID + 0xba, 0x58, 0xab, 0xd7, + 0xf2, 0x41, 0x01, 0x00, + 0x00, 0x01, 0x00, 0x08, // Mapped, 8 byte length + 0x00, 0x01, 0x9d, 0xfc, // AF_INET, unxor-ed port + 0xac, 0x17, 0x44, 0xe6 // IPv4 address +}; + +// Test XOR-mapped IP addresses: +static const unsigned char kStunMessageWithIPv6XorMappedAddress[] = { + 0x01, 0x01, 0x00, 0x18, // message header (binding response) + 0x21, 0x12, 0xa4, 0x42, // magic cookie (rfc5389) + 0xe3, 0xa9, 0x46, 0xe1, // transaction ID + 0x7c, 0x00, 0xc2, 0x62, + 0x54, 0x08, 0x01, 0x00, + 0x00, 0x20, 0x00, 0x14, // Address Type (XOR), length + 0x00, 0x02, 0xcb, 0x5b, // family, XOR-ed port + 0x05, 0x13, 0x5e, 0x42, // XOR-ed IPv6 address + 0xe3, 0xad, 0x56, 0xe1, + 0xc2, 0x30, 0x99, 0x9d, + 0xaa, 0xed, 0x01, 0xc3 +}; + +static const unsigned char kStunMessageWithIPv4XorMappedAddress[] = { + 0x01, 0x01, 0x00, 0x0c, // message header (binding response) + 0x21, 0x12, 0xa4, 0x42, // magic cookie + 0x29, 0x1f, 0xcd, 0x7c, // transaction ID + 0xba, 0x58, 0xab, 0xd7, + 0xf2, 0x41, 0x01, 0x00, + 0x00, 0x20, 0x00, 0x08, // address type (xor), length + 0x00, 0x01, 0xfc, 0xb5, // family (AF_INET), XOR-ed port + 0x8d, 0x05, 0xe0, 0xa4 // IPv4 address +}; + +// ByteString Attribute (username) +static const unsigned char kStunMessageWithByteStringAttribute[] = { + 0x00, 0x01, 0x00, 0x0c, + 0x21, 0x12, 0xa4, 0x42, + 0xe3, 0xa9, 0x46, 0xe1, + 0x7c, 0x00, 0xc2, 0x62, + 0x54, 0x08, 0x01, 0x00, + 0x00, 0x06, 0x00, 0x08, // username attribute (length 8) + 0x61, 0x62, 0x63, 0x64, // abcdefgh + 0x65, 0x66, 0x67, 0x68 +}; + +// Message with an unknown but comprehensible optional attribute. +// Parsing should succeed despite this unknown attribute. +static const unsigned char kStunMessageWithUnknownAttribute[] = { + 0x00, 0x01, 0x00, 0x14, + 0x21, 0x12, 0xa4, 0x42, + 0xe3, 0xa9, 0x46, 0xe1, + 0x7c, 0x00, 0xc2, 0x62, + 0x54, 0x08, 0x01, 0x00, + 0x00, 0xaa, 0x00, 0x07, // Unknown attribute, length 7 (needs padding!) + 0x61, 0x62, 0x63, 0x64, // abcdefg + padding + 0x65, 0x66, 0x67, 0x00, + 0x00, 0x06, 0x00, 0x03, // Followed by a known attribute we can + 0x61, 0x62, 0x63, 0x00 // check for (username of length 3) +}; + +// ByteString Attribute (username) with padding byte +static const unsigned char kStunMessageWithPaddedByteStringAttribute[] = { + 0x00, 0x01, 0x00, 0x08, + 0x21, 0x12, 0xa4, 0x42, + 0xe3, 0xa9, 0x46, 0xe1, + 0x7c, 0x00, 0xc2, 0x62, + 0x54, 0x08, 0x01, 0x00, + 0x00, 0x06, 0x00, 0x03, // username attribute (length 3) + 0x61, 0x62, 0x63, 0xcc // abc +}; + +// Message with an Unknown Attributes (uint16_t list) attribute. +static const unsigned char kStunMessageWithUInt16ListAttribute[] = { + 0x00, 0x01, 0x00, 0x0c, + 0x21, 0x12, 0xa4, 0x42, + 0xe3, 0xa9, 0x46, 0xe1, + 0x7c, 0x00, 0xc2, 0x62, + 0x54, 0x08, 0x01, 0x00, + 0x00, 0x0a, 0x00, 0x06, // username attribute (length 6) + 0x00, 0x01, 0x10, 0x00, // three attributes plus padding + 0xAB, 0xCU, 0xBE, 0xEF +}; + +// Error response message (unauthorized) +static const unsigned char kStunMessageWithErrorAttribute[] = { + 0x01, 0x11, 0x00, 0x14, + 0x21, 0x12, 0xa4, 0x42, + 0x29, 0x1f, 0xcd, 0x7c, + 0xba, 0x58, 0xab, 0xd7, + 0xf2, 0x41, 0x01, 0x00, + 0x00, 0x09, 0x00, 0x10, + 0x00, 0x00, 0x04, 0x01, + 0x55, 0x6e, 0x61, 0x75, + 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x65, 0x64 +}; + +static const unsigned char kStunMessageWithOriginAttribute[] = { + 0x00, 0x01, 0x00, 0x18, // message header (binding request), length 24 + 0x21, 0x12, 0xA4, 0x42, // magic cookie + 0x29, 0x1f, 0xcd, 0x7c, // transaction id + 0xba, 0x58, 0xab, 0xd7, + 0xf2, 0x41, 0x01, 0x00, + 0x80, 0x2f, 0x00, 0x12, // origin attribute (length 18) + 0x68, 0x74, 0x74, 0x70, // http://example.com + 0x3A, 0x2F, 0x2F, 0x65, + 0x78, 0x61, 0x6d, 0x70, + 0x6c, 0x65, 0x2e, 0x63, + 0x6f, 0x6d, 0x00, 0x00, +}; + +// Sample messages with an invalid length Field + +// The actual length in bytes of the invalid messages (including STUN header) +static const int kRealLengthOfInvalidLengthTestCases = 32; + +static const unsigned char kStunMessageWithZeroLength[] = { + 0x00, 0x01, 0x00, 0x00, // length of 0 (last 2 bytes) + 0x21, 0x12, 0xA4, 0x42, // magic cookie + '0', '1', '2', '3', // transaction id + '4', '5', '6', '7', + '8', '9', 'a', 'b', + 0x00, 0x20, 0x00, 0x08, // xor mapped address + 0x00, 0x01, 0x21, 0x1F, + 0x21, 0x12, 0xA4, 0x53, +}; + +static const unsigned char kStunMessageWithExcessLength[] = { + 0x00, 0x01, 0x00, 0x55, // length of 85 + 0x21, 0x12, 0xA4, 0x42, // magic cookie + '0', '1', '2', '3', // transaction id + '4', '5', '6', '7', + '8', '9', 'a', 'b', + 0x00, 0x20, 0x00, 0x08, // xor mapped address + 0x00, 0x01, 0x21, 0x1F, + 0x21, 0x12, 0xA4, 0x53, +}; + +static const unsigned char kStunMessageWithSmallLength[] = { + 0x00, 0x01, 0x00, 0x03, // length of 3 + 0x21, 0x12, 0xA4, 0x42, // magic cookie + '0', '1', '2', '3', // transaction id + '4', '5', '6', '7', + '8', '9', 'a', 'b', + 0x00, 0x20, 0x00, 0x08, // xor mapped address + 0x00, 0x01, 0x21, 0x1F, + 0x21, 0x12, 0xA4, 0x53, +}; + +// RTCP packet, for testing we correctly ignore non stun packet types. +// V=2, P=false, RC=0, Type=200, Len=6, Sender-SSRC=85, etc +static const unsigned char kRtcpPacket[] = { + 0x80, 0xc8, 0x00, 0x06, 0x00, 0x00, 0x00, 0x55, + 0xce, 0xa5, 0x18, 0x3a, 0x39, 0xcc, 0x7d, 0x09, + 0x23, 0xed, 0x19, 0x07, 0x00, 0x00, 0x01, 0x56, + 0x00, 0x03, 0x73, 0x50, +}; + +// RFC5769 Test Vectors +// Software name (request): "STUN test client" (without quotes) +// Software name (response): "test vector" (without quotes) +// Username: "evtj:h6vY" (without quotes) +// Password: "VOkJxbRl1RmTxUk/WvJxBt" (without quotes) +static const unsigned char kRfc5769SampleMsgTransactionId[] = { + 0xb7, 0xe7, 0xa7, 0x01, 0xbc, 0x34, 0xd6, 0x86, 0xfa, 0x87, 0xdf, 0xae +}; +static const char kRfc5769SampleMsgClientSoftware[] = "STUN test client"; +static const char kRfc5769SampleMsgServerSoftware[] = "test vector"; +static const char kRfc5769SampleMsgUsername[] = "evtj:h6vY"; +static const char kRfc5769SampleMsgPassword[] = "VOkJxbRl1RmTxUk/WvJxBt"; +static const rtc::SocketAddress kRfc5769SampleMsgMappedAddress( + "192.0.2.1", 32853); +static const rtc::SocketAddress kRfc5769SampleMsgIPv6MappedAddress( + "2001:db8:1234:5678:11:2233:4455:6677", 32853); + +static const unsigned char kRfc5769SampleMsgWithAuthTransactionId[] = { + 0x78, 0xad, 0x34, 0x33, 0xc6, 0xad, 0x72, 0xc0, 0x29, 0xda, 0x41, 0x2e +}; +static const char kRfc5769SampleMsgWithAuthUsername[] = + "\xe3\x83\x9e\xe3\x83\x88\xe3\x83\xaa\xe3\x83\x83\xe3\x82\xaf\xe3\x82\xb9"; +static const char kRfc5769SampleMsgWithAuthPassword[] = "TheMatrIX"; +static const char kRfc5769SampleMsgWithAuthNonce[] = + "f//499k954d6OL34oL9FSTvy64sA"; +static const char kRfc5769SampleMsgWithAuthRealm[] = "example.org"; + +// 2.1. Sample Request +static const unsigned char kRfc5769SampleRequest[] = { + 0x00, 0x01, 0x00, 0x58, // Request type and message length + 0x21, 0x12, 0xa4, 0x42, // Magic cookie + 0xb7, 0xe7, 0xa7, 0x01, // } + 0xbc, 0x34, 0xd6, 0x86, // } Transaction ID + 0xfa, 0x87, 0xdf, 0xae, // } + 0x80, 0x22, 0x00, 0x10, // SOFTWARE attribute header + 0x53, 0x54, 0x55, 0x4e, // } + 0x20, 0x74, 0x65, 0x73, // } User-agent... + 0x74, 0x20, 0x63, 0x6c, // } ...name + 0x69, 0x65, 0x6e, 0x74, // } + 0x00, 0x24, 0x00, 0x04, // PRIORITY attribute header + 0x6e, 0x00, 0x01, 0xff, // ICE priority value + 0x80, 0x29, 0x00, 0x08, // ICE-CONTROLLED attribute header + 0x93, 0x2f, 0xf9, 0xb1, // } Pseudo-random tie breaker... + 0x51, 0x26, 0x3b, 0x36, // } ...for ICE control + 0x00, 0x06, 0x00, 0x09, // USERNAME attribute header + 0x65, 0x76, 0x74, 0x6a, // } + 0x3a, 0x68, 0x36, 0x76, // } Username (9 bytes) and padding (3 bytes) + 0x59, 0x20, 0x20, 0x20, // } + 0x00, 0x08, 0x00, 0x14, // MESSAGE-INTEGRITY attribute header + 0x9a, 0xea, 0xa7, 0x0c, // } + 0xbf, 0xd8, 0xcb, 0x56, // } + 0x78, 0x1e, 0xf2, 0xb5, // } HMAC-SHA1 fingerprint + 0xb2, 0xd3, 0xf2, 0x49, // } + 0xc1, 0xb5, 0x71, 0xa2, // } + 0x80, 0x28, 0x00, 0x04, // FINGERPRINT attribute header + 0xe5, 0x7a, 0x3b, 0xcf // CRC32 fingerprint +}; + +// 2.2. Sample IPv4 Response +static const unsigned char kRfc5769SampleResponse[] = { + 0x01, 0x01, 0x00, 0x3c, // Response type and message length + 0x21, 0x12, 0xa4, 0x42, // Magic cookie + 0xb7, 0xe7, 0xa7, 0x01, // } + 0xbc, 0x34, 0xd6, 0x86, // } Transaction ID + 0xfa, 0x87, 0xdf, 0xae, // } + 0x80, 0x22, 0x00, 0x0b, // SOFTWARE attribute header + 0x74, 0x65, 0x73, 0x74, // } + 0x20, 0x76, 0x65, 0x63, // } UTF-8 server name + 0x74, 0x6f, 0x72, 0x20, // } + 0x00, 0x20, 0x00, 0x08, // XOR-MAPPED-ADDRESS attribute header + 0x00, 0x01, 0xa1, 0x47, // Address family (IPv4) and xor'd mapped port + 0xe1, 0x12, 0xa6, 0x43, // Xor'd mapped IPv4 address + 0x00, 0x08, 0x00, 0x14, // MESSAGE-INTEGRITY attribute header + 0x2b, 0x91, 0xf5, 0x99, // } + 0xfd, 0x9e, 0x90, 0xc3, // } + 0x8c, 0x74, 0x89, 0xf9, // } HMAC-SHA1 fingerprint + 0x2a, 0xf9, 0xba, 0x53, // } + 0xf0, 0x6b, 0xe7, 0xd7, // } + 0x80, 0x28, 0x00, 0x04, // FINGERPRINT attribute header + 0xc0, 0x7d, 0x4c, 0x96 // CRC32 fingerprint +}; + +// 2.3. Sample IPv6 Response +static const unsigned char kRfc5769SampleResponseIPv6[] = { + 0x01, 0x01, 0x00, 0x48, // Response type and message length + 0x21, 0x12, 0xa4, 0x42, // Magic cookie + 0xb7, 0xe7, 0xa7, 0x01, // } + 0xbc, 0x34, 0xd6, 0x86, // } Transaction ID + 0xfa, 0x87, 0xdf, 0xae, // } + 0x80, 0x22, 0x00, 0x0b, // SOFTWARE attribute header + 0x74, 0x65, 0x73, 0x74, // } + 0x20, 0x76, 0x65, 0x63, // } UTF-8 server name + 0x74, 0x6f, 0x72, 0x20, // } + 0x00, 0x20, 0x00, 0x14, // XOR-MAPPED-ADDRESS attribute header + 0x00, 0x02, 0xa1, 0x47, // Address family (IPv6) and xor'd mapped port. + 0x01, 0x13, 0xa9, 0xfa, // } + 0xa5, 0xd3, 0xf1, 0x79, // } Xor'd mapped IPv6 address + 0xbc, 0x25, 0xf4, 0xb5, // } + 0xbe, 0xd2, 0xb9, 0xd9, // } + 0x00, 0x08, 0x00, 0x14, // MESSAGE-INTEGRITY attribute header + 0xa3, 0x82, 0x95, 0x4e, // } + 0x4b, 0xe6, 0x7b, 0xf1, // } + 0x17, 0x84, 0xc9, 0x7c, // } HMAC-SHA1 fingerprint + 0x82, 0x92, 0xc2, 0x75, // } + 0xbf, 0xe3, 0xed, 0x41, // } + 0x80, 0x28, 0x00, 0x04, // FINGERPRINT attribute header + 0xc8, 0xfb, 0x0b, 0x4c // CRC32 fingerprint +}; + +// 2.4. Sample Request with Long-Term Authentication +static const unsigned char kRfc5769SampleRequestLongTermAuth[] = { + 0x00, 0x01, 0x00, 0x60, // Request type and message length + 0x21, 0x12, 0xa4, 0x42, // Magic cookie + 0x78, 0xad, 0x34, 0x33, // } + 0xc6, 0xad, 0x72, 0xc0, // } Transaction ID + 0x29, 0xda, 0x41, 0x2e, // } + 0x00, 0x06, 0x00, 0x12, // USERNAME attribute header + 0xe3, 0x83, 0x9e, 0xe3, // } + 0x83, 0x88, 0xe3, 0x83, // } + 0xaa, 0xe3, 0x83, 0x83, // } Username value (18 bytes) and padding (2 bytes) + 0xe3, 0x82, 0xaf, 0xe3, // } + 0x82, 0xb9, 0x00, 0x00, // } + 0x00, 0x15, 0x00, 0x1c, // NONCE attribute header + 0x66, 0x2f, 0x2f, 0x34, // } + 0x39, 0x39, 0x6b, 0x39, // } + 0x35, 0x34, 0x64, 0x36, // } + 0x4f, 0x4c, 0x33, 0x34, // } Nonce value + 0x6f, 0x4c, 0x39, 0x46, // } + 0x53, 0x54, 0x76, 0x79, // } + 0x36, 0x34, 0x73, 0x41, // } + 0x00, 0x14, 0x00, 0x0b, // REALM attribute header + 0x65, 0x78, 0x61, 0x6d, // } + 0x70, 0x6c, 0x65, 0x2e, // } Realm value (11 bytes) and padding (1 byte) + 0x6f, 0x72, 0x67, 0x00, // } + 0x00, 0x08, 0x00, 0x14, // MESSAGE-INTEGRITY attribute header + 0xf6, 0x70, 0x24, 0x65, // } + 0x6d, 0xd6, 0x4a, 0x3e, // } + 0x02, 0xb8, 0xe0, 0x71, // } HMAC-SHA1 fingerprint + 0x2e, 0x85, 0xc9, 0xa2, // } + 0x8c, 0xa8, 0x96, 0x66 // } +}; + +// Length parameter is changed to 0x38 from 0x58. +// AddMessageIntegrity will add MI information and update the length param +// accordingly. +static const unsigned char kRfc5769SampleRequestWithoutMI[] = { + 0x00, 0x01, 0x00, 0x38, // Request type and message length + 0x21, 0x12, 0xa4, 0x42, // Magic cookie + 0xb7, 0xe7, 0xa7, 0x01, // } + 0xbc, 0x34, 0xd6, 0x86, // } Transaction ID + 0xfa, 0x87, 0xdf, 0xae, // } + 0x80, 0x22, 0x00, 0x10, // SOFTWARE attribute header + 0x53, 0x54, 0x55, 0x4e, // } + 0x20, 0x74, 0x65, 0x73, // } User-agent... + 0x74, 0x20, 0x63, 0x6c, // } ...name + 0x69, 0x65, 0x6e, 0x74, // } + 0x00, 0x24, 0x00, 0x04, // PRIORITY attribute header + 0x6e, 0x00, 0x01, 0xff, // ICE priority value + 0x80, 0x29, 0x00, 0x08, // ICE-CONTROLLED attribute header + 0x93, 0x2f, 0xf9, 0xb1, // } Pseudo-random tie breaker... + 0x51, 0x26, 0x3b, 0x36, // } ...for ICE control + 0x00, 0x06, 0x00, 0x09, // USERNAME attribute header + 0x65, 0x76, 0x74, 0x6a, // } + 0x3a, 0x68, 0x36, 0x76, // } Username (9 bytes) and padding (3 bytes) + 0x59, 0x20, 0x20, 0x20 // } +}; + +// This HMAC differs from the RFC 5769 SampleRequest message. This differs +// because spec uses 0x20 for the padding where as our implementation uses 0. +static const unsigned char kCalculatedHmac1[] = { + 0x79, 0x07, 0xc2, 0xd2, // } + 0xed, 0xbf, 0xea, 0x48, // } + 0x0e, 0x4c, 0x76, 0xd8, // } HMAC-SHA1 fingerprint + 0x29, 0x62, 0xd5, 0xc3, // } + 0x74, 0x2a, 0xf9, 0xe3 // } +}; + +// Length parameter is changed to 0x1c from 0x3c. +// AddMessageIntegrity will add MI information and update the length param +// accordingly. +static const unsigned char kRfc5769SampleResponseWithoutMI[] = { + 0x01, 0x01, 0x00, 0x1c, // Response type and message length + 0x21, 0x12, 0xa4, 0x42, // Magic cookie + 0xb7, 0xe7, 0xa7, 0x01, // } + 0xbc, 0x34, 0xd6, 0x86, // } Transaction ID + 0xfa, 0x87, 0xdf, 0xae, // } + 0x80, 0x22, 0x00, 0x0b, // SOFTWARE attribute header + 0x74, 0x65, 0x73, 0x74, // } + 0x20, 0x76, 0x65, 0x63, // } UTF-8 server name + 0x74, 0x6f, 0x72, 0x20, // } + 0x00, 0x20, 0x00, 0x08, // XOR-MAPPED-ADDRESS attribute header + 0x00, 0x01, 0xa1, 0x47, // Address family (IPv4) and xor'd mapped port + 0xe1, 0x12, 0xa6, 0x43 // Xor'd mapped IPv4 address +}; + +// This HMAC differs from the RFC 5769 SampleResponse message. This differs +// because spec uses 0x20 for the padding where as our implementation uses 0. +static const unsigned char kCalculatedHmac2[] = { + 0x5d, 0x6b, 0x58, 0xbe, // } + 0xad, 0x94, 0xe0, 0x7e, // } + 0xef, 0x0d, 0xfc, 0x12, // } HMAC-SHA1 fingerprint + 0x82, 0xa2, 0xbd, 0x08, // } + 0x43, 0x14, 0x10, 0x28 // } +}; + +// A transaction ID without the 'magic cookie' portion +// pjnat's test programs use this transaction ID a lot. +const unsigned char kTestTransactionId1[] = { 0x029, 0x01f, 0x0cd, 0x07c, + 0x0ba, 0x058, 0x0ab, 0x0d7, + 0x0f2, 0x041, 0x001, 0x000 }; + +// They use this one sometimes too. +const unsigned char kTestTransactionId2[] = { 0x0e3, 0x0a9, 0x046, 0x0e1, + 0x07c, 0x000, 0x0c2, 0x062, + 0x054, 0x008, 0x001, 0x000 }; + +const in6_addr kIPv6TestAddress1 = { { { 0x24, 0x01, 0xfa, 0x00, + 0x00, 0x04, 0x10, 0x00, + 0xbe, 0x30, 0x5b, 0xff, + 0xfe, 0xe5, 0x00, 0xc3 } } }; +const in6_addr kIPv6TestAddress2 = { { { 0x24, 0x01, 0xfa, 0x00, + 0x00, 0x04, 0x10, 0x12, + 0x06, 0x0c, 0xce, 0xff, + 0xfe, 0x1f, 0x61, 0xa4 } } }; + +#ifdef WEBRTC_POSIX +const in_addr kIPv4TestAddress1 = { 0xe64417ac }; +#elif defined WEBRTC_WIN +// Windows in_addr has a union with a uchar[] array first. +const in_addr kIPv4TestAddress1 = { { 0x0ac, 0x017, 0x044, 0x0e6 } }; +#endif +const char kTestUserName1[] = "abcdefgh"; +const char kTestUserName2[] = "abc"; +const char kTestErrorReason[] = "Unauthorized"; +const char kTestOrigin[] = "http://example.com"; +const int kTestErrorClass = 4; +const int kTestErrorNumber = 1; +const int kTestErrorCode = 401; + +const int kTestMessagePort1 = 59977; +const int kTestMessagePort2 = 47233; +const int kTestMessagePort3 = 56743; +const int kTestMessagePort4 = 40444; + +#define ReadStunMessage(X, Y) ReadStunMessageTestCase(X, Y, sizeof(Y)); + +// Test that the GetStun*Type and IsStun*Type methods work as expected. +TEST_F(StunTest, MessageTypes) { + EXPECT_EQ(STUN_BINDING_RESPONSE, + GetStunSuccessResponseType(STUN_BINDING_REQUEST)); + EXPECT_EQ(STUN_BINDING_ERROR_RESPONSE, + GetStunErrorResponseType(STUN_BINDING_REQUEST)); + EXPECT_EQ(-1, GetStunSuccessResponseType(STUN_BINDING_INDICATION)); + EXPECT_EQ(-1, GetStunSuccessResponseType(STUN_BINDING_RESPONSE)); + EXPECT_EQ(-1, GetStunSuccessResponseType(STUN_BINDING_ERROR_RESPONSE)); + EXPECT_EQ(-1, GetStunErrorResponseType(STUN_BINDING_INDICATION)); + EXPECT_EQ(-1, GetStunErrorResponseType(STUN_BINDING_RESPONSE)); + EXPECT_EQ(-1, GetStunErrorResponseType(STUN_BINDING_ERROR_RESPONSE)); + + int types[] = { + STUN_BINDING_REQUEST, STUN_BINDING_INDICATION, + STUN_BINDING_RESPONSE, STUN_BINDING_ERROR_RESPONSE + }; + for (int i = 0; i < ARRAY_SIZE(types); ++i) { + EXPECT_EQ(i == 0, IsStunRequestType(types[i])); + EXPECT_EQ(i == 1, IsStunIndicationType(types[i])); + EXPECT_EQ(i == 2, IsStunSuccessResponseType(types[i])); + EXPECT_EQ(i == 3, IsStunErrorResponseType(types[i])); + EXPECT_EQ(1, types[i] & 0xFEEF); + } +} + +TEST_F(StunTest, ReadMessageWithIPv4AddressAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kStunMessageWithIPv4MappedAddress); + CheckStunHeader(msg, STUN_BINDING_RESPONSE, size); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + + const StunAddressAttribute* addr = msg.GetAddress(STUN_ATTR_MAPPED_ADDRESS); + rtc::IPAddress test_address(kIPv4TestAddress1); + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV4, + kTestMessagePort4, test_address); +} + +TEST_F(StunTest, ReadMessageWithIPv4XorAddressAttribute) { + StunMessage msg; + StunMessage msg2; + size_t size = ReadStunMessage(&msg, kStunMessageWithIPv4XorMappedAddress); + CheckStunHeader(msg, STUN_BINDING_RESPONSE, size); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + + const StunAddressAttribute* addr = + msg.GetAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + rtc::IPAddress test_address(kIPv4TestAddress1); + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV4, + kTestMessagePort3, test_address); +} + +TEST_F(StunTest, ReadMessageWithIPv6AddressAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kStunMessageWithIPv6MappedAddress); + CheckStunHeader(msg, STUN_BINDING_REQUEST, size); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + + rtc::IPAddress test_address(kIPv6TestAddress1); + + const StunAddressAttribute* addr = msg.GetAddress(STUN_ATTR_MAPPED_ADDRESS); + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV6, + kTestMessagePort2, test_address); +} + +TEST_F(StunTest, ReadMessageWithInvalidAddressAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kStunMessageWithIPv6MappedAddress); + CheckStunHeader(msg, STUN_BINDING_REQUEST, size); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + + rtc::IPAddress test_address(kIPv6TestAddress1); + + const StunAddressAttribute* addr = msg.GetAddress(STUN_ATTR_MAPPED_ADDRESS); + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV6, + kTestMessagePort2, test_address); +} + +TEST_F(StunTest, ReadMessageWithIPv6XorAddressAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kStunMessageWithIPv6XorMappedAddress); + + rtc::IPAddress test_address(kIPv6TestAddress1); + + CheckStunHeader(msg, STUN_BINDING_RESPONSE, size); + CheckStunTransactionID(msg, kTestTransactionId2, kStunTransactionIdLength); + + const StunAddressAttribute* addr = + msg.GetAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV6, + kTestMessagePort1, test_address); +} + +// Read the RFC5389 fields from the RFC5769 sample STUN request. +TEST_F(StunTest, ReadRfc5769RequestMessage) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kRfc5769SampleRequest); + CheckStunHeader(msg, STUN_BINDING_REQUEST, size); + CheckStunTransactionID(msg, kRfc5769SampleMsgTransactionId, + kStunTransactionIdLength); + + const StunByteStringAttribute* software = + msg.GetByteString(STUN_ATTR_SOFTWARE); + ASSERT_TRUE(software != NULL); + EXPECT_EQ(kRfc5769SampleMsgClientSoftware, software->GetString()); + + const StunByteStringAttribute* username = + msg.GetByteString(STUN_ATTR_USERNAME); + ASSERT_TRUE(username != NULL); + EXPECT_EQ(kRfc5769SampleMsgUsername, username->GetString()); + + // Actual M-I value checked in a later test. + ASSERT_TRUE(msg.GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); + + // Fingerprint checked in a later test, but double-check the value here. + const StunUInt32Attribute* fingerprint = + msg.GetUInt32(STUN_ATTR_FINGERPRINT); + ASSERT_TRUE(fingerprint != NULL); + EXPECT_EQ(0xe57a3bcf, fingerprint->value()); +} + +// Read the RFC5389 fields from the RFC5769 sample STUN response. +TEST_F(StunTest, ReadRfc5769ResponseMessage) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kRfc5769SampleResponse); + CheckStunHeader(msg, STUN_BINDING_RESPONSE, size); + CheckStunTransactionID(msg, kRfc5769SampleMsgTransactionId, + kStunTransactionIdLength); + + const StunByteStringAttribute* software = + msg.GetByteString(STUN_ATTR_SOFTWARE); + ASSERT_TRUE(software != NULL); + EXPECT_EQ(kRfc5769SampleMsgServerSoftware, software->GetString()); + + const StunAddressAttribute* mapped_address = + msg.GetAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + ASSERT_TRUE(mapped_address != NULL); + EXPECT_EQ(kRfc5769SampleMsgMappedAddress, mapped_address->GetAddress()); + + // Actual M-I and fingerprint checked in later tests. + ASSERT_TRUE(msg.GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); + ASSERT_TRUE(msg.GetUInt32(STUN_ATTR_FINGERPRINT) != NULL); +} + +// Read the RFC5389 fields from the RFC5769 sample STUN response for IPv6. +TEST_F(StunTest, ReadRfc5769ResponseMessageIPv6) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kRfc5769SampleResponseIPv6); + CheckStunHeader(msg, STUN_BINDING_RESPONSE, size); + CheckStunTransactionID(msg, kRfc5769SampleMsgTransactionId, + kStunTransactionIdLength); + + const StunByteStringAttribute* software = + msg.GetByteString(STUN_ATTR_SOFTWARE); + ASSERT_TRUE(software != NULL); + EXPECT_EQ(kRfc5769SampleMsgServerSoftware, software->GetString()); + + const StunAddressAttribute* mapped_address = + msg.GetAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + ASSERT_TRUE(mapped_address != NULL); + EXPECT_EQ(kRfc5769SampleMsgIPv6MappedAddress, mapped_address->GetAddress()); + + // Actual M-I and fingerprint checked in later tests. + ASSERT_TRUE(msg.GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); + ASSERT_TRUE(msg.GetUInt32(STUN_ATTR_FINGERPRINT) != NULL); +} + +// Read the RFC5389 fields from the RFC5769 sample STUN response with auth. +TEST_F(StunTest, ReadRfc5769RequestMessageLongTermAuth) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kRfc5769SampleRequestLongTermAuth); + CheckStunHeader(msg, STUN_BINDING_REQUEST, size); + CheckStunTransactionID(msg, kRfc5769SampleMsgWithAuthTransactionId, + kStunTransactionIdLength); + + const StunByteStringAttribute* username = + msg.GetByteString(STUN_ATTR_USERNAME); + ASSERT_TRUE(username != NULL); + EXPECT_EQ(kRfc5769SampleMsgWithAuthUsername, username->GetString()); + + const StunByteStringAttribute* nonce = + msg.GetByteString(STUN_ATTR_NONCE); + ASSERT_TRUE(nonce != NULL); + EXPECT_EQ(kRfc5769SampleMsgWithAuthNonce, nonce->GetString()); + + const StunByteStringAttribute* realm = + msg.GetByteString(STUN_ATTR_REALM); + ASSERT_TRUE(realm != NULL); + EXPECT_EQ(kRfc5769SampleMsgWithAuthRealm, realm->GetString()); + + // No fingerprint, actual M-I checked in later tests. + ASSERT_TRUE(msg.GetByteString(STUN_ATTR_MESSAGE_INTEGRITY) != NULL); + ASSERT_TRUE(msg.GetUInt32(STUN_ATTR_FINGERPRINT) == NULL); +} + +// The RFC3489 packet in this test is the same as +// kStunMessageWithIPv4MappedAddress, but with a different value where the +// magic cookie was. +TEST_F(StunTest, ReadLegacyMessage) { + unsigned char rfc3489_packet[sizeof(kStunMessageWithIPv4MappedAddress)]; + memcpy(rfc3489_packet, kStunMessageWithIPv4MappedAddress, + sizeof(kStunMessageWithIPv4MappedAddress)); + // Overwrite the magic cookie here. + memcpy(&rfc3489_packet[4], "ABCD", 4); + + StunMessage msg; + size_t size = ReadStunMessage(&msg, rfc3489_packet); + CheckStunHeader(msg, STUN_BINDING_RESPONSE, size); + CheckStunTransactionID(msg, &rfc3489_packet[4], kStunTransactionIdLength + 4); + + const StunAddressAttribute* addr = msg.GetAddress(STUN_ATTR_MAPPED_ADDRESS); + rtc::IPAddress test_address(kIPv4TestAddress1); + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV4, + kTestMessagePort4, test_address); +} + +TEST_F(StunTest, SetIPv6XorAddressAttributeOwner) { + StunMessage msg; + StunMessage msg2; + size_t size = ReadStunMessage(&msg, kStunMessageWithIPv6XorMappedAddress); + + rtc::IPAddress test_address(kIPv6TestAddress1); + + CheckStunHeader(msg, STUN_BINDING_RESPONSE, size); + CheckStunTransactionID(msg, kTestTransactionId2, kStunTransactionIdLength); + + const StunAddressAttribute* addr = + msg.GetAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV6, + kTestMessagePort1, test_address); + + // Owner with a different transaction ID. + msg2.SetTransactionID("ABCDABCDABCD"); + StunXorAddressAttribute addr2(STUN_ATTR_XOR_MAPPED_ADDRESS, 20, NULL); + addr2.SetIP(addr->ipaddr()); + addr2.SetPort(addr->port()); + addr2.SetOwner(&msg2); + // The internal IP address shouldn't change. + ASSERT_EQ(addr2.ipaddr(), addr->ipaddr()); + + rtc::ByteBuffer correct_buf; + rtc::ByteBuffer wrong_buf; + EXPECT_TRUE(addr->Write(&correct_buf)); + EXPECT_TRUE(addr2.Write(&wrong_buf)); + // But when written out, the buffers should look different. + ASSERT_NE(0, + memcmp(correct_buf.Data(), wrong_buf.Data(), wrong_buf.Length())); + // And when reading a known good value, the address should be wrong. + addr2.Read(&correct_buf); + ASSERT_NE(addr->ipaddr(), addr2.ipaddr()); + addr2.SetIP(addr->ipaddr()); + addr2.SetPort(addr->port()); + // Try writing with no owner at all, should fail and write nothing. + addr2.SetOwner(NULL); + ASSERT_EQ(addr2.ipaddr(), addr->ipaddr()); + wrong_buf.Consume(wrong_buf.Length()); + EXPECT_FALSE(addr2.Write(&wrong_buf)); + ASSERT_EQ(0U, wrong_buf.Length()); +} + +TEST_F(StunTest, SetIPv4XorAddressAttributeOwner) { + // Unlike the IPv6XorAddressAttributeOwner test, IPv4 XOR address attributes + // should _not_ be affected by a change in owner. IPv4 XOR address uses the + // magic cookie value which is fixed. + StunMessage msg; + StunMessage msg2; + size_t size = ReadStunMessage(&msg, kStunMessageWithIPv4XorMappedAddress); + + rtc::IPAddress test_address(kIPv4TestAddress1); + + CheckStunHeader(msg, STUN_BINDING_RESPONSE, size); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + + const StunAddressAttribute* addr = + msg.GetAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV4, + kTestMessagePort3, test_address); + + // Owner with a different transaction ID. + msg2.SetTransactionID("ABCDABCDABCD"); + StunXorAddressAttribute addr2(STUN_ATTR_XOR_MAPPED_ADDRESS, 20, NULL); + addr2.SetIP(addr->ipaddr()); + addr2.SetPort(addr->port()); + addr2.SetOwner(&msg2); + // The internal IP address shouldn't change. + ASSERT_EQ(addr2.ipaddr(), addr->ipaddr()); + + rtc::ByteBuffer correct_buf; + rtc::ByteBuffer wrong_buf; + EXPECT_TRUE(addr->Write(&correct_buf)); + EXPECT_TRUE(addr2.Write(&wrong_buf)); + // The same address data should be written. + ASSERT_EQ(0, + memcmp(correct_buf.Data(), wrong_buf.Data(), wrong_buf.Length())); + // And an attribute should be able to un-XOR an address belonging to a message + // with a different transaction ID. + EXPECT_TRUE(addr2.Read(&correct_buf)); + ASSERT_EQ(addr->ipaddr(), addr2.ipaddr()); + + // However, no owner is still an error, should fail and write nothing. + addr2.SetOwner(NULL); + ASSERT_EQ(addr2.ipaddr(), addr->ipaddr()); + wrong_buf.Consume(wrong_buf.Length()); + EXPECT_FALSE(addr2.Write(&wrong_buf)); +} + +TEST_F(StunTest, CreateIPv6AddressAttribute) { + rtc::IPAddress test_ip(kIPv6TestAddress2); + + StunAddressAttribute* addr = + StunAttribute::CreateAddress(STUN_ATTR_MAPPED_ADDRESS); + rtc::SocketAddress test_addr(test_ip, kTestMessagePort2); + addr->SetAddress(test_addr); + + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV6, + kTestMessagePort2, test_ip); + delete addr; +} + +TEST_F(StunTest, CreateIPv4AddressAttribute) { + struct in_addr test_in_addr; + test_in_addr.s_addr = 0xBEB0B0BE; + rtc::IPAddress test_ip(test_in_addr); + + StunAddressAttribute* addr = + StunAttribute::CreateAddress(STUN_ATTR_MAPPED_ADDRESS); + rtc::SocketAddress test_addr(test_ip, kTestMessagePort2); + addr->SetAddress(test_addr); + + CheckStunAddressAttribute(addr, STUN_ADDRESS_IPV4, + kTestMessagePort2, test_ip); + delete addr; +} + +// Test that we don't care what order we set the parts of an address +TEST_F(StunTest, CreateAddressInArbitraryOrder) { + StunAddressAttribute* addr = + StunAttribute::CreateAddress(STUN_ATTR_DESTINATION_ADDRESS); + // Port first + addr->SetPort(kTestMessagePort1); + addr->SetIP(rtc::IPAddress(kIPv4TestAddress1)); + ASSERT_EQ(kTestMessagePort1, addr->port()); + ASSERT_EQ(rtc::IPAddress(kIPv4TestAddress1), addr->ipaddr()); + + StunAddressAttribute* addr2 = + StunAttribute::CreateAddress(STUN_ATTR_DESTINATION_ADDRESS); + // IP first + addr2->SetIP(rtc::IPAddress(kIPv4TestAddress1)); + addr2->SetPort(kTestMessagePort2); + ASSERT_EQ(kTestMessagePort2, addr2->port()); + ASSERT_EQ(rtc::IPAddress(kIPv4TestAddress1), addr2->ipaddr()); + + delete addr; + delete addr2; +} + +TEST_F(StunTest, WriteMessageWithIPv6AddressAttribute) { + StunMessage msg; + size_t size = sizeof(kStunMessageWithIPv6MappedAddress); + + rtc::IPAddress test_ip(kIPv6TestAddress1); + + msg.SetType(STUN_BINDING_REQUEST); + msg.SetTransactionID( + std::string(reinterpret_cast<const char*>(kTestTransactionId1), + kStunTransactionIdLength)); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + + StunAddressAttribute* addr = + StunAttribute::CreateAddress(STUN_ATTR_MAPPED_ADDRESS); + rtc::SocketAddress test_addr(test_ip, kTestMessagePort2); + addr->SetAddress(test_addr); + EXPECT_TRUE(msg.AddAttribute(addr)); + + CheckStunHeader(msg, STUN_BINDING_REQUEST, (size - 20)); + + rtc::ByteBuffer out; + EXPECT_TRUE(msg.Write(&out)); + ASSERT_EQ(out.Length(), sizeof(kStunMessageWithIPv6MappedAddress)); + int len1 = static_cast<int>(out.Length()); + std::string bytes; + out.ReadString(&bytes, len1); + ASSERT_EQ(0, memcmp(bytes.c_str(), kStunMessageWithIPv6MappedAddress, len1)); +} + +TEST_F(StunTest, WriteMessageWithIPv4AddressAttribute) { + StunMessage msg; + size_t size = sizeof(kStunMessageWithIPv4MappedAddress); + + rtc::IPAddress test_ip(kIPv4TestAddress1); + + msg.SetType(STUN_BINDING_RESPONSE); + msg.SetTransactionID( + std::string(reinterpret_cast<const char*>(kTestTransactionId1), + kStunTransactionIdLength)); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + + StunAddressAttribute* addr = + StunAttribute::CreateAddress(STUN_ATTR_MAPPED_ADDRESS); + rtc::SocketAddress test_addr(test_ip, kTestMessagePort4); + addr->SetAddress(test_addr); + EXPECT_TRUE(msg.AddAttribute(addr)); + + CheckStunHeader(msg, STUN_BINDING_RESPONSE, (size - 20)); + + rtc::ByteBuffer out; + EXPECT_TRUE(msg.Write(&out)); + ASSERT_EQ(out.Length(), sizeof(kStunMessageWithIPv4MappedAddress)); + int len1 = static_cast<int>(out.Length()); + std::string bytes; + out.ReadString(&bytes, len1); + ASSERT_EQ(0, memcmp(bytes.c_str(), kStunMessageWithIPv4MappedAddress, len1)); +} + +TEST_F(StunTest, WriteMessageWithIPv6XorAddressAttribute) { + StunMessage msg; + size_t size = sizeof(kStunMessageWithIPv6XorMappedAddress); + + rtc::IPAddress test_ip(kIPv6TestAddress1); + + msg.SetType(STUN_BINDING_RESPONSE); + msg.SetTransactionID( + std::string(reinterpret_cast<const char*>(kTestTransactionId2), + kStunTransactionIdLength)); + CheckStunTransactionID(msg, kTestTransactionId2, kStunTransactionIdLength); + + StunAddressAttribute* addr = + StunAttribute::CreateXorAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + rtc::SocketAddress test_addr(test_ip, kTestMessagePort1); + addr->SetAddress(test_addr); + EXPECT_TRUE(msg.AddAttribute(addr)); + + CheckStunHeader(msg, STUN_BINDING_RESPONSE, (size - 20)); + + rtc::ByteBuffer out; + EXPECT_TRUE(msg.Write(&out)); + ASSERT_EQ(out.Length(), sizeof(kStunMessageWithIPv6XorMappedAddress)); + int len1 = static_cast<int>(out.Length()); + std::string bytes; + out.ReadString(&bytes, len1); + ASSERT_EQ(0, + memcmp(bytes.c_str(), kStunMessageWithIPv6XorMappedAddress, len1)); +} + +TEST_F(StunTest, WriteMessageWithIPv4XoreAddressAttribute) { + StunMessage msg; + size_t size = sizeof(kStunMessageWithIPv4XorMappedAddress); + + rtc::IPAddress test_ip(kIPv4TestAddress1); + + msg.SetType(STUN_BINDING_RESPONSE); + msg.SetTransactionID( + std::string(reinterpret_cast<const char*>(kTestTransactionId1), + kStunTransactionIdLength)); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + + StunAddressAttribute* addr = + StunAttribute::CreateXorAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + rtc::SocketAddress test_addr(test_ip, kTestMessagePort3); + addr->SetAddress(test_addr); + EXPECT_TRUE(msg.AddAttribute(addr)); + + CheckStunHeader(msg, STUN_BINDING_RESPONSE, (size - 20)); + + rtc::ByteBuffer out; + EXPECT_TRUE(msg.Write(&out)); + ASSERT_EQ(out.Length(), sizeof(kStunMessageWithIPv4XorMappedAddress)); + int len1 = static_cast<int>(out.Length()); + std::string bytes; + out.ReadString(&bytes, len1); + ASSERT_EQ(0, + memcmp(bytes.c_str(), kStunMessageWithIPv4XorMappedAddress, len1)); +} + +TEST_F(StunTest, ReadByteStringAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kStunMessageWithByteStringAttribute); + + CheckStunHeader(msg, STUN_BINDING_REQUEST, size); + CheckStunTransactionID(msg, kTestTransactionId2, kStunTransactionIdLength); + const StunByteStringAttribute* username = + msg.GetByteString(STUN_ATTR_USERNAME); + ASSERT_TRUE(username != NULL); + EXPECT_EQ(kTestUserName1, username->GetString()); +} + +TEST_F(StunTest, ReadPaddedByteStringAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, + kStunMessageWithPaddedByteStringAttribute); + ASSERT_NE(0U, size); + CheckStunHeader(msg, STUN_BINDING_REQUEST, size); + CheckStunTransactionID(msg, kTestTransactionId2, kStunTransactionIdLength); + const StunByteStringAttribute* username = + msg.GetByteString(STUN_ATTR_USERNAME); + ASSERT_TRUE(username != NULL); + EXPECT_EQ(kTestUserName2, username->GetString()); +} + +TEST_F(StunTest, ReadErrorCodeAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kStunMessageWithErrorAttribute); + + CheckStunHeader(msg, STUN_BINDING_ERROR_RESPONSE, size); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + const StunErrorCodeAttribute* errorcode = msg.GetErrorCode(); + ASSERT_TRUE(errorcode != NULL); + EXPECT_EQ(kTestErrorClass, errorcode->eclass()); + EXPECT_EQ(kTestErrorNumber, errorcode->number()); + EXPECT_EQ(kTestErrorReason, errorcode->reason()); + EXPECT_EQ(kTestErrorCode, errorcode->code()); +} + +TEST_F(StunTest, ReadMessageWithAUInt16ListAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kStunMessageWithUInt16ListAttribute); + CheckStunHeader(msg, STUN_BINDING_REQUEST, size); + const StunUInt16ListAttribute* types = msg.GetUnknownAttributes(); + ASSERT_TRUE(types != NULL); + EXPECT_EQ(3U, types->Size()); + EXPECT_EQ(0x1U, types->GetType(0)); + EXPECT_EQ(0x1000U, types->GetType(1)); + EXPECT_EQ(0xAB0CU, types->GetType(2)); +} + +TEST_F(StunTest, ReadMessageWithAnUnknownAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kStunMessageWithUnknownAttribute); + CheckStunHeader(msg, STUN_BINDING_REQUEST, size); + + // Parsing should have succeeded and there should be a USERNAME attribute + const StunByteStringAttribute* username = + msg.GetByteString(STUN_ATTR_USERNAME); + ASSERT_TRUE(username != NULL); + EXPECT_EQ(kTestUserName2, username->GetString()); +} + +TEST_F(StunTest, ReadMessageWithOriginAttribute) { + StunMessage msg; + size_t size = ReadStunMessage(&msg, kStunMessageWithOriginAttribute); + CheckStunHeader(msg, STUN_BINDING_REQUEST, size); + const StunByteStringAttribute* origin = + msg.GetByteString(STUN_ATTR_ORIGIN); + ASSERT_TRUE(origin != NULL); + EXPECT_EQ(kTestOrigin, origin->GetString()); +} + +TEST_F(StunTest, WriteMessageWithAnErrorCodeAttribute) { + StunMessage msg; + size_t size = sizeof(kStunMessageWithErrorAttribute); + + msg.SetType(STUN_BINDING_ERROR_RESPONSE); + msg.SetTransactionID( + std::string(reinterpret_cast<const char*>(kTestTransactionId1), + kStunTransactionIdLength)); + CheckStunTransactionID(msg, kTestTransactionId1, kStunTransactionIdLength); + StunErrorCodeAttribute* errorcode = StunAttribute::CreateErrorCode(); + errorcode->SetCode(kTestErrorCode); + errorcode->SetReason(kTestErrorReason); + EXPECT_TRUE(msg.AddAttribute(errorcode)); + CheckStunHeader(msg, STUN_BINDING_ERROR_RESPONSE, (size - 20)); + + rtc::ByteBuffer out; + EXPECT_TRUE(msg.Write(&out)); + ASSERT_EQ(size, out.Length()); + // No padding. + ASSERT_EQ(0, memcmp(out.Data(), kStunMessageWithErrorAttribute, size)); +} + +TEST_F(StunTest, WriteMessageWithAUInt16ListAttribute) { + StunMessage msg; + size_t size = sizeof(kStunMessageWithUInt16ListAttribute); + + msg.SetType(STUN_BINDING_REQUEST); + msg.SetTransactionID( + std::string(reinterpret_cast<const char*>(kTestTransactionId2), + kStunTransactionIdLength)); + CheckStunTransactionID(msg, kTestTransactionId2, kStunTransactionIdLength); + StunUInt16ListAttribute* list = StunAttribute::CreateUnknownAttributes(); + list->AddType(0x1U); + list->AddType(0x1000U); + list->AddType(0xAB0CU); + EXPECT_TRUE(msg.AddAttribute(list)); + CheckStunHeader(msg, STUN_BINDING_REQUEST, (size - 20)); + + rtc::ByteBuffer out; + EXPECT_TRUE(msg.Write(&out)); + ASSERT_EQ(size, out.Length()); + // Check everything up to the padding. + ASSERT_EQ(0, + memcmp(out.Data(), kStunMessageWithUInt16ListAttribute, size - 2)); +} + +TEST_F(StunTest, WriteMessageWithOriginAttribute) { + StunMessage msg; + size_t size = sizeof(kStunMessageWithOriginAttribute); + + msg.SetType(STUN_BINDING_REQUEST); + msg.SetTransactionID( + std::string(reinterpret_cast<const char*>(kTestTransactionId1), + kStunTransactionIdLength)); + StunByteStringAttribute* origin = + new StunByteStringAttribute(STUN_ATTR_ORIGIN, kTestOrigin); + EXPECT_TRUE(msg.AddAttribute(origin)); + + rtc::ByteBuffer out; + EXPECT_TRUE(msg.Write(&out)); + ASSERT_EQ(size, out.Length()); + // Check everything up to the padding + ASSERT_EQ(0, memcmp(out.Data(), kStunMessageWithOriginAttribute, size - 2)); +} + +// Test that we fail to read messages with invalid lengths. +void CheckFailureToRead(const unsigned char* testcase, size_t length) { + StunMessage msg; + const char* input = reinterpret_cast<const char*>(testcase); + rtc::ByteBuffer buf(input, length); + ASSERT_FALSE(msg.Read(&buf)); +} + +TEST_F(StunTest, FailToReadInvalidMessages) { + CheckFailureToRead(kStunMessageWithZeroLength, + kRealLengthOfInvalidLengthTestCases); + CheckFailureToRead(kStunMessageWithSmallLength, + kRealLengthOfInvalidLengthTestCases); + CheckFailureToRead(kStunMessageWithExcessLength, + kRealLengthOfInvalidLengthTestCases); +} + +// Test that we properly fail to read a non-STUN message. +TEST_F(StunTest, FailToReadRtcpPacket) { + CheckFailureToRead(kRtcpPacket, sizeof(kRtcpPacket)); +} + +// Check our STUN message validation code against the RFC5769 test messages. +TEST_F(StunTest, ValidateMessageIntegrity) { + // Try the messages from RFC 5769. + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kRfc5769SampleRequest), + sizeof(kRfc5769SampleRequest), + kRfc5769SampleMsgPassword)); + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kRfc5769SampleRequest), + sizeof(kRfc5769SampleRequest), + "InvalidPassword")); + + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kRfc5769SampleResponse), + sizeof(kRfc5769SampleResponse), + kRfc5769SampleMsgPassword)); + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kRfc5769SampleResponse), + sizeof(kRfc5769SampleResponse), + "InvalidPassword")); + + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kRfc5769SampleResponseIPv6), + sizeof(kRfc5769SampleResponseIPv6), + kRfc5769SampleMsgPassword)); + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kRfc5769SampleResponseIPv6), + sizeof(kRfc5769SampleResponseIPv6), + "InvalidPassword")); + + // We first need to compute the key for the long-term authentication HMAC. + std::string key; + ComputeStunCredentialHash(kRfc5769SampleMsgWithAuthUsername, + kRfc5769SampleMsgWithAuthRealm, kRfc5769SampleMsgWithAuthPassword, &key); + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kRfc5769SampleRequestLongTermAuth), + sizeof(kRfc5769SampleRequestLongTermAuth), key)); + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kRfc5769SampleRequestLongTermAuth), + sizeof(kRfc5769SampleRequestLongTermAuth), + "InvalidPassword")); + + // Try some edge cases. + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kStunMessageWithZeroLength), + sizeof(kStunMessageWithZeroLength), + kRfc5769SampleMsgPassword)); + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kStunMessageWithExcessLength), + sizeof(kStunMessageWithExcessLength), + kRfc5769SampleMsgPassword)); + EXPECT_FALSE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(kStunMessageWithSmallLength), + sizeof(kStunMessageWithSmallLength), + kRfc5769SampleMsgPassword)); + + // Test that munging a single bit anywhere in the message causes the + // message-integrity check to fail, unless it is after the M-I attribute. + char buf[sizeof(kRfc5769SampleRequest)]; + memcpy(buf, kRfc5769SampleRequest, sizeof(kRfc5769SampleRequest)); + for (size_t i = 0; i < sizeof(buf); ++i) { + buf[i] ^= 0x01; + if (i > 0) + buf[i - 1] ^= 0x01; + EXPECT_EQ(i >= sizeof(buf) - 8, StunMessage::ValidateMessageIntegrity( + buf, sizeof(buf), kRfc5769SampleMsgPassword)); + } +} + +// Validate that we generate correct MESSAGE-INTEGRITY attributes. +// Note the use of IceMessage instead of StunMessage; this is necessary because +// the RFC5769 test messages used include attributes not found in basic STUN. +TEST_F(StunTest, AddMessageIntegrity) { + IceMessage msg; + rtc::ByteBuffer buf( + reinterpret_cast<const char*>(kRfc5769SampleRequestWithoutMI), + sizeof(kRfc5769SampleRequestWithoutMI)); + EXPECT_TRUE(msg.Read(&buf)); + EXPECT_TRUE(msg.AddMessageIntegrity(kRfc5769SampleMsgPassword)); + const StunByteStringAttribute* mi_attr = + msg.GetByteString(STUN_ATTR_MESSAGE_INTEGRITY); + EXPECT_EQ(20U, mi_attr->length()); + EXPECT_EQ(0, memcmp( + mi_attr->bytes(), kCalculatedHmac1, sizeof(kCalculatedHmac1))); + + rtc::ByteBuffer buf1; + EXPECT_TRUE(msg.Write(&buf1)); + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(buf1.Data()), buf1.Length(), + kRfc5769SampleMsgPassword)); + + IceMessage msg2; + rtc::ByteBuffer buf2( + reinterpret_cast<const char*>(kRfc5769SampleResponseWithoutMI), + sizeof(kRfc5769SampleResponseWithoutMI)); + EXPECT_TRUE(msg2.Read(&buf2)); + EXPECT_TRUE(msg2.AddMessageIntegrity(kRfc5769SampleMsgPassword)); + const StunByteStringAttribute* mi_attr2 = + msg2.GetByteString(STUN_ATTR_MESSAGE_INTEGRITY); + EXPECT_EQ(20U, mi_attr2->length()); + EXPECT_EQ( + 0, memcmp(mi_attr2->bytes(), kCalculatedHmac2, sizeof(kCalculatedHmac2))); + + rtc::ByteBuffer buf3; + EXPECT_TRUE(msg2.Write(&buf3)); + EXPECT_TRUE(StunMessage::ValidateMessageIntegrity( + reinterpret_cast<const char*>(buf3.Data()), buf3.Length(), + kRfc5769SampleMsgPassword)); +} + +// Check our STUN message validation code against the RFC5769 test messages. +TEST_F(StunTest, ValidateFingerprint) { + EXPECT_TRUE(StunMessage::ValidateFingerprint( + reinterpret_cast<const char*>(kRfc5769SampleRequest), + sizeof(kRfc5769SampleRequest))); + EXPECT_TRUE(StunMessage::ValidateFingerprint( + reinterpret_cast<const char*>(kRfc5769SampleResponse), + sizeof(kRfc5769SampleResponse))); + EXPECT_TRUE(StunMessage::ValidateFingerprint( + reinterpret_cast<const char*>(kRfc5769SampleResponseIPv6), + sizeof(kRfc5769SampleResponseIPv6))); + + EXPECT_FALSE(StunMessage::ValidateFingerprint( + reinterpret_cast<const char*>(kStunMessageWithZeroLength), + sizeof(kStunMessageWithZeroLength))); + EXPECT_FALSE(StunMessage::ValidateFingerprint( + reinterpret_cast<const char*>(kStunMessageWithExcessLength), + sizeof(kStunMessageWithExcessLength))); + EXPECT_FALSE(StunMessage::ValidateFingerprint( + reinterpret_cast<const char*>(kStunMessageWithSmallLength), + sizeof(kStunMessageWithSmallLength))); + + // Test that munging a single bit anywhere in the message causes the + // fingerprint check to fail. + char buf[sizeof(kRfc5769SampleRequest)]; + memcpy(buf, kRfc5769SampleRequest, sizeof(kRfc5769SampleRequest)); + for (size_t i = 0; i < sizeof(buf); ++i) { + buf[i] ^= 0x01; + if (i > 0) + buf[i - 1] ^= 0x01; + EXPECT_FALSE(StunMessage::ValidateFingerprint(buf, sizeof(buf))); + } + // Put them all back to normal and the check should pass again. + buf[sizeof(buf) - 1] ^= 0x01; + EXPECT_TRUE(StunMessage::ValidateFingerprint(buf, sizeof(buf))); +} + +TEST_F(StunTest, AddFingerprint) { + IceMessage msg; + rtc::ByteBuffer buf( + reinterpret_cast<const char*>(kRfc5769SampleRequestWithoutMI), + sizeof(kRfc5769SampleRequestWithoutMI)); + EXPECT_TRUE(msg.Read(&buf)); + EXPECT_TRUE(msg.AddFingerprint()); + + rtc::ByteBuffer buf1; + EXPECT_TRUE(msg.Write(&buf1)); + EXPECT_TRUE(StunMessage::ValidateFingerprint( + reinterpret_cast<const char*>(buf1.Data()), buf1.Length())); +} + +// Sample "GTURN" relay message. +static const unsigned char kRelayMessage[] = { + 0x00, 0x01, 0x00, 88, // message header + 0x21, 0x12, 0xA4, 0x42, // magic cookie + '0', '1', '2', '3', // transaction id + '4', '5', '6', '7', + '8', '9', 'a', 'b', + 0x00, 0x01, 0x00, 8, // mapped address + 0x00, 0x01, 0x00, 13, + 0x00, 0x00, 0x00, 17, + 0x00, 0x06, 0x00, 12, // username + 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 0x00, 0x0d, 0x00, 4, // lifetime + 0x00, 0x00, 0x00, 11, + 0x00, 0x0f, 0x00, 4, // magic cookie + 0x72, 0xc6, 0x4b, 0xc6, + 0x00, 0x10, 0x00, 4, // bandwidth + 0x00, 0x00, 0x00, 6, + 0x00, 0x11, 0x00, 8, // destination address + 0x00, 0x01, 0x00, 13, + 0x00, 0x00, 0x00, 17, + 0x00, 0x12, 0x00, 8, // source address 2 + 0x00, 0x01, 0x00, 13, + 0x00, 0x00, 0x00, 17, + 0x00, 0x13, 0x00, 7, // data + 'a', 'b', 'c', 'd', + 'e', 'f', 'g', 0 // DATA must be padded per rfc5766. +}; + +// Test that we can read the GTURN-specific fields. +TEST_F(StunTest, ReadRelayMessage) { + RelayMessage msg, msg2; + + const char* input = reinterpret_cast<const char*>(kRelayMessage); + size_t size = sizeof(kRelayMessage); + rtc::ByteBuffer buf(input, size); + EXPECT_TRUE(msg.Read(&buf)); + + EXPECT_EQ(STUN_BINDING_REQUEST, msg.type()); + EXPECT_EQ(size - 20, msg.length()); + EXPECT_EQ("0123456789ab", msg.transaction_id()); + + msg2.SetType(STUN_BINDING_REQUEST); + msg2.SetTransactionID("0123456789ab"); + + in_addr legacy_in_addr; + legacy_in_addr.s_addr = htonl(17U); + rtc::IPAddress legacy_ip(legacy_in_addr); + + const StunAddressAttribute* addr = msg.GetAddress(STUN_ATTR_MAPPED_ADDRESS); + ASSERT_TRUE(addr != NULL); + EXPECT_EQ(1, addr->family()); + EXPECT_EQ(13, addr->port()); + EXPECT_EQ(legacy_ip, addr->ipaddr()); + + StunAddressAttribute* addr2 = + StunAttribute::CreateAddress(STUN_ATTR_MAPPED_ADDRESS); + addr2->SetPort(13); + addr2->SetIP(legacy_ip); + EXPECT_TRUE(msg2.AddAttribute(addr2)); + + const StunByteStringAttribute* bytes = msg.GetByteString(STUN_ATTR_USERNAME); + ASSERT_TRUE(bytes != NULL); + EXPECT_EQ(12U, bytes->length()); + EXPECT_EQ("abcdefghijkl", bytes->GetString()); + + StunByteStringAttribute* bytes2 = + StunAttribute::CreateByteString(STUN_ATTR_USERNAME); + bytes2->CopyBytes("abcdefghijkl"); + EXPECT_TRUE(msg2.AddAttribute(bytes2)); + + const StunUInt32Attribute* uval = msg.GetUInt32(STUN_ATTR_LIFETIME); + ASSERT_TRUE(uval != NULL); + EXPECT_EQ(11U, uval->value()); + + StunUInt32Attribute* uval2 = StunAttribute::CreateUInt32(STUN_ATTR_LIFETIME); + uval2->SetValue(11); + EXPECT_TRUE(msg2.AddAttribute(uval2)); + + bytes = msg.GetByteString(STUN_ATTR_MAGIC_COOKIE); + ASSERT_TRUE(bytes != NULL); + EXPECT_EQ(4U, bytes->length()); + EXPECT_EQ(0, + memcmp(bytes->bytes(), + TURN_MAGIC_COOKIE_VALUE, + sizeof(TURN_MAGIC_COOKIE_VALUE))); + + bytes2 = StunAttribute::CreateByteString(STUN_ATTR_MAGIC_COOKIE); + bytes2->CopyBytes(reinterpret_cast<const char*>(TURN_MAGIC_COOKIE_VALUE), + sizeof(TURN_MAGIC_COOKIE_VALUE)); + EXPECT_TRUE(msg2.AddAttribute(bytes2)); + + uval = msg.GetUInt32(STUN_ATTR_BANDWIDTH); + ASSERT_TRUE(uval != NULL); + EXPECT_EQ(6U, uval->value()); + + uval2 = StunAttribute::CreateUInt32(STUN_ATTR_BANDWIDTH); + uval2->SetValue(6); + EXPECT_TRUE(msg2.AddAttribute(uval2)); + + addr = msg.GetAddress(STUN_ATTR_DESTINATION_ADDRESS); + ASSERT_TRUE(addr != NULL); + EXPECT_EQ(1, addr->family()); + EXPECT_EQ(13, addr->port()); + EXPECT_EQ(legacy_ip, addr->ipaddr()); + + addr2 = StunAttribute::CreateAddress(STUN_ATTR_DESTINATION_ADDRESS); + addr2->SetPort(13); + addr2->SetIP(legacy_ip); + EXPECT_TRUE(msg2.AddAttribute(addr2)); + + addr = msg.GetAddress(STUN_ATTR_SOURCE_ADDRESS2); + ASSERT_TRUE(addr != NULL); + EXPECT_EQ(1, addr->family()); + EXPECT_EQ(13, addr->port()); + EXPECT_EQ(legacy_ip, addr->ipaddr()); + + addr2 = StunAttribute::CreateAddress(STUN_ATTR_SOURCE_ADDRESS2); + addr2->SetPort(13); + addr2->SetIP(legacy_ip); + EXPECT_TRUE(msg2.AddAttribute(addr2)); + + bytes = msg.GetByteString(STUN_ATTR_DATA); + ASSERT_TRUE(bytes != NULL); + EXPECT_EQ(7U, bytes->length()); + EXPECT_EQ("abcdefg", bytes->GetString()); + + bytes2 = StunAttribute::CreateByteString(STUN_ATTR_DATA); + bytes2->CopyBytes("abcdefg"); + EXPECT_TRUE(msg2.AddAttribute(bytes2)); + + rtc::ByteBuffer out; + EXPECT_TRUE(msg.Write(&out)); + EXPECT_EQ(size, out.Length()); + size_t len1 = out.Length(); + std::string outstring; + out.ReadString(&outstring, len1); + EXPECT_EQ(0, memcmp(outstring.c_str(), input, len1)); + + rtc::ByteBuffer out2; + EXPECT_TRUE(msg2.Write(&out2)); + EXPECT_EQ(size, out2.Length()); + size_t len2 = out2.Length(); + std::string outstring2; + out2.ReadString(&outstring2, len2); + EXPECT_EQ(0, memcmp(outstring2.c_str(), input, len2)); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/stunport.cc b/webrtc/p2p/base/stunport.cc new file mode 100644 index 0000000000..1598fe43ce --- /dev/null +++ b/webrtc/p2p/base/stunport.cc @@ -0,0 +1,485 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/stunport.h" + +#include "webrtc/p2p/base/common.h" +#include "webrtc/p2p/base/portallocator.h" +#include "webrtc/p2p/base/stun.h" +#include "webrtc/base/common.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/ipaddress.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/nethelpers.h" + +namespace cricket { + +// TODO: Move these to a common place (used in relayport too) +const int KEEPALIVE_DELAY = 10 * 1000; // 10 seconds - sort timeouts +const int RETRY_DELAY = 50; // 50ms, from ICE spec +const int RETRY_TIMEOUT = 50 * 1000; // ICE says 50 secs + +// Handles a binding request sent to the STUN server. +class StunBindingRequest : public StunRequest { + public: + StunBindingRequest(UDPPort* port, bool keep_alive, + const rtc::SocketAddress& addr) + : port_(port), keep_alive_(keep_alive), server_addr_(addr) { + start_time_ = rtc::Time(); + } + + virtual ~StunBindingRequest() { + } + + const rtc::SocketAddress& server_addr() const { return server_addr_; } + + virtual void Prepare(StunMessage* request) override { + request->SetType(STUN_BINDING_REQUEST); + } + + virtual void OnResponse(StunMessage* response) override { + const StunAddressAttribute* addr_attr = + response->GetAddress(STUN_ATTR_MAPPED_ADDRESS); + if (!addr_attr) { + LOG(LS_ERROR) << "Binding response missing mapped address."; + } else if (addr_attr->family() != STUN_ADDRESS_IPV4 && + addr_attr->family() != STUN_ADDRESS_IPV6) { + LOG(LS_ERROR) << "Binding address has bad family"; + } else { + rtc::SocketAddress addr(addr_attr->ipaddr(), addr_attr->port()); + port_->OnStunBindingRequestSucceeded(server_addr_, addr); + } + + // We will do a keep-alive regardless of whether this request succeeds. + // This should have almost no impact on network usage. + if (keep_alive_) { + port_->requests_.SendDelayed( + new StunBindingRequest(port_, true, server_addr_), + port_->stun_keepalive_delay()); + } + } + + virtual void OnErrorResponse(StunMessage* response) override { + const StunErrorCodeAttribute* attr = response->GetErrorCode(); + if (!attr) { + LOG(LS_ERROR) << "Bad allocate response error code"; + } else { + LOG(LS_ERROR) << "Binding error response:" + << " class=" << attr->eclass() + << " number=" << attr->number() + << " reason='" << attr->reason() << "'"; + } + + port_->OnStunBindingOrResolveRequestFailed(server_addr_); + + if (keep_alive_ + && (rtc::TimeSince(start_time_) <= RETRY_TIMEOUT)) { + port_->requests_.SendDelayed( + new StunBindingRequest(port_, true, server_addr_), + port_->stun_keepalive_delay()); + } + } + + virtual void OnTimeout() override { + LOG(LS_ERROR) << "Binding request timed out from " + << port_->GetLocalAddress().ToSensitiveString() + << " (" << port_->Network()->name() << ")"; + + port_->OnStunBindingOrResolveRequestFailed(server_addr_); + + if (keep_alive_ + && (rtc::TimeSince(start_time_) <= RETRY_TIMEOUT)) { + port_->requests_.SendDelayed( + new StunBindingRequest(port_, true, server_addr_), + RETRY_DELAY); + } + } + + private: + UDPPort* port_; + bool keep_alive_; + const rtc::SocketAddress server_addr_; + uint32_t start_time_; +}; + +UDPPort::AddressResolver::AddressResolver( + rtc::PacketSocketFactory* factory) + : socket_factory_(factory) {} + +UDPPort::AddressResolver::~AddressResolver() { + for (ResolverMap::iterator it = resolvers_.begin(); + it != resolvers_.end(); ++it) { + it->second->Destroy(true); + } +} + +void UDPPort::AddressResolver::Resolve( + const rtc::SocketAddress& address) { + if (resolvers_.find(address) != resolvers_.end()) + return; + + rtc::AsyncResolverInterface* resolver = + socket_factory_->CreateAsyncResolver(); + resolvers_.insert( + std::pair<rtc::SocketAddress, rtc::AsyncResolverInterface*>( + address, resolver)); + + resolver->SignalDone.connect(this, + &UDPPort::AddressResolver::OnResolveResult); + + resolver->Start(address); +} + +bool UDPPort::AddressResolver::GetResolvedAddress( + const rtc::SocketAddress& input, + int family, + rtc::SocketAddress* output) const { + ResolverMap::const_iterator it = resolvers_.find(input); + if (it == resolvers_.end()) + return false; + + return it->second->GetResolvedAddress(family, output); +} + +void UDPPort::AddressResolver::OnResolveResult( + rtc::AsyncResolverInterface* resolver) { + for (ResolverMap::iterator it = resolvers_.begin(); + it != resolvers_.end(); ++it) { + if (it->second == resolver) { + SignalDone(it->first, resolver->GetError()); + return; + } + } +} + +UDPPort::UDPPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + rtc::AsyncPacketSocket* socket, + const std::string& username, + const std::string& password, + const std::string& origin, + bool emit_localhost_for_anyaddress) + : Port(thread, factory, network, socket->GetLocalAddress().ipaddr(), + username, password), + requests_(thread), + socket_(socket), + error_(0), + ready_(false), + stun_keepalive_delay_(KEEPALIVE_DELAY), + emit_localhost_for_anyaddress_(emit_localhost_for_anyaddress) { + requests_.set_origin(origin); +} + +UDPPort::UDPPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + const std::string& origin, + bool emit_localhost_for_anyaddress) + : Port(thread, + LOCAL_PORT_TYPE, + factory, + network, + ip, + min_port, + max_port, + username, + password), + requests_(thread), + socket_(NULL), + error_(0), + ready_(false), + stun_keepalive_delay_(KEEPALIVE_DELAY), + emit_localhost_for_anyaddress_(emit_localhost_for_anyaddress) { + requests_.set_origin(origin); +} + +bool UDPPort::Init() { + if (!SharedSocket()) { + ASSERT(socket_ == NULL); + socket_ = socket_factory()->CreateUdpSocket( + rtc::SocketAddress(ip(), 0), min_port(), max_port()); + if (!socket_) { + LOG_J(LS_WARNING, this) << "UDP socket creation failed"; + return false; + } + socket_->SignalReadPacket.connect(this, &UDPPort::OnReadPacket); + } + socket_->SignalSentPacket.connect(this, &UDPPort::OnSentPacket); + socket_->SignalReadyToSend.connect(this, &UDPPort::OnReadyToSend); + socket_->SignalAddressReady.connect(this, &UDPPort::OnLocalAddressReady); + requests_.SignalSendPacket.connect(this, &UDPPort::OnSendPacket); + return true; +} + +UDPPort::~UDPPort() { + if (!SharedSocket()) + delete socket_; +} + +void UDPPort::PrepareAddress() { + ASSERT(requests_.empty()); + if (socket_->GetState() == rtc::AsyncPacketSocket::STATE_BOUND) { + OnLocalAddressReady(socket_, socket_->GetLocalAddress()); + } +} + +void UDPPort::MaybePrepareStunCandidate() { + // Sending binding request to the STUN server if address is available to + // prepare STUN candidate. + if (!server_addresses_.empty()) { + SendStunBindingRequests(); + } else { + // Port is done allocating candidates. + MaybeSetPortCompleteOrError(); + } +} + +Connection* UDPPort::CreateConnection(const Candidate& address, + CandidateOrigin origin) { + if (address.protocol() != "udp") + return NULL; + + if (!IsCompatibleAddress(address.address())) { + return NULL; + } + + if (SharedSocket() && Candidates()[0].type() != LOCAL_PORT_TYPE) { + ASSERT(false); + return NULL; + } + + Connection* conn = new ProxyConnection(this, 0, address); + AddConnection(conn); + return conn; +} + +int UDPPort::SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, + bool payload) { + int sent = socket_->SendTo(data, size, addr, options); + if (sent < 0) { + error_ = socket_->GetError(); + LOG_J(LS_ERROR, this) << "UDP send of " << size + << " bytes failed with error " << error_; + } + return sent; +} + +int UDPPort::SetOption(rtc::Socket::Option opt, int value) { + return socket_->SetOption(opt, value); +} + +int UDPPort::GetOption(rtc::Socket::Option opt, int* value) { + return socket_->GetOption(opt, value); +} + +int UDPPort::GetError() { + return error_; +} + +void UDPPort::OnLocalAddressReady(rtc::AsyncPacketSocket* socket, + const rtc::SocketAddress& address) { + // When adapter enumeration is disabled and binding to the any address, the + // loopback address will be issued as a candidate instead if + // |emit_localhost_for_anyaddress| is true. This is to allow connectivity on + // demo pages without STUN/TURN to work. + rtc::SocketAddress addr = address; + if (addr.IsAnyIP() && emit_localhost_for_anyaddress_) { + addr.SetIP(rtc::GetLoopbackIP(addr.family())); + } + + AddAddress(addr, addr, rtc::SocketAddress(), UDP_PROTOCOL_NAME, "", "", + LOCAL_PORT_TYPE, ICE_TYPE_PREFERENCE_HOST, 0, false); + MaybePrepareStunCandidate(); +} + +void UDPPort::OnReadPacket( + rtc::AsyncPacketSocket* socket, const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + ASSERT(socket == socket_); + ASSERT(!remote_addr.IsUnresolved()); + + // Look for a response from the STUN server. + // Even if the response doesn't match one of our outstanding requests, we + // will eat it because it might be a response to a retransmitted packet, and + // we already cleared the request when we got the first response. + if (server_addresses_.find(remote_addr) != server_addresses_.end()) { + requests_.CheckResponse(data, size); + return; + } + + if (Connection* conn = GetConnection(remote_addr)) { + conn->OnReadPacket(data, size, packet_time); + } else { + Port::OnReadPacket(data, size, remote_addr, PROTO_UDP); + } +} + +void UDPPort::OnSentPacket(rtc::AsyncPacketSocket* socket, + const rtc::SentPacket& sent_packet) { + Port::OnSentPacket(sent_packet); +} + +void UDPPort::OnReadyToSend(rtc::AsyncPacketSocket* socket) { + Port::OnReadyToSend(); +} + +void UDPPort::SendStunBindingRequests() { + // We will keep pinging the stun server to make sure our NAT pin-hole stays + // open during the call. + ASSERT(requests_.empty()); + + for (ServerAddresses::const_iterator it = server_addresses_.begin(); + it != server_addresses_.end(); ++it) { + SendStunBindingRequest(*it); + } +} + +void UDPPort::ResolveStunAddress(const rtc::SocketAddress& stun_addr) { + if (!resolver_) { + resolver_.reset(new AddressResolver(socket_factory())); + resolver_->SignalDone.connect(this, &UDPPort::OnResolveResult); + } + + resolver_->Resolve(stun_addr); +} + +void UDPPort::OnResolveResult(const rtc::SocketAddress& input, + int error) { + ASSERT(resolver_.get() != NULL); + + rtc::SocketAddress resolved; + if (error != 0 || + !resolver_->GetResolvedAddress(input, ip().family(), &resolved)) { + LOG_J(LS_WARNING, this) << "StunPort: stun host lookup received error " + << error; + OnStunBindingOrResolveRequestFailed(input); + return; + } + + server_addresses_.erase(input); + + if (server_addresses_.find(resolved) == server_addresses_.end()) { + server_addresses_.insert(resolved); + SendStunBindingRequest(resolved); + } +} + +void UDPPort::SendStunBindingRequest( + const rtc::SocketAddress& stun_addr) { + if (stun_addr.IsUnresolved()) { + ResolveStunAddress(stun_addr); + + } else if (socket_->GetState() == rtc::AsyncPacketSocket::STATE_BOUND) { + // Check if |server_addr_| is compatible with the port's ip. + if (IsCompatibleAddress(stun_addr)) { + requests_.Send(new StunBindingRequest(this, true, stun_addr)); + } else { + // Since we can't send stun messages to the server, we should mark this + // port ready. + LOG(LS_WARNING) << "STUN server address is incompatible."; + OnStunBindingOrResolveRequestFailed(stun_addr); + } + } +} + +void UDPPort::OnStunBindingRequestSucceeded( + const rtc::SocketAddress& stun_server_addr, + const rtc::SocketAddress& stun_reflected_addr) { + if (bind_request_succeeded_servers_.find(stun_server_addr) != + bind_request_succeeded_servers_.end()) { + return; + } + bind_request_succeeded_servers_.insert(stun_server_addr); + + // If socket is shared and |stun_reflected_addr| is equal to local socket + // address, or if the same address has been added by another STUN server, + // then discarding the stun address. + // For STUN, related address is the local socket address. + if ((!SharedSocket() || stun_reflected_addr != socket_->GetLocalAddress()) && + !HasCandidateWithAddress(stun_reflected_addr)) { + + rtc::SocketAddress related_address = socket_->GetLocalAddress(); + if (!(candidate_filter() & CF_HOST)) { + // If candidate filter doesn't have CF_HOST specified, empty raddr to + // avoid local address leakage. + related_address = rtc::EmptySocketAddressWithFamily( + related_address.family()); + } + + AddAddress(stun_reflected_addr, socket_->GetLocalAddress(), related_address, + UDP_PROTOCOL_NAME, "", "", STUN_PORT_TYPE, + ICE_TYPE_PREFERENCE_SRFLX, 0, false); + } + MaybeSetPortCompleteOrError(); +} + +void UDPPort::OnStunBindingOrResolveRequestFailed( + const rtc::SocketAddress& stun_server_addr) { + if (bind_request_failed_servers_.find(stun_server_addr) != + bind_request_failed_servers_.end()) { + return; + } + bind_request_failed_servers_.insert(stun_server_addr); + MaybeSetPortCompleteOrError(); +} + +void UDPPort::MaybeSetPortCompleteOrError() { + if (ready_) + return; + + // Do not set port ready if we are still waiting for bind responses. + const size_t servers_done_bind_request = bind_request_failed_servers_.size() + + bind_request_succeeded_servers_.size(); + if (server_addresses_.size() != servers_done_bind_request) { + return; + } + + // Setting ready status. + ready_ = true; + + // The port is "completed" if there is no stun server provided, or the bind + // request succeeded for any stun server, or the socket is shared. + if (server_addresses_.empty() || + bind_request_succeeded_servers_.size() > 0 || + SharedSocket()) { + SignalPortComplete(this); + } else { + SignalPortError(this); + } +} + +// TODO: merge this with SendTo above. +void UDPPort::OnSendPacket(const void* data, size_t size, StunRequest* req) { + StunBindingRequest* sreq = static_cast<StunBindingRequest*>(req); + rtc::PacketOptions options(DefaultDscpValue()); + if (socket_->SendTo(data, size, sreq->server_addr(), options) < 0) + PLOG(LERROR, socket_->GetError()) << "sendto"; +} + +bool UDPPort::HasCandidateWithAddress(const rtc::SocketAddress& addr) const { + const std::vector<Candidate>& existing_candidates = Candidates(); + std::vector<Candidate>::const_iterator it = existing_candidates.begin(); + for (; it != existing_candidates.end(); ++it) { + if (it->address() == addr) + return true; + } + return false; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/stunport.h b/webrtc/p2p/base/stunport.h new file mode 100644 index 0000000000..62b23cf074 --- /dev/null +++ b/webrtc/p2p/base/stunport.h @@ -0,0 +1,278 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_STUNPORT_H_ +#define WEBRTC_P2P_BASE_STUNPORT_H_ + +#include <string> + +#include "webrtc/p2p/base/port.h" +#include "webrtc/p2p/base/stunrequest.h" +#include "webrtc/base/asyncpacketsocket.h" + +// TODO(mallinath) - Rename stunport.cc|h to udpport.cc|h. +namespace rtc { +class AsyncResolver; +class SignalThread; +} + +namespace cricket { + +// Communicates using the address on the outside of a NAT. +class UDPPort : public Port { + public: + static UDPPort* Create(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + rtc::AsyncPacketSocket* socket, + const std::string& username, + const std::string& password, + const std::string& origin, + bool emit_localhost_for_anyaddress) { + UDPPort* port = new UDPPort(thread, factory, network, socket, + username, password, origin, + emit_localhost_for_anyaddress); + if (!port->Init()) { + delete port; + port = NULL; + } + return port; + } + + static UDPPort* Create(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + const std::string& origin, + bool emit_localhost_for_anyaddress) { + UDPPort* port = new UDPPort(thread, factory, network, + ip, min_port, max_port, + username, password, origin, + emit_localhost_for_anyaddress); + if (!port->Init()) { + delete port; + port = NULL; + } + return port; + } + + virtual ~UDPPort(); + + rtc::SocketAddress GetLocalAddress() const { + return socket_->GetLocalAddress(); + } + + const ServerAddresses& server_addresses() const { + return server_addresses_; + } + void + set_server_addresses(const ServerAddresses& addresses) { + server_addresses_ = addresses; + } + + virtual void PrepareAddress(); + + virtual Connection* CreateConnection(const Candidate& address, + CandidateOrigin origin); + virtual int SetOption(rtc::Socket::Option opt, int value); + virtual int GetOption(rtc::Socket::Option opt, int* value); + virtual int GetError(); + + virtual bool HandleIncomingPacket( + rtc::AsyncPacketSocket* socket, const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + // All packets given to UDP port will be consumed. + OnReadPacket(socket, data, size, remote_addr, packet_time); + return true; + } + + void set_stun_keepalive_delay(int delay) { + stun_keepalive_delay_ = delay; + } + int stun_keepalive_delay() const { + return stun_keepalive_delay_; + } + + protected: + UDPPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + const std::string& origin, + bool emit_localhost_for_anyaddress); + + UDPPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + rtc::AsyncPacketSocket* socket, + const std::string& username, + const std::string& password, + const std::string& origin, + bool emit_localhost_for_anyaddress); + + bool Init(); + + virtual int SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, + bool payload); + + void OnLocalAddressReady(rtc::AsyncPacketSocket* socket, + const rtc::SocketAddress& address); + void OnReadPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time); + + void OnSentPacket(rtc::AsyncPacketSocket* socket, + const rtc::SentPacket& sent_packet); + + void OnReadyToSend(rtc::AsyncPacketSocket* socket); + + // This method will send STUN binding request if STUN server address is set. + void MaybePrepareStunCandidate(); + + void SendStunBindingRequests(); + + private: + // A helper class which can be called repeatedly to resolve multiple + // addresses, as opposed to rtc::AsyncResolverInterface, which can only + // resolve one address per instance. + class AddressResolver : public sigslot::has_slots<> { + public: + explicit AddressResolver(rtc::PacketSocketFactory* factory); + ~AddressResolver(); + + void Resolve(const rtc::SocketAddress& address); + bool GetResolvedAddress(const rtc::SocketAddress& input, + int family, + rtc::SocketAddress* output) const; + + // The signal is sent when resolving the specified address is finished. The + // first argument is the input address, the second argument is the error + // or 0 if it succeeded. + sigslot::signal2<const rtc::SocketAddress&, int> SignalDone; + + private: + typedef std::map<rtc::SocketAddress, + rtc::AsyncResolverInterface*> ResolverMap; + + void OnResolveResult(rtc::AsyncResolverInterface* resolver); + + rtc::PacketSocketFactory* socket_factory_; + ResolverMap resolvers_; + }; + + // DNS resolution of the STUN server. + void ResolveStunAddress(const rtc::SocketAddress& stun_addr); + void OnResolveResult(const rtc::SocketAddress& input, int error); + + void SendStunBindingRequest(const rtc::SocketAddress& stun_addr); + + // Below methods handles binding request responses. + void OnStunBindingRequestSucceeded( + const rtc::SocketAddress& stun_server_addr, + const rtc::SocketAddress& stun_reflected_addr); + void OnStunBindingOrResolveRequestFailed( + const rtc::SocketAddress& stun_server_addr); + + // Sends STUN requests to the server. + void OnSendPacket(const void* data, size_t size, StunRequest* req); + + // TODO(mallinaht) - Move this up to cricket::Port when SignalAddressReady is + // changed to SignalPortReady. + void MaybeSetPortCompleteOrError(); + + bool HasCandidateWithAddress(const rtc::SocketAddress& addr) const; + + ServerAddresses server_addresses_; + ServerAddresses bind_request_succeeded_servers_; + ServerAddresses bind_request_failed_servers_; + StunRequestManager requests_; + rtc::AsyncPacketSocket* socket_; + int error_; + rtc::scoped_ptr<AddressResolver> resolver_; + bool ready_; + int stun_keepalive_delay_; + + // This is true when PORTALLOCATOR_ENABLE_LOCALHOST_CANDIDATE is specified. + bool emit_localhost_for_anyaddress_; + + friend class StunBindingRequest; +}; + +class StunPort : public UDPPort { + public: + static StunPort* Create(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + const ServerAddresses& servers, + const std::string& origin) { + StunPort* port = new StunPort(thread, factory, network, + ip, min_port, max_port, + username, password, servers, + origin); + if (!port->Init()) { + delete port; + port = NULL; + } + return port; + } + + virtual ~StunPort() {} + + virtual void PrepareAddress() { + SendStunBindingRequests(); + } + + protected: + StunPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + const ServerAddresses& servers, + const std::string& origin) + : UDPPort(thread, + factory, + network, + ip, + min_port, + max_port, + username, + password, + origin, + false) { + // UDPPort will set these to local udp, updating these to STUN. + set_type(STUN_PORT_TYPE); + set_server_addresses(servers); + } +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_STUNPORT_H_ diff --git a/webrtc/p2p/base/stunport_unittest.cc b/webrtc/p2p/base/stunport_unittest.cc new file mode 100644 index 0000000000..037d448b9e --- /dev/null +++ b/webrtc/p2p/base/stunport_unittest.cc @@ -0,0 +1,287 @@ +/* + * Copyright 2009 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/basicpacketsocketfactory.h" +#include "webrtc/p2p/base/stunport.h" +#include "webrtc/p2p/base/teststunserver.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/physicalsocketserver.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/socketaddress.h" +#include "webrtc/base/ssladapter.h" +#include "webrtc/base/virtualsocketserver.h" + +using cricket::ServerAddresses; +using rtc::SocketAddress; + +static const SocketAddress kLocalAddr("127.0.0.1", 0); +static const SocketAddress kStunAddr1("127.0.0.1", 5000); +static const SocketAddress kStunAddr2("127.0.0.1", 4000); +static const SocketAddress kBadAddr("0.0.0.1", 5000); +static const SocketAddress kStunHostnameAddr("localhost", 5000); +static const SocketAddress kBadHostnameAddr("not-a-real-hostname", 5000); +static const int kTimeoutMs = 10000; +// stun prio = 100 << 24 | 30 (IPV4) << 8 | 256 - 0 +static const uint32_t kStunCandidatePriority = 1677729535; + +// Tests connecting a StunPort to a fake STUN server (cricket::StunServer) +// TODO: Use a VirtualSocketServer here. We have to use a +// PhysicalSocketServer right now since DNS is not part of SocketServer yet. +class StunPortTest : public testing::Test, + public sigslot::has_slots<> { + public: + StunPortTest() + : pss_(new rtc::PhysicalSocketServer), + ss_(new rtc::VirtualSocketServer(pss_.get())), + ss_scope_(ss_.get()), + network_("unittest", "unittest", rtc::IPAddress(INADDR_ANY), 32), + socket_factory_(rtc::Thread::Current()), + stun_server_1_(cricket::TestStunServer::Create( + rtc::Thread::Current(), kStunAddr1)), + stun_server_2_(cricket::TestStunServer::Create( + rtc::Thread::Current(), kStunAddr2)), + done_(false), error_(false), stun_keepalive_delay_(0) { + } + + const cricket::Port* port() const { return stun_port_.get(); } + bool done() const { return done_; } + bool error() const { return error_; } + + void CreateStunPort(const rtc::SocketAddress& server_addr) { + ServerAddresses stun_servers; + stun_servers.insert(server_addr); + CreateStunPort(stun_servers); + } + + void CreateStunPort(const ServerAddresses& stun_servers) { + stun_port_.reset(cricket::StunPort::Create( + rtc::Thread::Current(), &socket_factory_, &network_, + kLocalAddr.ipaddr(), 0, 0, rtc::CreateRandomString(16), + rtc::CreateRandomString(22), stun_servers, std::string())); + stun_port_->set_stun_keepalive_delay(stun_keepalive_delay_); + stun_port_->SignalPortComplete.connect(this, + &StunPortTest::OnPortComplete); + stun_port_->SignalPortError.connect(this, + &StunPortTest::OnPortError); + } + + void CreateSharedStunPort(const rtc::SocketAddress& server_addr) { + socket_.reset(socket_factory_.CreateUdpSocket( + rtc::SocketAddress(kLocalAddr.ipaddr(), 0), 0, 0)); + ASSERT_TRUE(socket_ != NULL); + socket_->SignalReadPacket.connect(this, &StunPortTest::OnReadPacket); + stun_port_.reset(cricket::UDPPort::Create( + rtc::Thread::Current(), &socket_factory_, + &network_, socket_.get(), + rtc::CreateRandomString(16), rtc::CreateRandomString(22), + std::string(), false)); + ASSERT_TRUE(stun_port_ != NULL); + ServerAddresses stun_servers; + stun_servers.insert(server_addr); + stun_port_->set_server_addresses(stun_servers); + stun_port_->SignalPortComplete.connect(this, + &StunPortTest::OnPortComplete); + stun_port_->SignalPortError.connect(this, + &StunPortTest::OnPortError); + } + + void PrepareAddress() { + stun_port_->PrepareAddress(); + } + + void OnReadPacket(rtc::AsyncPacketSocket* socket, const char* data, + size_t size, const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + stun_port_->HandleIncomingPacket( + socket, data, size, remote_addr, rtc::PacketTime()); + } + + void SendData(const char* data, size_t len) { + stun_port_->HandleIncomingPacket( + socket_.get(), data, len, rtc::SocketAddress("22.22.22.22", 0), + rtc::PacketTime()); + } + + protected: + static void SetUpTestCase() { + // Ensure the RNG is inited. + rtc::InitRandom(NULL, 0); + + } + + void OnPortComplete(cricket::Port* port) { + ASSERT_FALSE(done_); + done_ = true; + error_ = false; + } + void OnPortError(cricket::Port* port) { + done_ = true; + error_ = true; + } + void SetKeepaliveDelay(int delay) { + stun_keepalive_delay_ = delay; + } + + cricket::TestStunServer* stun_server_1() { + return stun_server_1_.get(); + } + cricket::TestStunServer* stun_server_2() { + return stun_server_2_.get(); + } + + private: + rtc::scoped_ptr<rtc::PhysicalSocketServer> pss_; + rtc::scoped_ptr<rtc::VirtualSocketServer> ss_; + rtc::SocketServerScope ss_scope_; + rtc::Network network_; + rtc::BasicPacketSocketFactory socket_factory_; + rtc::scoped_ptr<cricket::UDPPort> stun_port_; + rtc::scoped_ptr<cricket::TestStunServer> stun_server_1_; + rtc::scoped_ptr<cricket::TestStunServer> stun_server_2_; + rtc::scoped_ptr<rtc::AsyncPacketSocket> socket_; + bool done_; + bool error_; + int stun_keepalive_delay_; +}; + +// Test that we can create a STUN port +TEST_F(StunPortTest, TestBasic) { + CreateStunPort(kStunAddr1); + EXPECT_EQ("stun", port()->Type()); + EXPECT_EQ(0U, port()->Candidates().size()); +} + +// Test that we can get an address from a STUN server. +TEST_F(StunPortTest, TestPrepareAddress) { + CreateStunPort(kStunAddr1); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + ASSERT_EQ(1U, port()->Candidates().size()); + EXPECT_TRUE(kLocalAddr.EqualIPs(port()->Candidates()[0].address())); + + // TODO: Add IPv6 tests here, once either physicalsocketserver supports + // IPv6, or this test is changed to use VirtualSocketServer. +} + +// Test that we fail properly if we can't get an address. +TEST_F(StunPortTest, TestPrepareAddressFail) { + CreateStunPort(kBadAddr); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + EXPECT_TRUE(error()); + EXPECT_EQ(0U, port()->Candidates().size()); +} + +// Test that we can get an address from a STUN server specified by a hostname. +TEST_F(StunPortTest, TestPrepareAddressHostname) { + CreateStunPort(kStunHostnameAddr); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + ASSERT_EQ(1U, port()->Candidates().size()); + EXPECT_TRUE(kLocalAddr.EqualIPs(port()->Candidates()[0].address())); + EXPECT_EQ(kStunCandidatePriority, port()->Candidates()[0].priority()); +} + +// Test that we handle hostname lookup failures properly. +TEST_F(StunPortTest, TestPrepareAddressHostnameFail) { + CreateStunPort(kBadHostnameAddr); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + EXPECT_TRUE(error()); + EXPECT_EQ(0U, port()->Candidates().size()); +} + +// This test verifies keepalive response messages don't result in +// additional candidate generation. +TEST_F(StunPortTest, TestKeepAliveResponse) { + SetKeepaliveDelay(500); // 500ms of keepalive delay. + CreateStunPort(kStunHostnameAddr); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + ASSERT_EQ(1U, port()->Candidates().size()); + EXPECT_TRUE(kLocalAddr.EqualIPs(port()->Candidates()[0].address())); + // Waiting for 1 seond, which will allow us to process + // response for keepalive binding request. 500 ms is the keepalive delay. + rtc::Thread::Current()->ProcessMessages(1000); + ASSERT_EQ(1U, port()->Candidates().size()); +} + +// Test that a local candidate can be generated using a shared socket. +TEST_F(StunPortTest, TestSharedSocketPrepareAddress) { + CreateSharedStunPort(kStunAddr1); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + ASSERT_EQ(1U, port()->Candidates().size()); + EXPECT_TRUE(kLocalAddr.EqualIPs(port()->Candidates()[0].address())); +} + +// Test that we still a get a local candidate with invalid stun server hostname. +// Also verifing that UDPPort can receive packets when stun address can't be +// resolved. +TEST_F(StunPortTest, TestSharedSocketPrepareAddressInvalidHostname) { + CreateSharedStunPort(kBadHostnameAddr); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + ASSERT_EQ(1U, port()->Candidates().size()); + EXPECT_TRUE(kLocalAddr.EqualIPs(port()->Candidates()[0].address())); + + // Send data to port after it's ready. This is to make sure, UDP port can + // handle data with unresolved stun server address. + std::string data = "some random data, sending to cricket::Port."; + SendData(data.c_str(), data.length()); + // No crash is success. +} + +// Test that the same address is added only once if two STUN servers are in use. +TEST_F(StunPortTest, TestNoDuplicatedAddressWithTwoStunServers) { + ServerAddresses stun_servers; + stun_servers.insert(kStunAddr1); + stun_servers.insert(kStunAddr2); + CreateStunPort(stun_servers); + EXPECT_EQ("stun", port()->Type()); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + EXPECT_EQ(1U, port()->Candidates().size()); + EXPECT_EQ(port()->Candidates()[0].relay_protocol(), ""); +} + +// Test that candidates can be allocated for multiple STUN servers, one of which +// is not reachable. +TEST_F(StunPortTest, TestMultipleStunServersWithBadServer) { + ServerAddresses stun_servers; + stun_servers.insert(kStunAddr1); + stun_servers.insert(kBadAddr); + CreateStunPort(stun_servers); + EXPECT_EQ("stun", port()->Type()); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + EXPECT_EQ(1U, port()->Candidates().size()); +} + +// Test that two candidates are allocated if the two STUN servers return +// different mapped addresses. +TEST_F(StunPortTest, TestTwoCandidatesWithTwoStunServersAcrossNat) { + const SocketAddress kStunMappedAddr1("77.77.77.77", 0); + const SocketAddress kStunMappedAddr2("88.77.77.77", 0); + stun_server_1()->set_fake_stun_addr(kStunMappedAddr1); + stun_server_2()->set_fake_stun_addr(kStunMappedAddr2); + + ServerAddresses stun_servers; + stun_servers.insert(kStunAddr1); + stun_servers.insert(kStunAddr2); + CreateStunPort(stun_servers); + EXPECT_EQ("stun", port()->Type()); + PrepareAddress(); + EXPECT_TRUE_WAIT(done(), kTimeoutMs); + EXPECT_EQ(2U, port()->Candidates().size()); + EXPECT_EQ(port()->Candidates()[0].relay_protocol(), ""); + EXPECT_EQ(port()->Candidates()[1].relay_protocol(), ""); +} diff --git a/webrtc/p2p/base/stunrequest.cc b/webrtc/p2p/base/stunrequest.cc new file mode 100644 index 0000000000..df5614d3cc --- /dev/null +++ b/webrtc/p2p/base/stunrequest.cc @@ -0,0 +1,217 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/stunrequest.h" + +#include <algorithm> +#include "webrtc/base/common.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/stringencode.h" + +namespace cricket { + +const uint32_t MSG_STUN_SEND = 1; + +const int MAX_SENDS = 9; +const int DELAY_UNIT = 100; // 100 milliseconds +const int DELAY_MAX_FACTOR = 16; + +StunRequestManager::StunRequestManager(rtc::Thread* thread) + : thread_(thread) { +} + +StunRequestManager::~StunRequestManager() { + while (requests_.begin() != requests_.end()) { + StunRequest *request = requests_.begin()->second; + requests_.erase(requests_.begin()); + delete request; + } +} + +void StunRequestManager::Send(StunRequest* request) { + SendDelayed(request, 0); +} + +void StunRequestManager::SendDelayed(StunRequest* request, int delay) { + request->set_manager(this); + ASSERT(requests_.find(request->id()) == requests_.end()); + request->set_origin(origin_); + request->Construct(); + requests_[request->id()] = request; + if (delay > 0) { + thread_->PostDelayed(delay, request, MSG_STUN_SEND, NULL); + } else { + thread_->Send(request, MSG_STUN_SEND, NULL); + } +} + +void StunRequestManager::Remove(StunRequest* request) { + ASSERT(request->manager() == this); + RequestMap::iterator iter = requests_.find(request->id()); + if (iter != requests_.end()) { + ASSERT(iter->second == request); + requests_.erase(iter); + thread_->Clear(request); + } +} + +void StunRequestManager::Clear() { + std::vector<StunRequest*> requests; + for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) + requests.push_back(i->second); + + for (uint32_t i = 0; i < requests.size(); ++i) { + // StunRequest destructor calls Remove() which deletes requests + // from |requests_|. + delete requests[i]; + } +} + +bool StunRequestManager::CheckResponse(StunMessage* msg) { + RequestMap::iterator iter = requests_.find(msg->transaction_id()); + if (iter == requests_.end()) { + // TODO(pthatcher): Log unknown responses without being too spammy + // in the logs. + return false; + } + + StunRequest* request = iter->second; + if (msg->type() == GetStunSuccessResponseType(request->type())) { + request->OnResponse(msg); + } else if (msg->type() == GetStunErrorResponseType(request->type())) { + request->OnErrorResponse(msg); + } else { + LOG(LERROR) << "Received response with wrong type: " << msg->type() + << " (expecting " + << GetStunSuccessResponseType(request->type()) << ")"; + return false; + } + + delete request; + return true; +} + +bool StunRequestManager::CheckResponse(const char* data, size_t size) { + // Check the appropriate bytes of the stream to see if they match the + // transaction ID of a response we are expecting. + + if (size < 20) + return false; + + std::string id; + id.append(data + kStunTransactionIdOffset, kStunTransactionIdLength); + + RequestMap::iterator iter = requests_.find(id); + if (iter == requests_.end()) { + // TODO(pthatcher): Log unknown responses without being too spammy + // in the logs. + return false; + } + + // Parse the STUN message and continue processing as usual. + + rtc::ByteBuffer buf(data, size); + rtc::scoped_ptr<StunMessage> response(iter->second->msg_->CreateNew()); + if (!response->Read(&buf)) { + LOG(LS_WARNING) << "Failed to read STUN response " << rtc::hex_encode(id); + return false; + } + + return CheckResponse(response.get()); +} + +StunRequest::StunRequest() + : count_(0), timeout_(false), manager_(0), + msg_(new StunMessage()), tstamp_(0) { + msg_->SetTransactionID( + rtc::CreateRandomString(kStunTransactionIdLength)); +} + +StunRequest::StunRequest(StunMessage* request) + : count_(0), timeout_(false), manager_(0), + msg_(request), tstamp_(0) { + msg_->SetTransactionID( + rtc::CreateRandomString(kStunTransactionIdLength)); +} + +StunRequest::~StunRequest() { + ASSERT(manager_ != NULL); + if (manager_) { + manager_->Remove(this); + manager_->thread_->Clear(this); + } + delete msg_; +} + +void StunRequest::Construct() { + if (msg_->type() == 0) { + if (!origin_.empty()) { + msg_->AddAttribute(new StunByteStringAttribute(STUN_ATTR_ORIGIN, + origin_)); + } + Prepare(msg_); + ASSERT(msg_->type() != 0); + } +} + +int StunRequest::type() { + ASSERT(msg_ != NULL); + return msg_->type(); +} + +const StunMessage* StunRequest::msg() const { + return msg_; +} + +uint32_t StunRequest::Elapsed() const { + return rtc::TimeSince(tstamp_); +} + + +void StunRequest::set_manager(StunRequestManager* manager) { + ASSERT(!manager_); + manager_ = manager; +} + +void StunRequest::OnMessage(rtc::Message* pmsg) { + ASSERT(manager_ != NULL); + ASSERT(pmsg->message_id == MSG_STUN_SEND); + + if (timeout_) { + OnTimeout(); + delete this; + return; + } + + tstamp_ = rtc::Time(); + + rtc::ByteBuffer buf; + msg_->Write(&buf); + manager_->SignalSendPacket(buf.Data(), buf.Length(), this); + + OnSent(); + manager_->thread_->PostDelayed(resend_delay(), this, MSG_STUN_SEND, NULL); +} + +void StunRequest::OnSent() { + count_ += 1; + if (count_ == MAX_SENDS) + timeout_ = true; +} + +int StunRequest::resend_delay() { + if (count_ == 0) { + return 0; + } + return DELAY_UNIT * std::min(1 << (count_-1), DELAY_MAX_FACTOR); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/stunrequest.h b/webrtc/p2p/base/stunrequest.h new file mode 100644 index 0000000000..267b4a1959 --- /dev/null +++ b/webrtc/p2p/base/stunrequest.h @@ -0,0 +1,128 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_STUNREQUEST_H_ +#define WEBRTC_P2P_BASE_STUNREQUEST_H_ + +#include <map> +#include <string> +#include "webrtc/p2p/base/stun.h" +#include "webrtc/base/sigslot.h" +#include "webrtc/base/thread.h" + +namespace cricket { + +class StunRequest; + +// Manages a set of STUN requests, sending and resending until we receive a +// response or determine that the request has timed out. +class StunRequestManager { + public: + StunRequestManager(rtc::Thread* thread); + ~StunRequestManager(); + + // Starts sending the given request (perhaps after a delay). + void Send(StunRequest* request); + void SendDelayed(StunRequest* request, int delay); + + // Removes a stun request that was added previously. This will happen + // automatically when a request succeeds, fails, or times out. + void Remove(StunRequest* request); + + // Removes all stun requests that were added previously. + void Clear(); + + // Determines whether the given message is a response to one of the + // outstanding requests, and if so, processes it appropriately. + bool CheckResponse(StunMessage* msg); + bool CheckResponse(const char* data, size_t size); + + bool empty() { return requests_.empty(); } + + // Set the Origin header for outgoing stun messages. + void set_origin(const std::string& origin) { origin_ = origin; } + + // Raised when there are bytes to be sent. + sigslot::signal3<const void*, size_t, StunRequest*> SignalSendPacket; + + private: + typedef std::map<std::string, StunRequest*> RequestMap; + + rtc::Thread* thread_; + RequestMap requests_; + std::string origin_; + + friend class StunRequest; +}; + +// Represents an individual request to be sent. The STUN message can either be +// constructed beforehand or built on demand. +class StunRequest : public rtc::MessageHandler { + public: + StunRequest(); + StunRequest(StunMessage* request); + virtual ~StunRequest(); + + // Causes our wrapped StunMessage to be Prepared + void Construct(); + + // The manager handling this request (if it has been scheduled for sending). + StunRequestManager* manager() { return manager_; } + + // Returns the transaction ID of this request. + const std::string& id() { return msg_->transaction_id(); } + + // the origin value + const std::string& origin() const { return origin_; } + void set_origin(const std::string& origin) { origin_ = origin; } + + // Returns the STUN type of the request message. + int type(); + + // Returns a const pointer to |msg_|. + const StunMessage* msg() const; + + // Time elapsed since last send (in ms) + uint32_t Elapsed() const; + + protected: + int count_; + bool timeout_; + std::string origin_; + + // Fills in a request object to be sent. Note that request's transaction ID + // will already be set and cannot be changed. + virtual void Prepare(StunMessage* request) {} + + // Called when the message receives a response or times out. + virtual void OnResponse(StunMessage* response) {} + virtual void OnErrorResponse(StunMessage* response) {} + virtual void OnTimeout() {} + // Called when the message is sent. + virtual void OnSent(); + // Returns the next delay for resends. + virtual int resend_delay(); + + private: + void set_manager(StunRequestManager* manager); + + // Handles messages for sending and timeout. + void OnMessage(rtc::Message* pmsg); + + StunRequestManager* manager_; + StunMessage* msg_; + uint32_t tstamp_; + + friend class StunRequestManager; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_STUNREQUEST_H_ diff --git a/webrtc/p2p/base/stunrequest_unittest.cc b/webrtc/p2p/base/stunrequest_unittest.cc new file mode 100644 index 0000000000..8a23834891 --- /dev/null +++ b/webrtc/p2p/base/stunrequest_unittest.cc @@ -0,0 +1,203 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/stunrequest.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/ssladapter.h" +#include "webrtc/base/timeutils.h" + +using namespace cricket; + +class StunRequestTest : public testing::Test, + public sigslot::has_slots<> { + public: + StunRequestTest() + : manager_(rtc::Thread::Current()), + request_count_(0), response_(NULL), + success_(false), failure_(false), timeout_(false) { + manager_.SignalSendPacket.connect(this, &StunRequestTest::OnSendPacket); + } + + void OnSendPacket(const void* data, size_t size, StunRequest* req) { + request_count_++; + } + + void OnResponse(StunMessage* res) { + response_ = res; + success_ = true; + } + void OnErrorResponse(StunMessage* res) { + response_ = res; + failure_ = true; + } + void OnTimeout() { + timeout_ = true; + } + + protected: + static StunMessage* CreateStunMessage(StunMessageType type, + StunMessage* req) { + StunMessage* msg = new StunMessage(); + msg->SetType(type); + if (req) { + msg->SetTransactionID(req->transaction_id()); + } + return msg; + } + static int TotalDelay(int sends) { + int total = 0; + for (int i = 0; i < sends; i++) { + if (i < 4) + total += 100 << i; + else + total += 1600; + } + return total; + } + + StunRequestManager manager_; + int request_count_; + StunMessage* response_; + bool success_; + bool failure_; + bool timeout_; +}; + +// Forwards results to the test class. +class StunRequestThunker : public StunRequest { + public: + StunRequestThunker(StunMessage* msg, StunRequestTest* test) + : StunRequest(msg), test_(test) {} + explicit StunRequestThunker(StunRequestTest* test) : test_(test) {} + private: + virtual void OnResponse(StunMessage* res) { + test_->OnResponse(res); + } + virtual void OnErrorResponse(StunMessage* res) { + test_->OnErrorResponse(res); + } + virtual void OnTimeout() { + test_->OnTimeout(); + } + + virtual void Prepare(StunMessage* request) { + request->SetType(STUN_BINDING_REQUEST); + } + + StunRequestTest* test_; +}; + +// Test handling of a normal binding response. +TEST_F(StunRequestTest, TestSuccess) { + StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + + manager_.Send(new StunRequestThunker(req, this)); + StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, req); + EXPECT_TRUE(manager_.CheckResponse(res)); + + EXPECT_TRUE(response_ == res); + EXPECT_TRUE(success_); + EXPECT_FALSE(failure_); + EXPECT_FALSE(timeout_); + delete res; +} + +// Test handling of an error binding response. +TEST_F(StunRequestTest, TestError) { + StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + + manager_.Send(new StunRequestThunker(req, this)); + StunMessage* res = CreateStunMessage(STUN_BINDING_ERROR_RESPONSE, req); + EXPECT_TRUE(manager_.CheckResponse(res)); + + EXPECT_TRUE(response_ == res); + EXPECT_FALSE(success_); + EXPECT_TRUE(failure_); + EXPECT_FALSE(timeout_); + delete res; +} + +// Test handling of a binding response with the wrong transaction id. +TEST_F(StunRequestTest, TestUnexpected) { + StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + + manager_.Send(new StunRequestThunker(req, this)); + StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, NULL); + EXPECT_FALSE(manager_.CheckResponse(res)); + + EXPECT_TRUE(response_ == NULL); + EXPECT_FALSE(success_); + EXPECT_FALSE(failure_); + EXPECT_FALSE(timeout_); + delete res; +} + +// Test that requests are sent at the right times, and that the 9th request +// (sent at 7900 ms) can be properly replied to. +TEST_F(StunRequestTest, TestBackoff) { + StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + + uint32_t start = rtc::Time(); + manager_.Send(new StunRequestThunker(req, this)); + StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, req); + for (int i = 0; i < 9; ++i) { + while (request_count_ == i) + rtc::Thread::Current()->ProcessMessages(1); + int32_t elapsed = rtc::TimeSince(start); + LOG(LS_INFO) << "STUN request #" << (i + 1) + << " sent at " << elapsed << " ms"; + EXPECT_GE(TotalDelay(i + 1), elapsed); + } + EXPECT_TRUE(manager_.CheckResponse(res)); + + EXPECT_TRUE(response_ == res); + EXPECT_TRUE(success_); + EXPECT_FALSE(failure_); + EXPECT_FALSE(timeout_); + delete res; +} + +// Test that we timeout properly if no response is received in 9500 ms. +TEST_F(StunRequestTest, TestTimeout) { + StunMessage* req = CreateStunMessage(STUN_BINDING_REQUEST, NULL); + StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, req); + + manager_.Send(new StunRequestThunker(req, this)); + rtc::Thread::Current()->ProcessMessages(10000); // > STUN timeout + EXPECT_FALSE(manager_.CheckResponse(res)); + + EXPECT_TRUE(response_ == NULL); + EXPECT_FALSE(success_); + EXPECT_FALSE(failure_); + EXPECT_TRUE(timeout_); + delete res; +} + +// Regression test for specific crash where we receive a response with the +// same id as a request that doesn't have an underlying StunMessage yet. +TEST_F(StunRequestTest, TestNoEmptyRequest) { + StunRequestThunker* request = new StunRequestThunker(this); + + manager_.SendDelayed(request, 100); + + StunMessage dummy_req; + dummy_req.SetTransactionID(request->id()); + StunMessage* res = CreateStunMessage(STUN_BINDING_RESPONSE, &dummy_req); + + EXPECT_TRUE(manager_.CheckResponse(res)); + + EXPECT_TRUE(response_ == res); + EXPECT_TRUE(success_); + EXPECT_FALSE(failure_); + EXPECT_FALSE(timeout_); + delete res; +} diff --git a/webrtc/p2p/base/stunserver.cc b/webrtc/p2p/base/stunserver.cc new file mode 100644 index 0000000000..fbc316bd47 --- /dev/null +++ b/webrtc/p2p/base/stunserver.cc @@ -0,0 +1,99 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/stunserver.h" + +#include "webrtc/base/bytebuffer.h" +#include "webrtc/base/logging.h" + +namespace cricket { + +StunServer::StunServer(rtc::AsyncUDPSocket* socket) : socket_(socket) { + socket_->SignalReadPacket.connect(this, &StunServer::OnPacket); +} + +StunServer::~StunServer() { + socket_->SignalReadPacket.disconnect(this); +} + +void StunServer::OnPacket( + rtc::AsyncPacketSocket* socket, const char* buf, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + // Parse the STUN message; eat any messages that fail to parse. + rtc::ByteBuffer bbuf(buf, size); + StunMessage msg; + if (!msg.Read(&bbuf)) { + return; + } + + // TODO: If unknown non-optional (<= 0x7fff) attributes are found, send a + // 420 "Unknown Attribute" response. + + // Send the message to the appropriate handler function. + switch (msg.type()) { + case STUN_BINDING_REQUEST: + OnBindingRequest(&msg, remote_addr); + break; + + default: + SendErrorResponse(msg, remote_addr, 600, "Operation Not Supported"); + } +} + +void StunServer::OnBindingRequest( + StunMessage* msg, const rtc::SocketAddress& remote_addr) { + StunMessage response; + GetStunBindReqponse(msg, remote_addr, &response); + SendResponse(response, remote_addr); +} + +void StunServer::SendErrorResponse( + const StunMessage& msg, const rtc::SocketAddress& addr, + int error_code, const char* error_desc) { + StunMessage err_msg; + err_msg.SetType(GetStunErrorResponseType(msg.type())); + err_msg.SetTransactionID(msg.transaction_id()); + + StunErrorCodeAttribute* err_code = StunAttribute::CreateErrorCode(); + err_code->SetCode(error_code); + err_code->SetReason(error_desc); + err_msg.AddAttribute(err_code); + + SendResponse(err_msg, addr); +} + +void StunServer::SendResponse( + const StunMessage& msg, const rtc::SocketAddress& addr) { + rtc::ByteBuffer buf; + msg.Write(&buf); + rtc::PacketOptions options; + if (socket_->SendTo(buf.Data(), buf.Length(), addr, options) < 0) + LOG_ERR(LS_ERROR) << "sendto"; +} + +void StunServer::GetStunBindReqponse(StunMessage* request, + const rtc::SocketAddress& remote_addr, + StunMessage* response) const { + response->SetType(STUN_BINDING_RESPONSE); + response->SetTransactionID(request->transaction_id()); + + // Tell the user the address that we received their request from. + StunAddressAttribute* mapped_addr; + if (!request->IsLegacy()) { + mapped_addr = StunAttribute::CreateAddress(STUN_ATTR_MAPPED_ADDRESS); + } else { + mapped_addr = StunAttribute::CreateXorAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + } + mapped_addr->SetAddress(remote_addr); + response->AddAttribute(mapped_addr); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/stunserver.h b/webrtc/p2p/base/stunserver.h new file mode 100644 index 0000000000..a7eeab1544 --- /dev/null +++ b/webrtc/p2p/base/stunserver.h @@ -0,0 +1,66 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_STUNSERVER_H_ +#define WEBRTC_P2P_BASE_STUNSERVER_H_ + +#include "webrtc/p2p/base/stun.h" +#include "webrtc/base/asyncudpsocket.h" +#include "webrtc/base/scoped_ptr.h" + +namespace cricket { + +const int STUN_SERVER_PORT = 3478; + +class StunServer : public sigslot::has_slots<> { + public: + // Creates a STUN server, which will listen on the given socket. + explicit StunServer(rtc::AsyncUDPSocket* socket); + // Removes the STUN server from the socket and deletes the socket. + ~StunServer(); + + protected: + // Slot for AsyncSocket.PacketRead: + void OnPacket( + rtc::AsyncPacketSocket* socket, const char* buf, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time); + + // Handlers for the different types of STUN/TURN requests: + virtual void OnBindingRequest(StunMessage* msg, + const rtc::SocketAddress& addr); + void OnAllocateRequest(StunMessage* msg, + const rtc::SocketAddress& addr); + void OnSharedSecretRequest(StunMessage* msg, + const rtc::SocketAddress& addr); + void OnSendRequest(StunMessage* msg, + const rtc::SocketAddress& addr); + + // Sends an error response to the given message back to the user. + void SendErrorResponse( + const StunMessage& msg, const rtc::SocketAddress& addr, + int error_code, const char* error_desc); + + // Sends the given message to the appropriate destination. + void SendResponse(const StunMessage& msg, + const rtc::SocketAddress& addr); + + // A helper method to compose a STUN binding response. + void GetStunBindReqponse(StunMessage* request, + const rtc::SocketAddress& remote_addr, + StunMessage* response) const; + + private: + rtc::scoped_ptr<rtc::AsyncUDPSocket> socket_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_STUNSERVER_H_ diff --git a/webrtc/p2p/base/stunserver_unittest.cc b/webrtc/p2p/base/stunserver_unittest.cc new file mode 100644 index 0000000000..d405979064 --- /dev/null +++ b/webrtc/p2p/base/stunserver_unittest.cc @@ -0,0 +1,112 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <string> + +#include "webrtc/p2p/base/stunserver.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/physicalsocketserver.h" +#include "webrtc/base/testclient.h" +#include "webrtc/base/thread.h" +#include "webrtc/base/virtualsocketserver.h" + +using namespace cricket; + +static const rtc::SocketAddress server_addr("99.99.99.1", 3478); +static const rtc::SocketAddress client_addr("1.2.3.4", 1234); + +class StunServerTest : public testing::Test { + public: + StunServerTest() + : pss_(new rtc::PhysicalSocketServer), + ss_(new rtc::VirtualSocketServer(pss_.get())), + worker_(ss_.get()) { + } + virtual void SetUp() { + server_.reset(new StunServer( + rtc::AsyncUDPSocket::Create(ss_.get(), server_addr))); + client_.reset(new rtc::TestClient( + rtc::AsyncUDPSocket::Create(ss_.get(), client_addr))); + + worker_.Start(); + } + void Send(const StunMessage& msg) { + rtc::ByteBuffer buf; + msg.Write(&buf); + Send(buf.Data(), static_cast<int>(buf.Length())); + } + void Send(const char* buf, int len) { + client_->SendTo(buf, len, server_addr); + } + bool ReceiveFails() { + return(client_->CheckNoPacket()); + } + StunMessage* Receive() { + StunMessage* msg = NULL; + rtc::TestClient::Packet* packet = + client_->NextPacket(rtc::TestClient::kTimeoutMs); + if (packet) { + rtc::ByteBuffer buf(packet->buf, packet->size); + msg = new StunMessage(); + msg->Read(&buf); + delete packet; + } + return msg; + } + private: + rtc::scoped_ptr<rtc::PhysicalSocketServer> pss_; + rtc::scoped_ptr<rtc::VirtualSocketServer> ss_; + rtc::Thread worker_; + rtc::scoped_ptr<StunServer> server_; + rtc::scoped_ptr<rtc::TestClient> client_; +}; + +// Disable for TSan v2, see +// https://code.google.com/p/webrtc/issues/detail?id=2517 for details. +#if !defined(THREAD_SANITIZER) + +TEST_F(StunServerTest, TestGood) { + StunMessage req; + std::string transaction_id = "0123456789ab"; + req.SetType(STUN_BINDING_REQUEST); + req.SetTransactionID(transaction_id); + Send(req); + + StunMessage* msg = Receive(); + ASSERT_TRUE(msg != NULL); + EXPECT_EQ(STUN_BINDING_RESPONSE, msg->type()); + EXPECT_EQ(req.transaction_id(), msg->transaction_id()); + + const StunAddressAttribute* mapped_addr = + msg->GetAddress(STUN_ATTR_MAPPED_ADDRESS); + EXPECT_TRUE(mapped_addr != NULL); + EXPECT_EQ(1, mapped_addr->family()); + EXPECT_EQ(client_addr.port(), mapped_addr->port()); + if (mapped_addr->ipaddr() != client_addr.ipaddr()) { + LOG(LS_WARNING) << "Warning: mapped IP (" + << mapped_addr->ipaddr() + << ") != local IP (" << client_addr.ipaddr() + << ")"; + } + + delete msg; +} + +#endif // if !defined(THREAD_SANITIZER) + +TEST_F(StunServerTest, TestBad) { + const char* bad = "this is a completely nonsensical message whose only " + "purpose is to make the parser go 'ack'. it doesn't " + "look anything like a normal stun message"; + Send(bad, static_cast<int>(strlen(bad))); + + ASSERT_TRUE(ReceiveFails()); +} diff --git a/webrtc/p2p/base/tcpport.cc b/webrtc/p2p/base/tcpport.cc new file mode 100644 index 0000000000..2590d0aca8 --- /dev/null +++ b/webrtc/p2p/base/tcpport.cc @@ -0,0 +1,507 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +/* + * This is a diagram of how TCP reconnect works for the active side. The + * passive side just waits for an incoming connection. + * + * - Connected: Indicate whether the TCP socket is connected. + * + * - Writable: Whether the stun binding is completed. Sending a data packet + * before stun binding completed will trigger IPC socket layer to shutdown + * the connection. + * + * - PendingTCP: |connection_pending_| indicates whether there is an + * outstanding TCP connection in progress. + * + * - PretendWri: Tracked by |pretending_to_be_writable_|. Marking connection as + * WRITE_TIMEOUT will cause the connection be deleted. Instead, we're + * "pretending" we're still writable for a period of time such that reconnect + * could work. + * + * Data could only be sent in state 3. Sening data during state 2 & 6 will get + * EWOULDBLOCK, 4 & 5 EPIPE. + * + * OS Timeout 7 -------------+ + * +----------------------->|Connected: N | + * | |Writable: N | Timeout + * | Timeout |Connection is |<----------------+ + * | +------------------->|Dead | | + * | | +--------------+ | + * | | ^ | + * | | OnClose | | + * | | +-----------------------+ | | + * | | | | |Timeout | + * | | v | | | + * | 4 +----------+ 5 -----+--+--+ 6 -----+-----+ + * | |Connected: N|Send() or |Connected: N| |Connected: Y| + * | |Writable: Y|Ping() |Writable: Y|OnConnect |Writable: Y| + * | |PendingTCP:N+--------> |PendingTCP:Y+---------> |PendingTCP:N| + * | |PretendWri:Y| |PretendWri:Y| |PretendWri:Y| + * | +-----+------+ +------------+ +---+--+-----+ + * | ^ ^ | | + * | | | OnClose | | + * | | +----------------------------------------------+ | + * | | | + * | | Stun Binding Completed | + * | | | + * | | OnClose | + * | +------------------------------------------------+ | + * | | v + * 1 -----------+ 2 -----------+Stun 3 -----------+ + * |Connected: N| |Connected: Y|Binding |Connected: Y| + * |Writable: N|OnConnect |Writable: N|Completed |Writable: Y| + * |PendingTCP:Y+---------> |PendingTCP:N+--------> |PendingTCP:N| + * |PretendWri:N| |PretendWri:N| |PretendWri:N| + * +------------+ +------------+ +------------+ + * + */ + +#include "webrtc/p2p/base/tcpport.h" + +#include "webrtc/p2p/base/common.h" +#include "webrtc/base/common.h" +#include "webrtc/base/logging.h" + +namespace cricket { + +TCPPort::TCPPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + bool allow_listen) + : Port(thread, + LOCAL_PORT_TYPE, + factory, + network, + ip, + min_port, + max_port, + username, + password), + incoming_only_(false), + allow_listen_(allow_listen), + socket_(NULL), + error_(0) { + // TODO(mallinath) - Set preference value as per RFC 6544. + // http://b/issue?id=7141794 +} + +bool TCPPort::Init() { + if (allow_listen_) { + // Treat failure to create or bind a TCP socket as fatal. This + // should never happen. + socket_ = socket_factory()->CreateServerTcpSocket( + rtc::SocketAddress(ip(), 0), min_port(), max_port(), + false /* ssl */); + if (!socket_) { + LOG_J(LS_ERROR, this) << "TCP socket creation failed."; + return false; + } + socket_->SignalNewConnection.connect(this, &TCPPort::OnNewConnection); + socket_->SignalAddressReady.connect(this, &TCPPort::OnAddressReady); + } + return true; +} + +TCPPort::~TCPPort() { + delete socket_; + std::list<Incoming>::iterator it; + for (it = incoming_.begin(); it != incoming_.end(); ++it) + delete it->socket; + incoming_.clear(); +} + +Connection* TCPPort::CreateConnection(const Candidate& address, + CandidateOrigin origin) { + // We only support TCP protocols + if ((address.protocol() != TCP_PROTOCOL_NAME) && + (address.protocol() != SSLTCP_PROTOCOL_NAME)) { + return NULL; + } + + if (address.tcptype() == TCPTYPE_ACTIVE_STR || + (address.tcptype().empty() && address.address().port() == 0)) { + // It's active only candidate, we should not try to create connections + // for these candidates. + return NULL; + } + + // We can't accept TCP connections incoming on other ports + if (origin == ORIGIN_OTHER_PORT) + return NULL; + + // Check if we are allowed to make outgoing TCP connections + if (incoming_only_ && (origin == ORIGIN_MESSAGE)) + return NULL; + + // We don't know how to act as an ssl server yet + if ((address.protocol() == SSLTCP_PROTOCOL_NAME) && + (origin == ORIGIN_THIS_PORT)) { + return NULL; + } + + if (!IsCompatibleAddress(address.address())) { + return NULL; + } + + TCPConnection* conn = NULL; + if (rtc::AsyncPacketSocket* socket = + GetIncoming(address.address(), true)) { + socket->SignalReadPacket.disconnect(this); + conn = new TCPConnection(this, address, socket); + } else { + conn = new TCPConnection(this, address); + } + AddConnection(conn); + return conn; +} + +void TCPPort::PrepareAddress() { + if (socket_) { + // If socket isn't bound yet the address will be added in + // OnAddressReady(). Socket may be in the CLOSED state if Listen() + // failed, we still want to add the socket address. + LOG(LS_VERBOSE) << "Preparing TCP address, current state: " + << socket_->GetState(); + if (socket_->GetState() == rtc::AsyncPacketSocket::STATE_BOUND || + socket_->GetState() == rtc::AsyncPacketSocket::STATE_CLOSED) + AddAddress(socket_->GetLocalAddress(), socket_->GetLocalAddress(), + rtc::SocketAddress(), TCP_PROTOCOL_NAME, "", + TCPTYPE_PASSIVE_STR, LOCAL_PORT_TYPE, + ICE_TYPE_PREFERENCE_HOST_TCP, 0, true); + } else { + LOG_J(LS_INFO, this) << "Not listening due to firewall restrictions."; + // Note: We still add the address, since otherwise the remote side won't + // recognize our incoming TCP connections. + AddAddress(rtc::SocketAddress(ip(), 0), rtc::SocketAddress(ip(), 0), + rtc::SocketAddress(), TCP_PROTOCOL_NAME, "", TCPTYPE_ACTIVE_STR, + LOCAL_PORT_TYPE, ICE_TYPE_PREFERENCE_HOST_TCP, 0, true); + } +} + +int TCPPort::SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, + bool payload) { + rtc::AsyncPacketSocket * socket = NULL; + TCPConnection* conn = static_cast<TCPConnection*>(GetConnection(addr)); + + // For Connection, this is the code path used by Ping() to establish + // WRITABLE. It has to send through the socket directly as TCPConnection::Send + // checks writability. + if (conn) { + if (!conn->connected()) { + conn->MaybeReconnect(); + return SOCKET_ERROR; + } + socket = conn->socket(); + } else { + socket = GetIncoming(addr); + } + if (!socket) { + LOG_J(LS_ERROR, this) << "Attempted to send to an unknown destination, " + << addr.ToSensitiveString(); + return SOCKET_ERROR; // TODO(tbd): Set error_ + } + + int sent = socket->Send(data, size, options); + if (sent < 0) { + error_ = socket->GetError(); + // Error from this code path for a Connection (instead of from a bare + // socket) will not trigger reconnecting. In theory, this shouldn't matter + // as OnClose should always be called and set connected to false. + LOG_J(LS_ERROR, this) << "TCP send of " << size + << " bytes failed with error " << error_; + } + return sent; +} + +int TCPPort::GetOption(rtc::Socket::Option opt, int* value) { + if (socket_) { + return socket_->GetOption(opt, value); + } else { + return SOCKET_ERROR; + } +} + +int TCPPort::SetOption(rtc::Socket::Option opt, int value) { + if (socket_) { + return socket_->SetOption(opt, value); + } else { + return SOCKET_ERROR; + } +} + +int TCPPort::GetError() { + return error_; +} + +void TCPPort::OnNewConnection(rtc::AsyncPacketSocket* socket, + rtc::AsyncPacketSocket* new_socket) { + ASSERT(socket == socket_); + + Incoming incoming; + incoming.addr = new_socket->GetRemoteAddress(); + incoming.socket = new_socket; + incoming.socket->SignalReadPacket.connect(this, &TCPPort::OnReadPacket); + incoming.socket->SignalReadyToSend.connect(this, &TCPPort::OnReadyToSend); + + LOG_J(LS_VERBOSE, this) << "Accepted connection from " + << incoming.addr.ToSensitiveString(); + incoming_.push_back(incoming); +} + +rtc::AsyncPacketSocket* TCPPort::GetIncoming( + const rtc::SocketAddress& addr, bool remove) { + rtc::AsyncPacketSocket* socket = NULL; + for (std::list<Incoming>::iterator it = incoming_.begin(); + it != incoming_.end(); ++it) { + if (it->addr == addr) { + socket = it->socket; + if (remove) + incoming_.erase(it); + break; + } + } + return socket; +} + +void TCPPort::OnReadPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + Port::OnReadPacket(data, size, remote_addr, PROTO_TCP); +} + +void TCPPort::OnReadyToSend(rtc::AsyncPacketSocket* socket) { + Port::OnReadyToSend(); +} + +void TCPPort::OnAddressReady(rtc::AsyncPacketSocket* socket, + const rtc::SocketAddress& address) { + AddAddress(address, address, rtc::SocketAddress(), TCP_PROTOCOL_NAME, "", + TCPTYPE_PASSIVE_STR, LOCAL_PORT_TYPE, ICE_TYPE_PREFERENCE_HOST_TCP, + 0, true); +} + +TCPConnection::TCPConnection(TCPPort* port, + const Candidate& candidate, + rtc::AsyncPacketSocket* socket) + : Connection(port, 0, candidate), + socket_(socket), + error_(0), + outgoing_(socket == NULL), + connection_pending_(false), + pretending_to_be_writable_(false), + reconnection_timeout_(cricket::CONNECTION_WRITE_CONNECT_TIMEOUT) { + if (outgoing_) { + CreateOutgoingTcpSocket(); + } else { + // Incoming connections should match the network address. + LOG_J(LS_VERBOSE, this) + << "socket ipaddr: " << socket_->GetLocalAddress().ToString() + << ",port() ip:" << port->ip().ToString(); + ASSERT(socket_->GetLocalAddress().ipaddr() == port->ip()); + ConnectSocketSignals(socket); + } +} + +TCPConnection::~TCPConnection() { +} + +int TCPConnection::Send(const void* data, size_t size, + const rtc::PacketOptions& options) { + if (!socket_) { + error_ = ENOTCONN; + return SOCKET_ERROR; + } + + // Sending after OnClose on active side will trigger a reconnect for a + // outgoing connection. Note that the write state is still WRITABLE as we want + // to spend a few seconds attempting a reconnect before saying we're + // unwritable. + if (!connected()) { + MaybeReconnect(); + return SOCKET_ERROR; + } + + // Note that this is important to put this after the previous check to give + // the connection a chance to reconnect. + if (pretending_to_be_writable_ || write_state() != STATE_WRITABLE) { + // TODO: Should STATE_WRITE_TIMEOUT return a non-blocking error? + error_ = EWOULDBLOCK; + return SOCKET_ERROR; + } + sent_packets_total_++; + int sent = socket_->Send(data, size, options); + if (sent < 0) { + sent_packets_discarded_++; + error_ = socket_->GetError(); + } else { + send_rate_tracker_.AddSamples(sent); + } + return sent; +} + +int TCPConnection::GetError() { + return error_; +} + +void TCPConnection::OnConnectionRequestResponse(ConnectionRequest* req, + StunMessage* response) { + // Process the STUN response before we inform upper layer ready to send. + Connection::OnConnectionRequestResponse(req, response); + + // If we're in the state of pretending to be writeable, we should inform the + // upper layer it's ready to send again as previous EWOULDLBLOCK from socket + // would have stopped the outgoing stream. + if (pretending_to_be_writable_) { + Connection::OnReadyToSend(); + } + pretending_to_be_writable_ = false; + ASSERT(write_state() == STATE_WRITABLE); +} + +void TCPConnection::OnConnect(rtc::AsyncPacketSocket* socket) { + ASSERT(socket == socket_); + // Do not use this connection if the socket bound to a different address than + // the one we asked for. This is seen in Chrome, where TCP sockets cannot be + // given a binding address, and the platform is expected to pick the + // correct local address. + const rtc::IPAddress& socket_ip = socket->GetLocalAddress().ipaddr(); + if (socket_ip == port()->ip() || IPIsAny(port()->ip())) { + if (socket_ip == port()->ip()) { + LOG_J(LS_VERBOSE, this) << "Connection established to " + << socket->GetRemoteAddress().ToSensitiveString(); + } else { + LOG(LS_WARNING) << "Socket is bound to a different address:" + << socket->GetLocalAddress().ipaddr().ToString() + << ", rather then the local port:" + << port()->ip().ToString() + << ". Still allowing it since it's any address" + << ", possibly caused by multi-routes being disabled."; + } + set_connected(true); + connection_pending_ = false; + } else { + LOG_J(LS_WARNING, this) << "Dropping connection as TCP socket bound to IP " + << socket_ip.ToSensitiveString() + << ", different from the local candidate IP " + << port()->ip().ToSensitiveString(); + OnClose(socket, 0); + } +} + +void TCPConnection::OnClose(rtc::AsyncPacketSocket* socket, int error) { + ASSERT(socket == socket_); + LOG_J(LS_INFO, this) << "Connection closed with error " << error; + + // Guard against the condition where IPC socket will call OnClose for every + // packet it can't send. + if (connected()) { + set_connected(false); + + // Prevent the connection from being destroyed by redundant SignalClose + // events. + pretending_to_be_writable_ = true; + + // We don't attempt reconnect right here. This is to avoid a case where the + // shutdown is intentional and reconnect is not necessary. We only reconnect + // when the connection is used to Send() or Ping(). + port()->thread()->PostDelayed(reconnection_timeout(), this, + MSG_TCPCONNECTION_DELAYED_ONCLOSE); + } else if (!pretending_to_be_writable_) { + // OnClose could be called when the underneath socket times out during the + // initial connect() (i.e. |pretending_to_be_writable_| is false) . We have + // to manually destroy here as this connection, as never connected, will not + // be scheduled for ping to trigger destroy. + Destroy(); + } +} + +void TCPConnection::OnMessage(rtc::Message* pmsg) { + switch (pmsg->message_id) { + case MSG_TCPCONNECTION_DELAYED_ONCLOSE: + // If this connection can't become connected and writable again in 5 + // seconds, it's time to tear this down. This is the case for the original + // TCP connection on passive side during a reconnect. + if (pretending_to_be_writable_) { + Destroy(); + } + break; + default: + Connection::OnMessage(pmsg); + } +} + +void TCPConnection::MaybeReconnect() { + // Only reconnect for an outgoing TCPConnection when OnClose was signaled and + // no outstanding reconnect is pending. + if (connected() || connection_pending_ || !outgoing_) { + return; + } + + LOG_J(LS_INFO, this) << "TCP Connection with remote is closed, " + << "trying to reconnect"; + + CreateOutgoingTcpSocket(); + error_ = EPIPE; +} + +void TCPConnection::OnReadPacket( + rtc::AsyncPacketSocket* socket, const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + ASSERT(socket == socket_); + Connection::OnReadPacket(data, size, packet_time); +} + +void TCPConnection::OnReadyToSend(rtc::AsyncPacketSocket* socket) { + ASSERT(socket == socket_); + Connection::OnReadyToSend(); +} + +void TCPConnection::CreateOutgoingTcpSocket() { + ASSERT(outgoing_); + // TODO(guoweis): Handle failures here (unlikely since TCP). + int opts = (remote_candidate().protocol() == SSLTCP_PROTOCOL_NAME) + ? rtc::PacketSocketFactory::OPT_SSLTCP + : 0; + socket_.reset(port()->socket_factory()->CreateClientTcpSocket( + rtc::SocketAddress(port()->ip(), 0), remote_candidate().address(), + port()->proxy(), port()->user_agent(), opts)); + if (socket_) { + LOG_J(LS_VERBOSE, this) + << "Connecting from " << socket_->GetLocalAddress().ToSensitiveString() + << " to " << remote_candidate().address().ToSensitiveString(); + set_connected(false); + connection_pending_ = true; + ConnectSocketSignals(socket_.get()); + } else { + LOG_J(LS_WARNING, this) << "Failed to create connection to " + << remote_candidate().address().ToSensitiveString(); + } +} + +void TCPConnection::ConnectSocketSignals(rtc::AsyncPacketSocket* socket) { + if (outgoing_) { + socket->SignalConnect.connect(this, &TCPConnection::OnConnect); + } + socket->SignalReadPacket.connect(this, &TCPConnection::OnReadPacket); + socket->SignalReadyToSend.connect(this, &TCPConnection::OnReadyToSend); + socket->SignalClose.connect(this, &TCPConnection::OnClose); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/tcpport.h b/webrtc/p2p/base/tcpport.h new file mode 100644 index 0000000000..a64c5eeab9 --- /dev/null +++ b/webrtc/p2p/base/tcpport.h @@ -0,0 +1,182 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TCPPORT_H_ +#define WEBRTC_P2P_BASE_TCPPORT_H_ + +#include <list> +#include <string> +#include "webrtc/p2p/base/port.h" +#include "webrtc/base/asyncpacketsocket.h" + +namespace cricket { + +class TCPConnection; + +// Communicates using a local TCP port. +// +// This class is designed to allow subclasses to take advantage of the +// connection management provided by this class. A subclass should take of all +// packet sending and preparation, but when a packet is received, it should +// call this TCPPort::OnReadPacket (3 arg) to dispatch to a connection. +class TCPPort : public Port { + public: + static TCPPort* Create(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + bool allow_listen) { + TCPPort* port = new TCPPort(thread, factory, network, ip, min_port, + max_port, username, password, allow_listen); + if (!port->Init()) { + delete port; + port = NULL; + } + return port; + } + virtual ~TCPPort(); + + virtual Connection* CreateConnection(const Candidate& address, + CandidateOrigin origin); + + virtual void PrepareAddress(); + + virtual int GetOption(rtc::Socket::Option opt, int* value); + virtual int SetOption(rtc::Socket::Option opt, int value); + virtual int GetError(); + + protected: + TCPPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + bool allow_listen); + bool Init(); + + // Handles sending using the local TCP socket. + virtual int SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, + bool payload); + + // Accepts incoming TCP connection. + void OnNewConnection(rtc::AsyncPacketSocket* socket, + rtc::AsyncPacketSocket* new_socket); + + private: + struct Incoming { + rtc::SocketAddress addr; + rtc::AsyncPacketSocket* socket; + }; + + rtc::AsyncPacketSocket* GetIncoming( + const rtc::SocketAddress& addr, bool remove = false); + + // Receives packet signal from the local TCP Socket. + void OnReadPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time); + + void OnReadyToSend(rtc::AsyncPacketSocket* socket); + + void OnAddressReady(rtc::AsyncPacketSocket* socket, + const rtc::SocketAddress& address); + + // TODO: Is this still needed? + bool incoming_only_; + bool allow_listen_; + rtc::AsyncPacketSocket* socket_; + int error_; + std::list<Incoming> incoming_; + + friend class TCPConnection; +}; + +class TCPConnection : public Connection { + public: + // Connection is outgoing unless socket is specified + TCPConnection(TCPPort* port, const Candidate& candidate, + rtc::AsyncPacketSocket* socket = 0); + virtual ~TCPConnection(); + + virtual int Send(const void* data, size_t size, + const rtc::PacketOptions& options); + virtual int GetError(); + + rtc::AsyncPacketSocket* socket() { return socket_.get(); } + + void OnMessage(rtc::Message* pmsg); + + // Allow test cases to overwrite the default timeout period. + int reconnection_timeout() const { return reconnection_timeout_; } + void set_reconnection_timeout(int timeout_in_ms) { + reconnection_timeout_ = timeout_in_ms; + } + + protected: + enum { + MSG_TCPCONNECTION_DELAYED_ONCLOSE = Connection::MSG_FIRST_AVAILABLE, + }; + + // Set waiting_for_stun_binding_complete_ to false to allow data packets in + // addition to what Port::OnConnectionRequestResponse does. + virtual void OnConnectionRequestResponse(ConnectionRequest* req, + StunMessage* response); + + private: + // Helper function to handle the case when Ping or Send fails with error + // related to socket close. + void MaybeReconnect(); + + void CreateOutgoingTcpSocket(); + + void ConnectSocketSignals(rtc::AsyncPacketSocket* socket); + + void OnConnect(rtc::AsyncPacketSocket* socket); + void OnClose(rtc::AsyncPacketSocket* socket, int error); + void OnReadPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time); + void OnReadyToSend(rtc::AsyncPacketSocket* socket); + + rtc::scoped_ptr<rtc::AsyncPacketSocket> socket_; + int error_; + bool outgoing_; + + // Guard against multiple outgoing tcp connection during a reconnect. + bool connection_pending_; + + // Guard against data packets sent when we reconnect a TCP connection. During + // reconnecting, when a new tcp connection has being made, we can't send data + // packets out until the STUN binding is completed (i.e. the write state is + // set to WRITABLE again by Connection::OnConnectionRequestResponse). IPC + // socket, when receiving data packets before that, will trigger OnError which + // will terminate the newly created connection. + bool pretending_to_be_writable_; + + // Allow test case to overwrite the default timeout period. + int reconnection_timeout_; + + friend class TCPPort; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TCPPORT_H_ diff --git a/webrtc/p2p/base/testrelayserver.h b/webrtc/p2p/base/testrelayserver.h new file mode 100644 index 0000000000..87cb9e5dc3 --- /dev/null +++ b/webrtc/p2p/base/testrelayserver.h @@ -0,0 +1,101 @@ +/* + * Copyright 2008 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TESTRELAYSERVER_H_ +#define WEBRTC_P2P_BASE_TESTRELAYSERVER_H_ + +#include "webrtc/p2p/base/relayserver.h" +#include "webrtc/base/asynctcpsocket.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/sigslot.h" +#include "webrtc/base/socketadapters.h" +#include "webrtc/base/thread.h" + +namespace cricket { + +// A test relay server. Useful for unit tests. +class TestRelayServer : public sigslot::has_slots<> { + public: + TestRelayServer(rtc::Thread* thread, + const rtc::SocketAddress& udp_int_addr, + const rtc::SocketAddress& udp_ext_addr, + const rtc::SocketAddress& tcp_int_addr, + const rtc::SocketAddress& tcp_ext_addr, + const rtc::SocketAddress& ssl_int_addr, + const rtc::SocketAddress& ssl_ext_addr) + : server_(thread) { + server_.AddInternalSocket(rtc::AsyncUDPSocket::Create( + thread->socketserver(), udp_int_addr)); + server_.AddExternalSocket(rtc::AsyncUDPSocket::Create( + thread->socketserver(), udp_ext_addr)); + + tcp_int_socket_.reset(CreateListenSocket(thread, tcp_int_addr)); + tcp_ext_socket_.reset(CreateListenSocket(thread, tcp_ext_addr)); + ssl_int_socket_.reset(CreateListenSocket(thread, ssl_int_addr)); + ssl_ext_socket_.reset(CreateListenSocket(thread, ssl_ext_addr)); + } + int GetConnectionCount() const { + return server_.GetConnectionCount(); + } + rtc::SocketAddressPair GetConnection(int connection) const { + return server_.GetConnection(connection); + } + bool HasConnection(const rtc::SocketAddress& address) const { + return server_.HasConnection(address); + } + + private: + rtc::AsyncSocket* CreateListenSocket(rtc::Thread* thread, + const rtc::SocketAddress& addr) { + rtc::AsyncSocket* socket = + thread->socketserver()->CreateAsyncSocket(addr.family(), SOCK_STREAM); + socket->Bind(addr); + socket->Listen(5); + socket->SignalReadEvent.connect(this, &TestRelayServer::OnAccept); + return socket; + } + void OnAccept(rtc::AsyncSocket* socket) { + bool external = (socket == tcp_ext_socket_.get() || + socket == ssl_ext_socket_.get()); + bool ssl = (socket == ssl_int_socket_.get() || + socket == ssl_ext_socket_.get()); + rtc::AsyncSocket* raw_socket = socket->Accept(NULL); + if (raw_socket) { + rtc::AsyncTCPSocket* packet_socket = new rtc::AsyncTCPSocket( + (!ssl) ? raw_socket : + new rtc::AsyncSSLServerSocket(raw_socket), false); + if (!external) { + packet_socket->SignalClose.connect(this, + &TestRelayServer::OnInternalClose); + server_.AddInternalSocket(packet_socket); + } else { + packet_socket->SignalClose.connect(this, + &TestRelayServer::OnExternalClose); + server_.AddExternalSocket(packet_socket); + } + } + } + void OnInternalClose(rtc::AsyncPacketSocket* socket, int error) { + server_.RemoveInternalSocket(socket); + } + void OnExternalClose(rtc::AsyncPacketSocket* socket, int error) { + server_.RemoveExternalSocket(socket); + } + private: + cricket::RelayServer server_; + rtc::scoped_ptr<rtc::AsyncSocket> tcp_int_socket_; + rtc::scoped_ptr<rtc::AsyncSocket> tcp_ext_socket_; + rtc::scoped_ptr<rtc::AsyncSocket> ssl_int_socket_; + rtc::scoped_ptr<rtc::AsyncSocket> ssl_ext_socket_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TESTRELAYSERVER_H_ diff --git a/webrtc/p2p/base/teststunserver.h b/webrtc/p2p/base/teststunserver.h new file mode 100644 index 0000000000..1911a0b739 --- /dev/null +++ b/webrtc/p2p/base/teststunserver.h @@ -0,0 +1,58 @@ +/* + * Copyright 2008 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TESTSTUNSERVER_H_ +#define WEBRTC_P2P_BASE_TESTSTUNSERVER_H_ + +#include "webrtc/p2p/base/stunserver.h" +#include "webrtc/base/socketaddress.h" +#include "webrtc/base/thread.h" + +namespace cricket { + +// A test STUN server. Useful for unit tests. +class TestStunServer : StunServer { + public: + static TestStunServer* Create(rtc::Thread* thread, + const rtc::SocketAddress& addr) { + rtc::AsyncSocket* socket = + thread->socketserver()->CreateAsyncSocket(addr.family(), SOCK_DGRAM); + rtc::AsyncUDPSocket* udp_socket = + rtc::AsyncUDPSocket::Create(socket, addr); + + return new TestStunServer(udp_socket); + } + + // Set a fake STUN address to return to the client. + void set_fake_stun_addr(const rtc::SocketAddress& addr) { + fake_stun_addr_ = addr; + } + + private: + explicit TestStunServer(rtc::AsyncUDPSocket* socket) : StunServer(socket) {} + + void OnBindingRequest(StunMessage* msg, + const rtc::SocketAddress& remote_addr) override { + if (fake_stun_addr_.IsNil()) { + StunServer::OnBindingRequest(msg, remote_addr); + } else { + StunMessage response; + GetStunBindReqponse(msg, fake_stun_addr_, &response); + SendResponse(response, remote_addr); + } + } + + private: + rtc::SocketAddress fake_stun_addr_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TESTSTUNSERVER_H_ diff --git a/webrtc/p2p/base/testturnserver.h b/webrtc/p2p/base/testturnserver.h new file mode 100644 index 0000000000..7be35e5340 --- /dev/null +++ b/webrtc/p2p/base/testturnserver.h @@ -0,0 +1,116 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TESTTURNSERVER_H_ +#define WEBRTC_P2P_BASE_TESTTURNSERVER_H_ + +#include <string> +#include <vector> + +#include "webrtc/p2p/base/basicpacketsocketfactory.h" +#include "webrtc/p2p/base/stun.h" +#include "webrtc/p2p/base/turnserver.h" +#include "webrtc/base/asyncudpsocket.h" +#include "webrtc/base/thread.h" + +namespace cricket { + +static const char kTestRealm[] = "example.org"; +static const char kTestSoftware[] = "TestTurnServer"; + +class TestTurnRedirector : public TurnRedirectInterface { + public: + explicit TestTurnRedirector(const std::vector<rtc::SocketAddress>& addresses) + : alternate_server_addresses_(addresses), + iter_(alternate_server_addresses_.begin()) { + } + + virtual bool ShouldRedirect(const rtc::SocketAddress&, + rtc::SocketAddress* out) { + if (!out || iter_ == alternate_server_addresses_.end()) { + return false; + } + *out = *iter_++; + return true; + } + + private: + const std::vector<rtc::SocketAddress>& alternate_server_addresses_; + std::vector<rtc::SocketAddress>::const_iterator iter_; +}; + +class TestTurnServer : public TurnAuthInterface { + public: + TestTurnServer(rtc::Thread* thread, + const rtc::SocketAddress& udp_int_addr, + const rtc::SocketAddress& udp_ext_addr) + : server_(thread) { + AddInternalSocket(udp_int_addr, cricket::PROTO_UDP); + server_.SetExternalSocketFactory(new rtc::BasicPacketSocketFactory(), + udp_ext_addr); + server_.set_realm(kTestRealm); + server_.set_software(kTestSoftware); + server_.set_auth_hook(this); + } + + void set_enable_otu_nonce(bool enable) { + server_.set_enable_otu_nonce(enable); + } + + TurnServer* server() { return &server_; } + + void set_redirect_hook(TurnRedirectInterface* redirect_hook) { + server_.set_redirect_hook(redirect_hook); + } + + void AddInternalSocket(const rtc::SocketAddress& int_addr, + ProtocolType proto) { + rtc::Thread* thread = rtc::Thread::Current(); + if (proto == cricket::PROTO_UDP) { + server_.AddInternalSocket(rtc::AsyncUDPSocket::Create( + thread->socketserver(), int_addr), proto); + } else if (proto == cricket::PROTO_TCP) { + // For TCP we need to create a server socket which can listen for incoming + // new connections. + rtc::AsyncSocket* socket = + thread->socketserver()->CreateAsyncSocket(SOCK_STREAM); + socket->Bind(int_addr); + socket->Listen(5); + server_.AddInternalServerSocket(socket, proto); + } + } + + // Finds the first allocation in the server allocation map with a source + // ip and port matching the socket address provided. + TurnServerAllocation* FindAllocation(const rtc::SocketAddress& src) { + const TurnServer::AllocationMap& map = server_.allocations(); + for (TurnServer::AllocationMap::const_iterator it = map.begin(); + it != map.end(); ++it) { + if (src == it->first.src()) { + return it->second; + } + } + return NULL; + } + + private: + // For this test server, succeed if the password is the same as the username. + // Obviously, do not use this in a production environment. + virtual bool GetKey(const std::string& username, const std::string& realm, + std::string* key) { + return ComputeStunCredentialHash(username, realm, username, key); + } + + TurnServer server_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TESTTURNSERVER_H_ diff --git a/webrtc/p2p/base/transport.cc b/webrtc/p2p/base/transport.cc new file mode 100644 index 0000000000..2328e4587c --- /dev/null +++ b/webrtc/p2p/base/transport.cc @@ -0,0 +1,391 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <utility> // for std::pair + +#include "webrtc/p2p/base/transport.h" + +#include "webrtc/p2p/base/candidate.h" +#include "webrtc/p2p/base/constants.h" +#include "webrtc/p2p/base/port.h" +#include "webrtc/p2p/base/transportchannelimpl.h" +#include "webrtc/base/bind.h" +#include "webrtc/base/checks.h" +#include "webrtc/base/logging.h" + +namespace cricket { + +static bool VerifyIceParams(const TransportDescription& desc) { + // For legacy protocols. + if (desc.ice_ufrag.empty() && desc.ice_pwd.empty()) + return true; + + if (desc.ice_ufrag.length() < ICE_UFRAG_MIN_LENGTH || + desc.ice_ufrag.length() > ICE_UFRAG_MAX_LENGTH) { + return false; + } + if (desc.ice_pwd.length() < ICE_PWD_MIN_LENGTH || + desc.ice_pwd.length() > ICE_PWD_MAX_LENGTH) { + return false; + } + return true; +} + +bool BadTransportDescription(const std::string& desc, std::string* err_desc) { + if (err_desc) { + *err_desc = desc; + } + LOG(LS_ERROR) << desc; + return false; +} + +bool IceCredentialsChanged(const std::string& old_ufrag, + const std::string& old_pwd, + const std::string& new_ufrag, + const std::string& new_pwd) { + // TODO(jiayl): The standard (RFC 5245 Section 9.1.1.1) says that ICE should + // restart when both the ufrag and password are changed, but we do restart + // when either ufrag or passwrod is changed to keep compatible with GICE. We + // should clean this up when GICE is no longer used. + return (old_ufrag != new_ufrag) || (old_pwd != new_pwd); +} + +static bool IceCredentialsChanged(const TransportDescription& old_desc, + const TransportDescription& new_desc) { + return IceCredentialsChanged(old_desc.ice_ufrag, old_desc.ice_pwd, + new_desc.ice_ufrag, new_desc.ice_pwd); +} + +Transport::Transport(const std::string& name, PortAllocator* allocator) + : name_(name), allocator_(allocator) {} + +Transport::~Transport() { + RTC_DCHECK(channels_destroyed_); +} + +void Transport::SetIceRole(IceRole role) { + ice_role_ = role; + for (const auto& kv : channels_) { + kv.second->SetIceRole(ice_role_); + } +} + +bool Transport::GetRemoteSSLCertificate(rtc::SSLCertificate** cert) { + if (channels_.empty()) { + return false; + } + + auto iter = channels_.begin(); + return iter->second->GetRemoteSSLCertificate(cert); +} + +void Transport::SetIceConfig(const IceConfig& config) { + ice_config_ = config; + for (const auto& kv : channels_) { + kv.second->SetIceConfig(ice_config_); + } +} + +bool Transport::SetLocalTransportDescription( + const TransportDescription& description, + ContentAction action, + std::string* error_desc) { + bool ret = true; + + if (!VerifyIceParams(description)) { + return BadTransportDescription("Invalid ice-ufrag or ice-pwd length", + error_desc); + } + + if (local_description_ && + IceCredentialsChanged(*local_description_, description)) { + IceRole new_ice_role = + (action == CA_OFFER) ? ICEROLE_CONTROLLING : ICEROLE_CONTROLLED; + + // It must be called before ApplyLocalTransportDescription, which may + // trigger an ICE restart and depends on the new ICE role. + SetIceRole(new_ice_role); + } + + local_description_.reset(new TransportDescription(description)); + + for (const auto& kv : channels_) { + ret &= ApplyLocalTransportDescription(kv.second, error_desc); + } + if (!ret) { + return false; + } + + // If PRANSWER/ANSWER is set, we should decide transport protocol type. + if (action == CA_PRANSWER || action == CA_ANSWER) { + ret &= NegotiateTransportDescription(action, error_desc); + } + if (ret) { + local_description_set_ = true; + ConnectChannels(); + } + + return ret; +} + +bool Transport::SetRemoteTransportDescription( + const TransportDescription& description, + ContentAction action, + std::string* error_desc) { + bool ret = true; + + if (!VerifyIceParams(description)) { + return BadTransportDescription("Invalid ice-ufrag or ice-pwd length", + error_desc); + } + + remote_description_.reset(new TransportDescription(description)); + for (const auto& kv : channels_) { + ret &= ApplyRemoteTransportDescription(kv.second, error_desc); + } + + // If PRANSWER/ANSWER is set, we should decide transport protocol type. + if (action == CA_PRANSWER || action == CA_ANSWER) { + ret = NegotiateTransportDescription(CA_OFFER, error_desc); + } + if (ret) { + remote_description_set_ = true; + } + + return ret; +} + +TransportChannelImpl* Transport::CreateChannel(int component) { + TransportChannelImpl* channel; + + // Create the entry if it does not exist. + bool channel_exists = false; + auto iter = channels_.find(component); + if (iter == channels_.end()) { + channel = CreateTransportChannel(component); + channels_.insert(std::pair<int, TransportChannelImpl*>(component, channel)); + } else { + channel = iter->second; + channel_exists = true; + } + + channels_destroyed_ = false; + + if (channel_exists) { + // If this is an existing channel, we should just return it. + return channel; + } + + // Push down our transport state to the new channel. + channel->SetIceRole(ice_role_); + channel->SetIceTiebreaker(tiebreaker_); + channel->SetIceConfig(ice_config_); + // TODO(ronghuawu): Change CreateChannel to be able to return error since + // below Apply**Description calls can fail. + if (local_description_) + ApplyLocalTransportDescription(channel, nullptr); + if (remote_description_) + ApplyRemoteTransportDescription(channel, nullptr); + if (local_description_ && remote_description_) + ApplyNegotiatedTransportDescription(channel, nullptr); + + if (connect_requested_) { + channel->Connect(); + } + return channel; +} + +TransportChannelImpl* Transport::GetChannel(int component) { + auto iter = channels_.find(component); + return (iter != channels_.end()) ? iter->second : nullptr; +} + +bool Transport::HasChannels() { + return !channels_.empty(); +} + +void Transport::DestroyChannel(int component) { + auto iter = channels_.find(component); + if (iter == channels_.end()) + return; + + TransportChannelImpl* channel = iter->second; + channels_.erase(iter); + DestroyTransportChannel(channel); +} + +void Transport::ConnectChannels() { + if (connect_requested_ || channels_.empty()) + return; + + connect_requested_ = true; + + if (!local_description_) { + // TOOD(mallinath) : TransportDescription(TD) shouldn't be generated here. + // As Transport must know TD is offer or answer and cricket::Transport + // doesn't have the capability to decide it. This should be set by the + // Session. + // Session must generate local TD before remote candidates pushed when + // initiate request initiated by the remote. + LOG(LS_INFO) << "Transport::ConnectChannels: No local description has " + << "been set. Will generate one."; + TransportDescription desc( + std::vector<std::string>(), rtc::CreateRandomString(ICE_UFRAG_LENGTH), + rtc::CreateRandomString(ICE_PWD_LENGTH), ICEMODE_FULL, + CONNECTIONROLE_NONE, nullptr, Candidates()); + SetLocalTransportDescription(desc, CA_OFFER, nullptr); + } + + CallChannels(&TransportChannelImpl::Connect); +} + +void Transport::MaybeStartGathering() { + if (connect_requested_) { + CallChannels(&TransportChannelImpl::MaybeStartGathering); + } +} + +void Transport::DestroyAllChannels() { + for (const auto& kv : channels_) { + DestroyTransportChannel(kv.second); + } + channels_.clear(); + channels_destroyed_ = true; +} + +void Transport::CallChannels(TransportChannelFunc func) { + for (const auto& kv : channels_) { + (kv.second->*func)(); + } +} + +bool Transport::VerifyCandidate(const Candidate& cand, std::string* error) { + // No address zero. + if (cand.address().IsNil() || cand.address().IsAnyIP()) { + *error = "candidate has address of zero"; + return false; + } + + // Disallow all ports below 1024, except for 80 and 443 on public addresses. + int port = cand.address().port(); + if (cand.protocol() == TCP_PROTOCOL_NAME && + (cand.tcptype() == TCPTYPE_ACTIVE_STR || port == 0)) { + // Expected for active-only candidates per + // http://tools.ietf.org/html/rfc6544#section-4.5 so no error. + // Libjingle clients emit port 0, in "active" mode. + return true; + } + if (port < 1024) { + if ((port != 80) && (port != 443)) { + *error = "candidate has port below 1024, but not 80 or 443"; + return false; + } + + if (cand.address().IsPrivateIP()) { + *error = "candidate has port of 80 or 443 with private IP address"; + return false; + } + } + + return true; +} + + +bool Transport::GetStats(TransportStats* stats) { + stats->transport_name = name(); + stats->channel_stats.clear(); + for (auto kv : channels_) { + TransportChannelImpl* channel = kv.second; + TransportChannelStats substats; + substats.component = channel->component(); + channel->GetSrtpCryptoSuite(&substats.srtp_cipher); + channel->GetSslCipherSuite(&substats.ssl_cipher); + if (!channel->GetStats(&substats.connection_infos)) { + return false; + } + stats->channel_stats.push_back(substats); + } + return true; +} + +bool Transport::AddRemoteCandidates(const std::vector<Candidate>& candidates, + std::string* error) { + ASSERT(!channels_destroyed_); + // Verify each candidate before passing down to transport layer. + for (const Candidate& cand : candidates) { + if (!VerifyCandidate(cand, error)) { + return false; + } + if (!HasChannel(cand.component())) { + *error = "Candidate has unknown component: " + cand.ToString() + + " for content: " + name(); + return false; + } + } + + for (const Candidate& candidate : candidates) { + TransportChannelImpl* channel = GetChannel(candidate.component()); + if (channel != nullptr) { + channel->AddRemoteCandidate(candidate); + } + } + return true; +} + +bool Transport::ApplyLocalTransportDescription(TransportChannelImpl* ch, + std::string* error_desc) { + ch->SetIceCredentials(local_description_->ice_ufrag, + local_description_->ice_pwd); + return true; +} + +bool Transport::ApplyRemoteTransportDescription(TransportChannelImpl* ch, + std::string* error_desc) { + ch->SetRemoteIceCredentials(remote_description_->ice_ufrag, + remote_description_->ice_pwd); + return true; +} + +bool Transport::ApplyNegotiatedTransportDescription( + TransportChannelImpl* channel, + std::string* error_desc) { + channel->SetRemoteIceMode(remote_ice_mode_); + return true; +} + +bool Transport::NegotiateTransportDescription(ContentAction local_role, + std::string* error_desc) { + // TODO(ekr@rtfm.com): This is ICE-specific stuff. Refactor into + // P2PTransport. + + // If transport is in ICEROLE_CONTROLLED and remote end point supports only + // ice_lite, this local end point should take CONTROLLING role. + if (ice_role_ == ICEROLE_CONTROLLED && + remote_description_->ice_mode == ICEMODE_LITE) { + SetIceRole(ICEROLE_CONTROLLING); + } + + // Update remote ice_mode to all existing channels. + remote_ice_mode_ = remote_description_->ice_mode; + + // Now that we have negotiated everything, push it downward. + // Note that we cache the result so that if we have race conditions + // between future SetRemote/SetLocal invocations and new channel + // creation, we have the negotiation state saved until a new + // negotiation happens. + for (const auto& kv : channels_) { + if (!ApplyNegotiatedTransportDescription(kv.second, error_desc)) { + return false; + } + } + return true; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/transport.h b/webrtc/p2p/base/transport.h new file mode 100644 index 0000000000..955eb42098 --- /dev/null +++ b/webrtc/p2p/base/transport.h @@ -0,0 +1,322 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +// A Transport manages a set of named channels of the same type. +// +// Subclasses choose the appropriate class to instantiate for each channel; +// however, this base class keeps track of the channels by name, watches their +// state changes (in order to update the manager's state), and forwards +// requests to begin connecting or to reset to each of the channels. +// +// On Threading: Transport performs work solely on the worker thread, and so +// its methods should only be called on the worker thread. +// +// Note: Subclasses must call DestroyChannels() in their own destructors. +// It is not possible to do so here because the subclass destructor will +// already have run. + +#ifndef WEBRTC_P2P_BASE_TRANSPORT_H_ +#define WEBRTC_P2P_BASE_TRANSPORT_H_ + +#include <map> +#include <string> +#include <vector> +#include "webrtc/p2p/base/candidate.h" +#include "webrtc/p2p/base/constants.h" +#include "webrtc/p2p/base/sessiondescription.h" +#include "webrtc/p2p/base/transportinfo.h" +#include "webrtc/base/messagequeue.h" +#include "webrtc/base/rtccertificate.h" +#include "webrtc/base/sigslot.h" +#include "webrtc/base/sslstreamadapter.h" + +namespace cricket { + +class PortAllocator; +class TransportChannel; +class TransportChannelImpl; + +typedef std::vector<Candidate> Candidates; + +// TODO(deadbeef): Unify with PeerConnectionInterface::IceConnectionState +// once /talk/ and /webrtc/ are combined, and also switch to ENUM_NAME naming +// style. +enum IceConnectionState { + kIceConnectionConnecting = 0, + kIceConnectionFailed, + kIceConnectionConnected, // Writable, but still checking one or more + // connections + kIceConnectionCompleted, +}; + +enum DtlsTransportState { + // Haven't started negotiating. + DTLS_TRANSPORT_NEW = 0, + // Have started negotiating. + DTLS_TRANSPORT_CONNECTING, + // Negotiated, and has a secure connection. + DTLS_TRANSPORT_CONNECTED, + // Transport is closed. + DTLS_TRANSPORT_CLOSED, + // Failed due to some error in the handshake process. + DTLS_TRANSPORT_FAILED, +}; + +// TODO(deadbeef): Unify with PeerConnectionInterface::IceConnectionState +// once /talk/ and /webrtc/ are combined, and also switch to ENUM_NAME naming +// style. +enum IceGatheringState { + kIceGatheringNew = 0, + kIceGatheringGathering, + kIceGatheringComplete, +}; + +// Stats that we can return about the connections for a transport channel. +// TODO(hta): Rename to ConnectionStats +struct ConnectionInfo { + ConnectionInfo() + : best_connection(false), + writable(false), + receiving(false), + timeout(false), + new_connection(false), + rtt(0), + sent_total_bytes(0), + sent_bytes_second(0), + sent_discarded_packets(0), + sent_total_packets(0), + recv_total_bytes(0), + recv_bytes_second(0), + key(NULL) {} + + bool best_connection; // Is this the best connection we have? + bool writable; // Has this connection received a STUN response? + bool receiving; // Has this connection received anything? + bool timeout; // Has this connection timed out? + bool new_connection; // Is this a newly created connection? + size_t rtt; // The STUN RTT for this connection. + size_t sent_total_bytes; // Total bytes sent on this connection. + size_t sent_bytes_second; // Bps over the last measurement interval. + size_t sent_discarded_packets; // Number of outgoing packets discarded due to + // socket errors. + size_t sent_total_packets; // Number of total outgoing packets attempted for + // sending. + + size_t recv_total_bytes; // Total bytes received on this connection. + size_t recv_bytes_second; // Bps over the last measurement interval. + Candidate local_candidate; // The local candidate for this connection. + Candidate remote_candidate; // The remote candidate for this connection. + void* key; // A static value that identifies this conn. +}; + +// Information about all the connections of a channel. +typedef std::vector<ConnectionInfo> ConnectionInfos; + +// Information about a specific channel +struct TransportChannelStats { + int component = 0; + ConnectionInfos connection_infos; + std::string srtp_cipher; + int ssl_cipher = 0; +}; + +// Information about all the channels of a transport. +// TODO(hta): Consider if a simple vector is as good as a map. +typedef std::vector<TransportChannelStats> TransportChannelStatsList; + +// Information about the stats of a transport. +struct TransportStats { + std::string transport_name; + TransportChannelStatsList channel_stats; +}; + +// Information about ICE configuration. +struct IceConfig { + // The ICE connection receiving timeout value. + int receiving_timeout_ms = -1; + // If true, the most recent port allocator session will keep on running. + bool gather_continually = false; +}; + +bool BadTransportDescription(const std::string& desc, std::string* err_desc); + +bool IceCredentialsChanged(const std::string& old_ufrag, + const std::string& old_pwd, + const std::string& new_ufrag, + const std::string& new_pwd); + +class Transport : public sigslot::has_slots<> { + public: + Transport(const std::string& name, PortAllocator* allocator); + virtual ~Transport(); + + // Returns the name of this transport. + const std::string& name() const { return name_; } + + // Returns the port allocator object for this transport. + PortAllocator* port_allocator() { return allocator_; } + + bool ready_for_remote_candidates() const { + return local_description_set_ && remote_description_set_; + } + + // Returns whether the client has requested the channels to connect. + bool connect_requested() const { return connect_requested_; } + + void SetIceRole(IceRole role); + IceRole ice_role() const { return ice_role_; } + + void SetIceTiebreaker(uint64_t IceTiebreaker) { tiebreaker_ = IceTiebreaker; } + uint64_t IceTiebreaker() { return tiebreaker_; } + + void SetIceConfig(const IceConfig& config); + + // Must be called before applying local session description. + virtual void SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {} + + // Get a copy of the local certificate provided by SetLocalCertificate. + virtual bool GetLocalCertificate( + rtc::scoped_refptr<rtc::RTCCertificate>* certificate) { + return false; + } + + // Get a copy of the remote certificate in use by the specified channel. + bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert); + + // Create, destroy, and lookup the channels of this type by their components. + TransportChannelImpl* CreateChannel(int component); + + TransportChannelImpl* GetChannel(int component); + + bool HasChannel(int component) { + return (NULL != GetChannel(component)); + } + bool HasChannels(); + + void DestroyChannel(int component); + + // Set the local TransportDescription to be used by TransportChannels. + bool SetLocalTransportDescription(const TransportDescription& description, + ContentAction action, + std::string* error_desc); + + // Set the remote TransportDescription to be used by TransportChannels. + bool SetRemoteTransportDescription(const TransportDescription& description, + ContentAction action, + std::string* error_desc); + + // Tells all current and future channels to start connecting. + void ConnectChannels(); + + // Tells channels to start gathering candidates if necessary. + // Should be called after ConnectChannels() has been called at least once, + // which will happen in SetLocalTransportDescription. + void MaybeStartGathering(); + + // Resets all of the channels back to their initial state. They are no + // longer connecting. + void ResetChannels(); + + // Destroys every channel created so far. + void DestroyAllChannels(); + + bool GetStats(TransportStats* stats); + + // Called when one or more candidates are ready from the remote peer. + bool AddRemoteCandidates(const std::vector<Candidate>& candidates, + std::string* error); + + // If candidate is not acceptable, returns false and sets error. + // Call this before calling OnRemoteCandidates. + virtual bool VerifyCandidate(const Candidate& candidate, + std::string* error); + + virtual bool GetSslRole(rtc::SSLRole* ssl_role) const { return false; } + + // Must be called before channel is starting to connect. + virtual bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) { + return false; + } + + protected: + // These are called by Create/DestroyChannel above in order to create or + // destroy the appropriate type of channel. + virtual TransportChannelImpl* CreateTransportChannel(int component) = 0; + virtual void DestroyTransportChannel(TransportChannelImpl* channel) = 0; + + // The current local transport description, for use by derived classes + // when performing transport description negotiation. + const TransportDescription* local_description() const { + return local_description_.get(); + } + + // The current remote transport description, for use by derived classes + // when performing transport description negotiation. + const TransportDescription* remote_description() const { + return remote_description_.get(); + } + + // Pushes down the transport parameters from the local description, such + // as the ICE ufrag and pwd. + // Derived classes can override, but must call the base as well. + virtual bool ApplyLocalTransportDescription(TransportChannelImpl* channel, + std::string* error_desc); + + // Pushes down remote ice credentials from the remote description to the + // transport channel. + virtual bool ApplyRemoteTransportDescription(TransportChannelImpl* ch, + std::string* error_desc); + + // Negotiates the transport parameters based on the current local and remote + // transport description, such as the ICE role to use, and whether DTLS + // should be activated. + // Derived classes can negotiate their specific parameters here, but must call + // the base as well. + virtual bool NegotiateTransportDescription(ContentAction local_role, + std::string* error_desc); + + // Pushes down the transport parameters obtained via negotiation. + // Derived classes can set their specific parameters here, but must call the + // base as well. + virtual bool ApplyNegotiatedTransportDescription( + TransportChannelImpl* channel, + std::string* error_desc); + + private: + // Candidate component => TransportChannelImpl* + typedef std::map<int, TransportChannelImpl*> ChannelMap; + + // Helper function that invokes the given function on every channel. + typedef void (TransportChannelImpl::* TransportChannelFunc)(); + void CallChannels(TransportChannelFunc func); + + const std::string name_; + PortAllocator* const allocator_; + bool channels_destroyed_ = false; + bool connect_requested_ = false; + IceRole ice_role_ = ICEROLE_UNKNOWN; + uint64_t tiebreaker_ = 0; + IceMode remote_ice_mode_ = ICEMODE_FULL; + IceConfig ice_config_; + rtc::scoped_ptr<TransportDescription> local_description_; + rtc::scoped_ptr<TransportDescription> remote_description_; + bool local_description_set_ = false; + bool remote_description_set_ = false; + + ChannelMap channels_; + + RTC_DISALLOW_COPY_AND_ASSIGN(Transport); +}; + + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TRANSPORT_H_ diff --git a/webrtc/p2p/base/transport_unittest.cc b/webrtc/p2p/base/transport_unittest.cc new file mode 100644 index 0000000000..1f66a47c99 --- /dev/null +++ b/webrtc/p2p/base/transport_unittest.cc @@ -0,0 +1,232 @@ +/* + * Copyright 2011 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/base/fakesslidentity.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/network.h" +#include "webrtc/p2p/base/faketransportcontroller.h" +#include "webrtc/p2p/base/p2ptransport.h" + +using cricket::Candidate; +using cricket::Candidates; +using cricket::Transport; +using cricket::FakeTransport; +using cricket::TransportChannel; +using cricket::FakeTransportChannel; +using cricket::IceRole; +using cricket::TransportDescription; +using rtc::SocketAddress; + +static const char kIceUfrag1[] = "TESTICEUFRAG0001"; +static const char kIcePwd1[] = "TESTICEPWD00000000000001"; + +static const char kIceUfrag2[] = "TESTICEUFRAG0002"; +static const char kIcePwd2[] = "TESTICEPWD00000000000002"; + +class TransportTest : public testing::Test, + public sigslot::has_slots<> { + public: + TransportTest() + : transport_(new FakeTransport("test content name")), channel_(NULL) {} + ~TransportTest() { + transport_->DestroyAllChannels(); + } + bool SetupChannel() { + channel_ = CreateChannel(1); + return (channel_ != NULL); + } + FakeTransportChannel* CreateChannel(int component) { + return static_cast<FakeTransportChannel*>( + transport_->CreateChannel(component)); + } + void DestroyChannel() { + transport_->DestroyChannel(1); + channel_ = NULL; + } + + protected: + rtc::scoped_ptr<FakeTransport> transport_; + FakeTransportChannel* channel_; +}; + +// This test verifies channels are created with proper ICE +// role, tiebreaker and remote ice mode and credentials after offer and +// answer negotiations. +TEST_F(TransportTest, TestChannelIceParameters) { + transport_->SetIceRole(cricket::ICEROLE_CONTROLLING); + transport_->SetIceTiebreaker(99U); + cricket::TransportDescription local_desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetLocalTransportDescription(local_desc, + cricket::CA_OFFER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + EXPECT_TRUE(SetupChannel()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + EXPECT_EQ(cricket::ICEMODE_FULL, channel_->remote_ice_mode()); + EXPECT_EQ(kIceUfrag1, channel_->ice_ufrag()); + EXPECT_EQ(kIcePwd1, channel_->ice_pwd()); + + cricket::TransportDescription remote_desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(remote_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + EXPECT_EQ(99U, channel_->IceTiebreaker()); + EXPECT_EQ(cricket::ICEMODE_FULL, channel_->remote_ice_mode()); + // Changing the transport role from CONTROLLING to CONTROLLED. + transport_->SetIceRole(cricket::ICEROLE_CONTROLLED); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, channel_->GetIceRole()); + EXPECT_EQ(cricket::ICEMODE_FULL, channel_->remote_ice_mode()); + EXPECT_EQ(kIceUfrag1, channel_->remote_ice_ufrag()); + EXPECT_EQ(kIcePwd1, channel_->remote_ice_pwd()); +} + +// Verifies that IceCredentialsChanged returns true when either ufrag or pwd +// changed, and false in other cases. +TEST_F(TransportTest, TestIceCredentialsChanged) { + EXPECT_TRUE(cricket::IceCredentialsChanged("u1", "p1", "u2", "p2")); + EXPECT_TRUE(cricket::IceCredentialsChanged("u1", "p1", "u2", "p1")); + EXPECT_TRUE(cricket::IceCredentialsChanged("u1", "p1", "u1", "p2")); + EXPECT_FALSE(cricket::IceCredentialsChanged("u1", "p1", "u1", "p1")); +} + +// This test verifies that the callee's ICE role changes from controlled to +// controlling when the callee triggers an ICE restart. +TEST_F(TransportTest, TestIceControlledToControllingOnIceRestart) { + EXPECT_TRUE(SetupChannel()); + transport_->SetIceRole(cricket::ICEROLE_CONTROLLED); + + cricket::TransportDescription desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(desc, + cricket::CA_OFFER, + NULL)); + ASSERT_TRUE(transport_->SetLocalTransportDescription(desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, transport_->ice_role()); + + cricket::TransportDescription new_local_desc(kIceUfrag2, kIcePwd2); + ASSERT_TRUE(transport_->SetLocalTransportDescription(new_local_desc, + cricket::CA_OFFER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); +} + +// This test verifies that the caller's ICE role changes from controlling to +// controlled when the callee triggers an ICE restart. +TEST_F(TransportTest, TestIceControllingToControlledOnIceRestart) { + EXPECT_TRUE(SetupChannel()); + transport_->SetIceRole(cricket::ICEROLE_CONTROLLING); + + cricket::TransportDescription desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetLocalTransportDescription(desc, + cricket::CA_OFFER, + NULL)); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + + cricket::TransportDescription new_local_desc(kIceUfrag2, kIcePwd2); + ASSERT_TRUE(transport_->SetLocalTransportDescription(new_local_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, transport_->ice_role()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, channel_->GetIceRole()); +} + +// This test verifies that the caller's ICE role is still controlling after the +// callee triggers ICE restart if the callee's ICE mode is LITE. +TEST_F(TransportTest, TestIceControllingOnIceRestartIfRemoteIsIceLite) { + EXPECT_TRUE(SetupChannel()); + transport_->SetIceRole(cricket::ICEROLE_CONTROLLING); + + cricket::TransportDescription desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetLocalTransportDescription(desc, + cricket::CA_OFFER, + NULL)); + + cricket::TransportDescription remote_desc( + std::vector<std::string>(), + kIceUfrag1, kIcePwd1, cricket::ICEMODE_LITE, + cricket::CONNECTIONROLE_NONE, NULL, cricket::Candidates()); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(remote_desc, + cricket::CA_ANSWER, + NULL)); + + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + + cricket::TransportDescription new_local_desc(kIceUfrag2, kIcePwd2); + ASSERT_TRUE(transport_->SetLocalTransportDescription(new_local_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); +} + +// Tests channel role is reversed after receiving ice-lite from remote. +TEST_F(TransportTest, TestSetRemoteIceLiteInOffer) { + transport_->SetIceRole(cricket::ICEROLE_CONTROLLED); + cricket::TransportDescription remote_desc( + std::vector<std::string>(), + kIceUfrag1, kIcePwd1, cricket::ICEMODE_LITE, + cricket::CONNECTIONROLE_ACTPASS, NULL, cricket::Candidates()); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(remote_desc, + cricket::CA_OFFER, + NULL)); + cricket::TransportDescription local_desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetLocalTransportDescription(local_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + EXPECT_TRUE(SetupChannel()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + EXPECT_EQ(cricket::ICEMODE_LITE, channel_->remote_ice_mode()); +} + +// Tests ice-lite in remote answer. +TEST_F(TransportTest, TestSetRemoteIceLiteInAnswer) { + transport_->SetIceRole(cricket::ICEROLE_CONTROLLING); + cricket::TransportDescription local_desc(kIceUfrag1, kIcePwd1); + ASSERT_TRUE(transport_->SetLocalTransportDescription(local_desc, + cricket::CA_OFFER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, transport_->ice_role()); + EXPECT_TRUE(SetupChannel()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + // Channels will be created in ICEFULL_MODE. + EXPECT_EQ(cricket::ICEMODE_FULL, channel_->remote_ice_mode()); + cricket::TransportDescription remote_desc( + std::vector<std::string>(), + kIceUfrag1, kIcePwd1, cricket::ICEMODE_LITE, + cricket::CONNECTIONROLE_NONE, NULL, cricket::Candidates()); + ASSERT_TRUE(transport_->SetRemoteTransportDescription(remote_desc, + cricket::CA_ANSWER, + NULL)); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel_->GetIceRole()); + // After receiving remote description with ICEMODE_LITE, channel should + // have mode set to ICEMODE_LITE. + EXPECT_EQ(cricket::ICEMODE_LITE, channel_->remote_ice_mode()); +} + +TEST_F(TransportTest, TestGetStats) { + EXPECT_TRUE(SetupChannel()); + cricket::TransportStats stats; + EXPECT_TRUE(transport_->GetStats(&stats)); + // Note that this tests the behavior of a FakeTransportChannel. + ASSERT_EQ(1U, stats.channel_stats.size()); + EXPECT_EQ(1, stats.channel_stats[0].component); + transport_->ConnectChannels(); + EXPECT_TRUE(transport_->GetStats(&stats)); + ASSERT_EQ(1U, stats.channel_stats.size()); + EXPECT_EQ(1, stats.channel_stats[0].component); +} + diff --git a/webrtc/p2p/base/transportchannel.cc b/webrtc/p2p/base/transportchannel.cc new file mode 100644 index 0000000000..63d84494e5 --- /dev/null +++ b/webrtc/p2p/base/transportchannel.cc @@ -0,0 +1,57 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <sstream> +#include "webrtc/p2p/base/common.h" +#include "webrtc/p2p/base/transportchannel.h" + +namespace cricket { + +std::string TransportChannel::ToString() const { + const char RECEIVING_ABBREV[2] = { '_', 'R' }; + const char WRITABLE_ABBREV[2] = { '_', 'W' }; + std::stringstream ss; + ss << "Channel[" << transport_name_ << "|" << component_ << "|" + << RECEIVING_ABBREV[receiving_] << WRITABLE_ABBREV[writable_] << "]"; + return ss.str(); +} + +void TransportChannel::set_receiving(bool receiving) { + if (receiving_ == receiving) { + return; + } + receiving_ = receiving; + SignalReceivingState(this); +} + +void TransportChannel::set_writable(bool writable) { + if (writable_ == writable) { + return; + } + LOG_J(LS_VERBOSE, this) << "set_writable from:" << writable_ << " to " + << writable; + writable_ = writable; + if (writable_) { + SignalReadyToSend(this); + } + SignalWritableState(this); +} + +void TransportChannel::set_dtls_state(DtlsTransportState state) { + if (dtls_state_ == state) { + return; + } + LOG_J(LS_VERBOSE, this) << "set_dtls_state from:" << dtls_state_ << " to " + << state; + dtls_state_ = state; + SignalDtlsState(this); +} + +} // namespace cricket diff --git a/webrtc/p2p/base/transportchannel.h b/webrtc/p2p/base/transportchannel.h new file mode 100644 index 0000000000..767a5f68bf --- /dev/null +++ b/webrtc/p2p/base/transportchannel.h @@ -0,0 +1,182 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TRANSPORTCHANNEL_H_ +#define WEBRTC_P2P_BASE_TRANSPORTCHANNEL_H_ + +#include <string> +#include <vector> + +#include "webrtc/p2p/base/candidate.h" +#include "webrtc/p2p/base/transport.h" +#include "webrtc/p2p/base/transportdescription.h" +#include "webrtc/base/asyncpacketsocket.h" +#include "webrtc/base/basictypes.h" +#include "webrtc/base/dscp.h" +#include "webrtc/base/sigslot.h" +#include "webrtc/base/socket.h" +#include "webrtc/base/sslidentity.h" +#include "webrtc/base/sslstreamadapter.h" + +namespace cricket { + +class Candidate; + +// Flags for SendPacket/SignalReadPacket. +enum PacketFlags { + PF_NORMAL = 0x00, // A normal packet. + PF_SRTP_BYPASS = 0x01, // An encrypted SRTP packet; bypass any additional + // crypto provided by the transport (e.g. DTLS) +}; + +// Used to indicate channel's connection state. +enum TransportChannelState { + STATE_INIT, + STATE_CONNECTING, // Will enter this state once a connection is created + STATE_COMPLETED, + STATE_FAILED +}; + +// A TransportChannel represents one logical stream of packets that are sent +// between the two sides of a session. +// TODO(deadbeef): This interface currently represents the unity of an ICE +// transport and a DTLS transport. They need to be separated apart. +class TransportChannel : public sigslot::has_slots<> { + public: + TransportChannel(const std::string& transport_name, int component) + : transport_name_(transport_name), + component_(component), + writable_(false), + receiving_(false) {} + virtual ~TransportChannel() {} + + // TODO(guoweis) - Make this pure virtual once all subclasses of + // TransportChannel have this defined. + virtual TransportChannelState GetState() const { + return TransportChannelState::STATE_CONNECTING; + } + + // TODO(mallinath) - Remove this API, as it's no longer useful. + // Returns the session id of this channel. + virtual const std::string SessionId() const { return std::string(); } + + const std::string& transport_name() const { return transport_name_; } + int component() const { return component_; } + + // Returns the states of this channel. Each time one of these states changes, + // a signal is raised. These states are aggregated by the TransportManager. + bool writable() const { return writable_; } + bool receiving() const { return receiving_; } + DtlsTransportState dtls_state() const { return dtls_state_; } + sigslot::signal1<TransportChannel*> SignalWritableState; + // Emitted when the TransportChannel's ability to send has changed. + sigslot::signal1<TransportChannel*> SignalReadyToSend; + sigslot::signal1<TransportChannel*> SignalReceivingState; + // Emitted when the DtlsTransportState has changed. + sigslot::signal1<TransportChannel*> SignalDtlsState; + + // Attempts to send the given packet. The return value is < 0 on failure. + // TODO: Remove the default argument once channel code is updated. + virtual int SendPacket(const char* data, size_t len, + const rtc::PacketOptions& options, + int flags = 0) = 0; + + // Sets a socket option on this channel. Note that not all options are + // supported by all transport types. + virtual int SetOption(rtc::Socket::Option opt, int value) = 0; + // TODO(pthatcher): Once Chrome's MockTransportChannel implments + // this, remove the default implementation. + virtual bool GetOption(rtc::Socket::Option opt, int* value) { return false; } + + // Returns the most recent error that occurred on this channel. + virtual int GetError() = 0; + + // Returns the current stats for this connection. + virtual bool GetStats(ConnectionInfos* infos) = 0; + + // Is DTLS active? + virtual bool IsDtlsActive() const = 0; + + // Default implementation. + virtual bool GetSslRole(rtc::SSLRole* role) const = 0; + + // Sets up the ciphers to use for DTLS-SRTP. + virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) = 0; + + // Finds out which DTLS-SRTP cipher was negotiated. + // TODO(guoweis): Remove this once all dependencies implement this. + virtual bool GetSrtpCryptoSuite(std::string* cipher) { + return false; + } + + // Finds out which DTLS cipher was negotiated. + // TODO(guoweis): Remove this once all dependencies implement this. + virtual bool GetSslCipherSuite(int* cipher) { return false; } + + // Gets the local RTCCertificate used for DTLS. + virtual rtc::scoped_refptr<rtc::RTCCertificate> + GetLocalCertificate() const = 0; + + // Gets a copy of the remote side's SSL certificate, owned by the caller. + virtual bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const = 0; + + // Allows key material to be extracted for external encryption. + virtual bool ExportKeyingMaterial(const std::string& label, + const uint8_t* context, + size_t context_len, + bool use_context, + uint8_t* result, + size_t result_len) = 0; + + // Signalled each time a packet is received on this channel. + sigslot::signal5<TransportChannel*, const char*, + size_t, const rtc::PacketTime&, int> SignalReadPacket; + + // Signalled each time a packet is sent on this channel. + sigslot::signal2<TransportChannel*, const rtc::SentPacket&> SignalSentPacket; + + // This signal occurs when there is a change in the way that packets are + // being routed, i.e. to a different remote location. The candidate + // indicates where and how we are currently sending media. + sigslot::signal2<TransportChannel*, const Candidate&> SignalRouteChange; + + // Invoked when the channel is being destroyed. + sigslot::signal1<TransportChannel*> SignalDestroyed; + + // Debugging description of this transport channel. + std::string ToString() const; + + protected: + // TODO(honghaiz): Remove this once chromium's unit tests no longer call it. + void set_readable(bool readable) { set_receiving(readable); } + + // Sets the writable state, signaling if necessary. + void set_writable(bool writable); + + // Sets the receiving state, signaling if necessary. + void set_receiving(bool receiving); + + // Sets the DTLS state, signaling if necessary. + void set_dtls_state(DtlsTransportState state); + + private: + // Used mostly for debugging. + std::string transport_name_; + int component_; + bool writable_; + bool receiving_; + DtlsTransportState dtls_state_ = DTLS_TRANSPORT_NEW; + + RTC_DISALLOW_COPY_AND_ASSIGN(TransportChannel); +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TRANSPORTCHANNEL_H_ diff --git a/webrtc/p2p/base/transportchannelimpl.h b/webrtc/p2p/base/transportchannelimpl.h new file mode 100644 index 0000000000..8d4d4bb728 --- /dev/null +++ b/webrtc/p2p/base/transportchannelimpl.h @@ -0,0 +1,112 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TRANSPORTCHANNELIMPL_H_ +#define WEBRTC_P2P_BASE_TRANSPORTCHANNELIMPL_H_ + +#include <string> +#include "webrtc/p2p/base/transport.h" +#include "webrtc/p2p/base/transportchannel.h" + +namespace buzz { class XmlElement; } + +namespace cricket { + +class Candidate; + +// TODO(pthatcher): Remove this once it's no longer used in +// remoting/protocol/libjingle_transport_factory.cc +enum IceProtocolType { + ICEPROTO_RFC5245 // Standard RFC 5245 version of ICE. +}; + +// Base class for real implementations of TransportChannel. This includes some +// methods called only by Transport, which do not need to be exposed to the +// client. +class TransportChannelImpl : public TransportChannel { + public: + explicit TransportChannelImpl(const std::string& transport_name, + int component) + : TransportChannel(transport_name, component) {} + + // Returns the transport that created this channel. + virtual Transport* GetTransport() = 0; + + // For ICE channels. + virtual IceRole GetIceRole() const = 0; + virtual void SetIceRole(IceRole role) = 0; + virtual void SetIceTiebreaker(uint64_t tiebreaker) = 0; + // TODO(pthatcher): Remove this once it's no longer called in + // remoting/protocol/libjingle_transport_factory.cc + virtual void SetIceProtocolType(IceProtocolType type) {} + // SetIceCredentials only need to be implemented by the ICE + // transport channels. Non-ICE transport channels can just ignore. + // The ufrag and pwd should be set before the Connect() is called. + virtual void SetIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) = 0; + // SetRemoteIceCredentials only need to be implemented by the ICE + // transport channels. Non-ICE transport channels can just ignore. + virtual void SetRemoteIceCredentials(const std::string& ice_ufrag, + const std::string& ice_pwd) = 0; + + // SetRemoteIceMode must be implemented only by the ICE transport channels. + virtual void SetRemoteIceMode(IceMode mode) = 0; + + virtual void SetIceConfig(const IceConfig& config) = 0; + + // Begins the process of attempting to make a connection to the other client. + virtual void Connect() = 0; + + // Start gathering candidates if not already started, or if an ICE restart + // occurred. + virtual void MaybeStartGathering() = 0; + + sigslot::signal1<TransportChannelImpl*> SignalGatheringState; + + // Handles sending and receiving of candidates. The Transport + // receives the candidates and may forward them to the relevant + // channel. + // + // Note: Since candidates are delivered asynchronously to the + // channel, they cannot return an error if the message is invalid. + // It is assumed that the Transport will have checked validity + // before forwarding. + sigslot::signal2<TransportChannelImpl*, const Candidate&> + SignalCandidateGathered; + virtual void AddRemoteCandidate(const Candidate& candidate) = 0; + + virtual IceGatheringState gathering_state() const = 0; + + // DTLS methods + virtual bool SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) = 0; + + // Set DTLS Remote fingerprint. Must be after local identity set. + virtual bool SetRemoteFingerprint(const std::string& digest_alg, + const uint8_t* digest, + size_t digest_len) = 0; + + virtual bool SetSslRole(rtc::SSLRole role) = 0; + + // Invoked when there is conflict in the ICE role between local and remote + // agents. + sigslot::signal1<TransportChannelImpl*> SignalRoleConflict; + + // Emitted whenever the number of connections available to the transport + // channel decreases. + sigslot::signal1<TransportChannelImpl*> SignalConnectionRemoved; + + private: + RTC_DISALLOW_COPY_AND_ASSIGN(TransportChannelImpl); +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TRANSPORTCHANNELIMPL_H_ diff --git a/webrtc/p2p/base/transportcontroller.cc b/webrtc/p2p/base/transportcontroller.cc new file mode 100644 index 0000000000..22b827a1a5 --- /dev/null +++ b/webrtc/p2p/base/transportcontroller.cc @@ -0,0 +1,605 @@ +/* + * Copyright 2015 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/transportcontroller.h" + +#include <algorithm> + +#include "webrtc/base/bind.h" +#include "webrtc/base/checks.h" +#include "webrtc/base/thread.h" +#include "webrtc/p2p/base/dtlstransport.h" +#include "webrtc/p2p/base/p2ptransport.h" +#include "webrtc/p2p/base/port.h" + +namespace cricket { + +enum { + MSG_ICECONNECTIONSTATE, + MSG_RECEIVING, + MSG_ICEGATHERINGSTATE, + MSG_CANDIDATESGATHERED, +}; + +struct CandidatesData : public rtc::MessageData { + CandidatesData(const std::string& transport_name, + const Candidates& candidates) + : transport_name(transport_name), candidates(candidates) {} + + std::string transport_name; + Candidates candidates; +}; + +TransportController::TransportController(rtc::Thread* signaling_thread, + rtc::Thread* worker_thread, + PortAllocator* port_allocator) + : signaling_thread_(signaling_thread), + worker_thread_(worker_thread), + port_allocator_(port_allocator) {} + +TransportController::~TransportController() { + worker_thread_->Invoke<void>( + rtc::Bind(&TransportController::DestroyAllTransports_w, this)); + signaling_thread_->Clear(this); +} + +bool TransportController::SetSslMaxProtocolVersion( + rtc::SSLProtocolVersion version) { + return worker_thread_->Invoke<bool>(rtc::Bind( + &TransportController::SetSslMaxProtocolVersion_w, this, version)); +} + +void TransportController::SetIceConfig(const IceConfig& config) { + worker_thread_->Invoke<void>( + rtc::Bind(&TransportController::SetIceConfig_w, this, config)); +} + +void TransportController::SetIceRole(IceRole ice_role) { + worker_thread_->Invoke<void>( + rtc::Bind(&TransportController::SetIceRole_w, this, ice_role)); +} + +bool TransportController::GetSslRole(rtc::SSLRole* role) { + return worker_thread_->Invoke<bool>( + rtc::Bind(&TransportController::GetSslRole_w, this, role)); +} + +bool TransportController::SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { + return worker_thread_->Invoke<bool>(rtc::Bind( + &TransportController::SetLocalCertificate_w, this, certificate)); +} + +bool TransportController::GetLocalCertificate( + const std::string& transport_name, + rtc::scoped_refptr<rtc::RTCCertificate>* certificate) { + return worker_thread_->Invoke<bool>( + rtc::Bind(&TransportController::GetLocalCertificate_w, this, + transport_name, certificate)); +} + +bool TransportController::GetRemoteSSLCertificate( + const std::string& transport_name, + rtc::SSLCertificate** cert) { + return worker_thread_->Invoke<bool>( + rtc::Bind(&TransportController::GetRemoteSSLCertificate_w, this, + transport_name, cert)); +} + +bool TransportController::SetLocalTransportDescription( + const std::string& transport_name, + const TransportDescription& tdesc, + ContentAction action, + std::string* err) { + return worker_thread_->Invoke<bool>( + rtc::Bind(&TransportController::SetLocalTransportDescription_w, this, + transport_name, tdesc, action, err)); +} + +bool TransportController::SetRemoteTransportDescription( + const std::string& transport_name, + const TransportDescription& tdesc, + ContentAction action, + std::string* err) { + return worker_thread_->Invoke<bool>( + rtc::Bind(&TransportController::SetRemoteTransportDescription_w, this, + transport_name, tdesc, action, err)); +} + +void TransportController::MaybeStartGathering() { + worker_thread_->Invoke<void>( + rtc::Bind(&TransportController::MaybeStartGathering_w, this)); +} + +bool TransportController::AddRemoteCandidates(const std::string& transport_name, + const Candidates& candidates, + std::string* err) { + return worker_thread_->Invoke<bool>( + rtc::Bind(&TransportController::AddRemoteCandidates_w, this, + transport_name, candidates, err)); +} + +bool TransportController::ReadyForRemoteCandidates( + const std::string& transport_name) { + return worker_thread_->Invoke<bool>(rtc::Bind( + &TransportController::ReadyForRemoteCandidates_w, this, transport_name)); +} + +bool TransportController::GetStats(const std::string& transport_name, + TransportStats* stats) { + return worker_thread_->Invoke<bool>( + rtc::Bind(&TransportController::GetStats_w, this, transport_name, stats)); +} + +TransportChannel* TransportController::CreateTransportChannel_w( + const std::string& transport_name, + int component) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + auto it = FindChannel_w(transport_name, component); + if (it != channels_.end()) { + // Channel already exists; increment reference count and return. + it->AddRef(); + return it->get(); + } + + // Need to create a new channel. + Transport* transport = GetOrCreateTransport_w(transport_name); + TransportChannelImpl* channel = transport->CreateChannel(component); + channel->SignalWritableState.connect( + this, &TransportController::OnChannelWritableState_w); + channel->SignalReceivingState.connect( + this, &TransportController::OnChannelReceivingState_w); + channel->SignalGatheringState.connect( + this, &TransportController::OnChannelGatheringState_w); + channel->SignalCandidateGathered.connect( + this, &TransportController::OnChannelCandidateGathered_w); + channel->SignalRoleConflict.connect( + this, &TransportController::OnChannelRoleConflict_w); + channel->SignalConnectionRemoved.connect( + this, &TransportController::OnChannelConnectionRemoved_w); + channels_.insert(channels_.end(), RefCountedChannel(channel))->AddRef(); + // Adding a channel could cause aggregate state to change. + UpdateAggregateStates_w(); + return channel; +} + +void TransportController::DestroyTransportChannel_w( + const std::string& transport_name, + int component) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + auto it = FindChannel_w(transport_name, component); + if (it == channels_.end()) { + LOG(LS_WARNING) << "Attempting to delete " << transport_name + << " TransportChannel " << component + << ", which doesn't exist."; + return; + } + + it->DecRef(); + if (it->ref() > 0) { + return; + } + + channels_.erase(it); + Transport* transport = GetTransport_w(transport_name); + transport->DestroyChannel(component); + // Just as we create a Transport when its first channel is created, + // we delete it when its last channel is deleted. + if (!transport->HasChannels()) { + DestroyTransport_w(transport_name); + } + // Removing a channel could cause aggregate state to change. + UpdateAggregateStates_w(); +} + +const rtc::scoped_refptr<rtc::RTCCertificate>& +TransportController::certificate_for_testing() { + return certificate_; +} + +Transport* TransportController::CreateTransport_w( + const std::string& transport_name) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + Transport* transport = new DtlsTransport<P2PTransport>( + transport_name, port_allocator(), certificate_); + return transport; +} + +Transport* TransportController::GetTransport_w( + const std::string& transport_name) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + auto iter = transports_.find(transport_name); + return (iter != transports_.end()) ? iter->second : nullptr; +} + +void TransportController::OnMessage(rtc::Message* pmsg) { + RTC_DCHECK(signaling_thread_->IsCurrent()); + + switch (pmsg->message_id) { + case MSG_ICECONNECTIONSTATE: { + rtc::TypedMessageData<IceConnectionState>* data = + static_cast<rtc::TypedMessageData<IceConnectionState>*>(pmsg->pdata); + SignalConnectionState(data->data()); + delete data; + break; + } + case MSG_RECEIVING: { + rtc::TypedMessageData<bool>* data = + static_cast<rtc::TypedMessageData<bool>*>(pmsg->pdata); + SignalReceiving(data->data()); + delete data; + break; + } + case MSG_ICEGATHERINGSTATE: { + rtc::TypedMessageData<IceGatheringState>* data = + static_cast<rtc::TypedMessageData<IceGatheringState>*>(pmsg->pdata); + SignalGatheringState(data->data()); + delete data; + break; + } + case MSG_CANDIDATESGATHERED: { + CandidatesData* data = static_cast<CandidatesData*>(pmsg->pdata); + SignalCandidatesGathered(data->transport_name, data->candidates); + delete data; + break; + } + default: + ASSERT(false); + } +} + +std::vector<TransportController::RefCountedChannel>::iterator +TransportController::FindChannel_w(const std::string& transport_name, + int component) { + return std::find_if( + channels_.begin(), channels_.end(), + [transport_name, component](const RefCountedChannel& channel) { + return channel->transport_name() == transport_name && + channel->component() == component; + }); +} + +Transport* TransportController::GetOrCreateTransport_w( + const std::string& transport_name) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + Transport* transport = GetTransport_w(transport_name); + if (transport) { + return transport; + } + + transport = CreateTransport_w(transport_name); + // The stuff below happens outside of CreateTransport_w so that unit tests + // can override CreateTransport_w to return a different type of transport. + transport->SetSslMaxProtocolVersion(ssl_max_version_); + transport->SetIceConfig(ice_config_); + transport->SetIceRole(ice_role_); + transport->SetIceTiebreaker(ice_tiebreaker_); + if (certificate_) { + transport->SetLocalCertificate(certificate_); + } + transports_[transport_name] = transport; + + return transport; +} + +void TransportController::DestroyTransport_w( + const std::string& transport_name) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + auto iter = transports_.find(transport_name); + if (iter != transports_.end()) { + delete iter->second; + transports_.erase(transport_name); + } +} + +void TransportController::DestroyAllTransports_w() { + RTC_DCHECK(worker_thread_->IsCurrent()); + + for (const auto& kv : transports_) { + delete kv.second; + } + transports_.clear(); +} + +bool TransportController::SetSslMaxProtocolVersion_w( + rtc::SSLProtocolVersion version) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + // Max SSL version can only be set before transports are created. + if (!transports_.empty()) { + return false; + } + + ssl_max_version_ = version; + return true; +} + +void TransportController::SetIceConfig_w(const IceConfig& config) { + RTC_DCHECK(worker_thread_->IsCurrent()); + ice_config_ = config; + for (const auto& kv : transports_) { + kv.second->SetIceConfig(ice_config_); + } +} + +void TransportController::SetIceRole_w(IceRole ice_role) { + RTC_DCHECK(worker_thread_->IsCurrent()); + ice_role_ = ice_role; + for (const auto& kv : transports_) { + kv.second->SetIceRole(ice_role_); + } +} + +bool TransportController::GetSslRole_w(rtc::SSLRole* role) { + RTC_DCHECK(worker_thread()->IsCurrent()); + + if (transports_.empty()) { + return false; + } + return transports_.begin()->second->GetSslRole(role); +} + +bool TransportController::SetLocalCertificate_w( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + if (certificate_) { + return false; + } + if (!certificate) { + return false; + } + certificate_ = certificate; + + for (const auto& kv : transports_) { + kv.second->SetLocalCertificate(certificate_); + } + return true; +} + +bool TransportController::GetLocalCertificate_w( + const std::string& transport_name, + rtc::scoped_refptr<rtc::RTCCertificate>* certificate) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + Transport* t = GetTransport_w(transport_name); + if (!t) { + return false; + } + + return t->GetLocalCertificate(certificate); +} + +bool TransportController::GetRemoteSSLCertificate_w( + const std::string& transport_name, + rtc::SSLCertificate** cert) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + Transport* t = GetTransport_w(transport_name); + if (!t) { + return false; + } + + return t->GetRemoteSSLCertificate(cert); +} + +bool TransportController::SetLocalTransportDescription_w( + const std::string& transport_name, + const TransportDescription& tdesc, + ContentAction action, + std::string* err) { + RTC_DCHECK(worker_thread()->IsCurrent()); + + Transport* transport = GetTransport_w(transport_name); + if (!transport) { + // If we didn't find a transport, that's not an error; + // it could have been deleted as a result of bundling. + // TODO(deadbeef): Make callers smarter so they won't attempt to set a + // description on a deleted transport. + return true; + } + + return transport->SetLocalTransportDescription(tdesc, action, err); +} + +bool TransportController::SetRemoteTransportDescription_w( + const std::string& transport_name, + const TransportDescription& tdesc, + ContentAction action, + std::string* err) { + RTC_DCHECK(worker_thread()->IsCurrent()); + + Transport* transport = GetTransport_w(transport_name); + if (!transport) { + // If we didn't find a transport, that's not an error; + // it could have been deleted as a result of bundling. + // TODO(deadbeef): Make callers smarter so they won't attempt to set a + // description on a deleted transport. + return true; + } + + return transport->SetRemoteTransportDescription(tdesc, action, err); +} + +void TransportController::MaybeStartGathering_w() { + for (const auto& kv : transports_) { + kv.second->MaybeStartGathering(); + } +} + +bool TransportController::AddRemoteCandidates_w( + const std::string& transport_name, + const Candidates& candidates, + std::string* err) { + RTC_DCHECK(worker_thread()->IsCurrent()); + + Transport* transport = GetTransport_w(transport_name); + if (!transport) { + // If we didn't find a transport, that's not an error; + // it could have been deleted as a result of bundling. + return true; + } + + return transport->AddRemoteCandidates(candidates, err); +} + +bool TransportController::ReadyForRemoteCandidates_w( + const std::string& transport_name) { + RTC_DCHECK(worker_thread()->IsCurrent()); + + Transport* transport = GetTransport_w(transport_name); + if (!transport) { + return false; + } + return transport->ready_for_remote_candidates(); +} + +bool TransportController::GetStats_w(const std::string& transport_name, + TransportStats* stats) { + RTC_DCHECK(worker_thread()->IsCurrent()); + + Transport* transport = GetTransport_w(transport_name); + if (!transport) { + return false; + } + return transport->GetStats(stats); +} + +void TransportController::OnChannelWritableState_w(TransportChannel* channel) { + RTC_DCHECK(worker_thread_->IsCurrent()); + LOG(LS_INFO) << channel->transport_name() << " TransportChannel " + << channel->component() << " writability changed to " + << channel->writable() << "."; + UpdateAggregateStates_w(); +} + +void TransportController::OnChannelReceivingState_w(TransportChannel* channel) { + RTC_DCHECK(worker_thread_->IsCurrent()); + UpdateAggregateStates_w(); +} + +void TransportController::OnChannelGatheringState_w( + TransportChannelImpl* channel) { + RTC_DCHECK(worker_thread_->IsCurrent()); + UpdateAggregateStates_w(); +} + +void TransportController::OnChannelCandidateGathered_w( + TransportChannelImpl* channel, + const Candidate& candidate) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + // We should never signal peer-reflexive candidates. + if (candidate.type() == PRFLX_PORT_TYPE) { + RTC_DCHECK(false); + return; + } + std::vector<Candidate> candidates; + candidates.push_back(candidate); + CandidatesData* data = + new CandidatesData(channel->transport_name(), candidates); + signaling_thread_->Post(this, MSG_CANDIDATESGATHERED, data); +} + +void TransportController::OnChannelRoleConflict_w( + TransportChannelImpl* channel) { + RTC_DCHECK(worker_thread_->IsCurrent()); + + if (ice_role_switch_) { + LOG(LS_WARNING) + << "Repeat of role conflict signal from TransportChannelImpl."; + return; + } + + ice_role_switch_ = true; + IceRole reversed_role = (ice_role_ == ICEROLE_CONTROLLING) + ? ICEROLE_CONTROLLED + : ICEROLE_CONTROLLING; + for (const auto& kv : transports_) { + kv.second->SetIceRole(reversed_role); + } +} + +void TransportController::OnChannelConnectionRemoved_w( + TransportChannelImpl* channel) { + RTC_DCHECK(worker_thread_->IsCurrent()); + LOG(LS_INFO) << channel->transport_name() << " TransportChannel " + << channel->component() + << " connection removed. Check if state is complete."; + UpdateAggregateStates_w(); +} + +void TransportController::UpdateAggregateStates_w() { + RTC_DCHECK(worker_thread_->IsCurrent()); + + IceConnectionState new_connection_state = kIceConnectionConnecting; + IceGatheringState new_gathering_state = kIceGatheringNew; + bool any_receiving = false; + bool any_failed = false; + bool all_connected = !channels_.empty(); + bool all_completed = !channels_.empty(); + bool any_gathering = false; + bool all_done_gathering = !channels_.empty(); + for (const auto& channel : channels_) { + any_receiving = any_receiving || channel->receiving(); + any_failed = any_failed || + channel->GetState() == TransportChannelState::STATE_FAILED; + all_connected = all_connected && channel->writable(); + all_completed = + all_completed && channel->writable() && + channel->GetState() == TransportChannelState::STATE_COMPLETED && + channel->GetIceRole() == ICEROLE_CONTROLLING && + channel->gathering_state() == kIceGatheringComplete; + any_gathering = + any_gathering || channel->gathering_state() != kIceGatheringNew; + all_done_gathering = all_done_gathering && + channel->gathering_state() == kIceGatheringComplete; + } + + if (any_failed) { + new_connection_state = kIceConnectionFailed; + } else if (all_completed) { + new_connection_state = kIceConnectionCompleted; + } else if (all_connected) { + new_connection_state = kIceConnectionConnected; + } + if (connection_state_ != new_connection_state) { + connection_state_ = new_connection_state; + signaling_thread_->Post( + this, MSG_ICECONNECTIONSTATE, + new rtc::TypedMessageData<IceConnectionState>(new_connection_state)); + } + + if (receiving_ != any_receiving) { + receiving_ = any_receiving; + signaling_thread_->Post(this, MSG_RECEIVING, + new rtc::TypedMessageData<bool>(any_receiving)); + } + + if (all_done_gathering) { + new_gathering_state = kIceGatheringComplete; + } else if (any_gathering) { + new_gathering_state = kIceGatheringGathering; + } + if (gathering_state_ != new_gathering_state) { + gathering_state_ = new_gathering_state; + signaling_thread_->Post( + this, MSG_ICEGATHERINGSTATE, + new rtc::TypedMessageData<IceGatheringState>(new_gathering_state)); + } +} + +} // namespace cricket diff --git a/webrtc/p2p/base/transportcontroller.h b/webrtc/p2p/base/transportcontroller.h new file mode 100644 index 0000000000..8d57b460e8 --- /dev/null +++ b/webrtc/p2p/base/transportcontroller.h @@ -0,0 +1,223 @@ +/* + * Copyright 2015 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TRANSPORTCONTROLLER_H_ +#define WEBRTC_P2P_BASE_TRANSPORTCONTROLLER_H_ + +#include <map> +#include <string> +#include <vector> + +#include "webrtc/base/sigslot.h" +#include "webrtc/base/sslstreamadapter.h" +#include "webrtc/p2p/base/candidate.h" +#include "webrtc/p2p/base/transport.h" + +namespace rtc { +class Thread; +} + +namespace cricket { + +class TransportController : public sigslot::has_slots<>, + public rtc::MessageHandler { + public: + TransportController(rtc::Thread* signaling_thread, + rtc::Thread* worker_thread, + PortAllocator* port_allocator); + + virtual ~TransportController(); + + rtc::Thread* signaling_thread() const { return signaling_thread_; } + rtc::Thread* worker_thread() const { return worker_thread_; } + + PortAllocator* port_allocator() const { return port_allocator_; } + + // Can only be set before transports are created. + // TODO(deadbeef): Make this an argument to the constructor once BaseSession + // and WebRtcSession are combined + bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version); + + void SetIceConfig(const IceConfig& config); + void SetIceRole(IceRole ice_role); + + // TODO(deadbeef) - Return role of each transport, as role may differ from + // one another. + // In current implementaion we just return the role of the first transport + // alphabetically. + bool GetSslRole(rtc::SSLRole* role); + + // Specifies the identity to use in this session. + // Can only be called once. + bool SetLocalCertificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate); + bool GetLocalCertificate( + const std::string& transport_name, + rtc::scoped_refptr<rtc::RTCCertificate>* certificate); + // Caller owns returned certificate + bool GetRemoteSSLCertificate(const std::string& transport_name, + rtc::SSLCertificate** cert); + bool SetLocalTransportDescription(const std::string& transport_name, + const TransportDescription& tdesc, + ContentAction action, + std::string* err); + bool SetRemoteTransportDescription(const std::string& transport_name, + const TransportDescription& tdesc, + ContentAction action, + std::string* err); + // Start gathering candidates for any new transports, or transports doing an + // ICE restart. + void MaybeStartGathering(); + bool AddRemoteCandidates(const std::string& transport_name, + const Candidates& candidates, + std::string* err); + bool ReadyForRemoteCandidates(const std::string& transport_name); + bool GetStats(const std::string& transport_name, TransportStats* stats); + + // Creates a channel if it doesn't exist. Otherwise, increments a reference + // count and returns an existing channel. + virtual TransportChannel* CreateTransportChannel_w( + const std::string& transport_name, + int component); + + // Decrements a channel's reference count, and destroys the channel if + // nothing is referencing it. + virtual void DestroyTransportChannel_w(const std::string& transport_name, + int component); + + // All of these signals are fired on the signalling thread. + + // If any transport failed => failed, + // Else if all completed => completed, + // Else if all connected => connected, + // Else => connecting + sigslot::signal1<IceConnectionState> SignalConnectionState; + + // Receiving if any transport is receiving + sigslot::signal1<bool> SignalReceiving; + + // If all transports done gathering => complete, + // Else if any are gathering => gathering, + // Else => new + sigslot::signal1<IceGatheringState> SignalGatheringState; + + // (transport_name, candidates) + sigslot::signal2<const std::string&, const Candidates&> + SignalCandidatesGathered; + + // for unit test + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate_for_testing(); + + protected: + // Protected and virtual so we can override it in unit tests. + virtual Transport* CreateTransport_w(const std::string& transport_name); + + // For unit tests + const std::map<std::string, Transport*>& transports() { return transports_; } + Transport* GetTransport_w(const std::string& transport_name); + + private: + void OnMessage(rtc::Message* pmsg) override; + + // It's the Transport that's currently responsible for creating/destroying + // channels, but the TransportController keeps track of how many external + // objects (BaseChannels) reference each channel. + struct RefCountedChannel { + RefCountedChannel() : impl_(nullptr), ref_(0) {} + explicit RefCountedChannel(TransportChannelImpl* impl) + : impl_(impl), ref_(0) {} + + void AddRef() { ++ref_; } + void DecRef() { + ASSERT(ref_ > 0); + --ref_; + } + int ref() const { return ref_; } + + TransportChannelImpl* get() const { return impl_; } + TransportChannelImpl* operator->() const { return impl_; } + + private: + TransportChannelImpl* impl_; + int ref_; + }; + + std::vector<RefCountedChannel>::iterator FindChannel_w( + const std::string& transport_name, + int component); + + Transport* GetOrCreateTransport_w(const std::string& transport_name); + void DestroyTransport_w(const std::string& transport_name); + void DestroyAllTransports_w(); + + bool SetSslMaxProtocolVersion_w(rtc::SSLProtocolVersion version); + void SetIceConfig_w(const IceConfig& config); + void SetIceRole_w(IceRole ice_role); + bool GetSslRole_w(rtc::SSLRole* role); + bool SetLocalCertificate_w( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate); + bool GetLocalCertificate_w( + const std::string& transport_name, + rtc::scoped_refptr<rtc::RTCCertificate>* certificate); + bool GetRemoteSSLCertificate_w(const std::string& transport_name, + rtc::SSLCertificate** cert); + bool SetLocalTransportDescription_w(const std::string& transport_name, + const TransportDescription& tdesc, + ContentAction action, + std::string* err); + bool SetRemoteTransportDescription_w(const std::string& transport_name, + const TransportDescription& tdesc, + ContentAction action, + std::string* err); + void MaybeStartGathering_w(); + bool AddRemoteCandidates_w(const std::string& transport_name, + const Candidates& candidates, + std::string* err); + bool ReadyForRemoteCandidates_w(const std::string& transport_name); + bool GetStats_w(const std::string& transport_name, TransportStats* stats); + + // Handlers for signals from Transport. + void OnChannelWritableState_w(TransportChannel* channel); + void OnChannelReceivingState_w(TransportChannel* channel); + void OnChannelGatheringState_w(TransportChannelImpl* channel); + void OnChannelCandidateGathered_w(TransportChannelImpl* channel, + const Candidate& candidate); + void OnChannelRoleConflict_w(TransportChannelImpl* channel); + void OnChannelConnectionRemoved_w(TransportChannelImpl* channel); + + void UpdateAggregateStates_w(); + + rtc::Thread* const signaling_thread_ = nullptr; + rtc::Thread* const worker_thread_ = nullptr; + typedef std::map<std::string, Transport*> TransportMap; + TransportMap transports_; + + std::vector<RefCountedChannel> channels_; + + PortAllocator* const port_allocator_ = nullptr; + rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_10; + + // Aggregate state for TransportChannelImpls. + IceConnectionState connection_state_ = kIceConnectionConnecting; + bool receiving_ = false; + IceGatheringState gathering_state_ = kIceGatheringNew; + + // TODO(deadbeef): Move the fields below down to the transports themselves + IceConfig ice_config_; + IceRole ice_role_ = ICEROLE_CONTROLLING; + // Flag which will be set to true after the first role switch + bool ice_role_switch_ = false; + uint64_t ice_tiebreaker_ = rtc::CreateRandomId64(); + rtc::scoped_refptr<rtc::RTCCertificate> certificate_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TRANSPORTCONTROLLER_H_ diff --git a/webrtc/p2p/base/transportcontroller_unittest.cc b/webrtc/p2p/base/transportcontroller_unittest.cc new file mode 100644 index 0000000000..23e4dc8067 --- /dev/null +++ b/webrtc/p2p/base/transportcontroller_unittest.cc @@ -0,0 +1,689 @@ +/* + * Copyright 2015 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <map> + +#include "webrtc/base/fakesslidentity.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/sslidentity.h" +#include "webrtc/base/thread.h" +#include "webrtc/p2p/base/dtlstransportchannel.h" +#include "webrtc/p2p/base/faketransportcontroller.h" +#include "webrtc/p2p/base/p2ptransportchannel.h" +#include "webrtc/p2p/base/portallocator.h" +#include "webrtc/p2p/base/transportcontroller.h" +#include "webrtc/p2p/client/fakeportallocator.h" + +static const int kTimeout = 100; +static const char kIceUfrag1[] = "TESTICEUFRAG0001"; +static const char kIcePwd1[] = "TESTICEPWD00000000000001"; +static const char kIceUfrag2[] = "TESTICEUFRAG0002"; +static const char kIcePwd2[] = "TESTICEPWD00000000000002"; + +using cricket::Candidate; +using cricket::Candidates; +using cricket::FakeTransportChannel; +using cricket::FakeTransportController; +using cricket::IceConnectionState; +using cricket::IceGatheringState; +using cricket::TransportChannel; +using cricket::TransportController; +using cricket::TransportDescription; +using cricket::TransportStats; + +// Only subclassing from FakeTransportController because currently that's the +// only way to have a TransportController with fake TransportChannels. +// +// TODO(deadbeef): Change this once the Transport/TransportChannel class +// heirarchy is cleaned up, and we can pass a "TransportChannelFactory" or +// something similar into TransportController. +typedef FakeTransportController TransportControllerForTest; + +class TransportControllerTest : public testing::Test, + public sigslot::has_slots<> { + public: + TransportControllerTest() + : transport_controller_(new TransportControllerForTest()), + signaling_thread_(rtc::Thread::Current()) { + ConnectTransportControllerSignals(); + } + + void CreateTransportControllerWithWorkerThread() { + if (!worker_thread_) { + worker_thread_.reset(new rtc::Thread()); + worker_thread_->Start(); + } + transport_controller_.reset( + new TransportControllerForTest(worker_thread_.get())); + ConnectTransportControllerSignals(); + } + + void ConnectTransportControllerSignals() { + transport_controller_->SignalConnectionState.connect( + this, &TransportControllerTest::OnConnectionState); + transport_controller_->SignalReceiving.connect( + this, &TransportControllerTest::OnReceiving); + transport_controller_->SignalGatheringState.connect( + this, &TransportControllerTest::OnGatheringState); + transport_controller_->SignalCandidatesGathered.connect( + this, &TransportControllerTest::OnCandidatesGathered); + } + + FakeTransportChannel* CreateChannel(const std::string& content, + int component) { + TransportChannel* channel = + transport_controller_->CreateTransportChannel_w(content, component); + return static_cast<FakeTransportChannel*>(channel); + } + + void DestroyChannel(const std::string& content, int component) { + transport_controller_->DestroyTransportChannel_w(content, component); + } + + Candidate CreateCandidate(int component) { + Candidate c; + c.set_address(rtc::SocketAddress("192.168.1.1", 8000)); + c.set_component(1); + c.set_protocol(cricket::UDP_PROTOCOL_NAME); + c.set_priority(1); + return c; + } + + // Used for thread hopping test. + void CreateChannelsAndCompleteConnectionOnWorkerThread() { + worker_thread_->Invoke<void>(rtc::Bind( + &TransportControllerTest::CreateChannelsAndCompleteConnection_w, this)); + } + + void CreateChannelsAndCompleteConnection_w() { + transport_controller_->SetIceRole(cricket::ICEROLE_CONTROLLING); + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + + TransportDescription local_desc( + std::vector<std::string>(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, + cricket::CONNECTIONROLE_ACTPASS, nullptr, Candidates()); + std::string err; + transport_controller_->SetLocalTransportDescription( + "audio", local_desc, cricket::CA_OFFER, &err); + transport_controller_->SetLocalTransportDescription( + "video", local_desc, cricket::CA_OFFER, &err); + transport_controller_->MaybeStartGathering(); + channel1->SignalCandidateGathered(channel1, CreateCandidate(1)); + channel2->SignalCandidateGathered(channel2, CreateCandidate(1)); + channel1->SetCandidatesGatheringComplete(); + channel2->SetCandidatesGatheringComplete(); + channel1->SetConnectionCount(2); + channel2->SetConnectionCount(2); + channel1->SetReceiving(true); + channel2->SetReceiving(true); + channel1->SetWritable(true); + channel2->SetWritable(true); + channel1->SetConnectionCount(1); + channel2->SetConnectionCount(1); + } + + cricket::IceConfig CreateIceConfig(int receiving_timeout_ms, + bool gather_continually) { + cricket::IceConfig config; + config.receiving_timeout_ms = receiving_timeout_ms; + config.gather_continually = gather_continually; + return config; + } + + protected: + void OnConnectionState(IceConnectionState state) { + if (!signaling_thread_->IsCurrent()) { + signaled_on_non_signaling_thread_ = true; + } + connection_state_ = state; + ++connection_state_signal_count_; + } + + void OnReceiving(bool receiving) { + if (!signaling_thread_->IsCurrent()) { + signaled_on_non_signaling_thread_ = true; + } + receiving_ = receiving; + ++receiving_signal_count_; + } + + void OnGatheringState(IceGatheringState state) { + if (!signaling_thread_->IsCurrent()) { + signaled_on_non_signaling_thread_ = true; + } + gathering_state_ = state; + ++gathering_state_signal_count_; + } + + void OnCandidatesGathered(const std::string& transport_name, + const Candidates& candidates) { + if (!signaling_thread_->IsCurrent()) { + signaled_on_non_signaling_thread_ = true; + } + candidates_[transport_name].insert(candidates_[transport_name].end(), + candidates.begin(), candidates.end()); + ++candidates_signal_count_; + } + + rtc::scoped_ptr<rtc::Thread> worker_thread_; // Not used for most tests. + rtc::scoped_ptr<TransportControllerForTest> transport_controller_; + + // Information received from signals from transport controller. + IceConnectionState connection_state_ = cricket::kIceConnectionConnecting; + bool receiving_ = false; + IceGatheringState gathering_state_ = cricket::kIceGatheringNew; + // transport_name => candidates + std::map<std::string, Candidates> candidates_; + // Counts of each signal emitted. + int connection_state_signal_count_ = 0; + int receiving_signal_count_ = 0; + int gathering_state_signal_count_ = 0; + int candidates_signal_count_ = 0; + + // Used to make sure signals only come on signaling thread. + rtc::Thread* const signaling_thread_ = nullptr; + bool signaled_on_non_signaling_thread_ = false; +}; + +TEST_F(TransportControllerTest, TestSetIceConfig) { + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + + transport_controller_->SetIceConfig(CreateIceConfig(1000, true)); + EXPECT_EQ(1000, channel1->receiving_timeout()); + EXPECT_TRUE(channel1->gather_continually()); + + // Test that value stored in controller is applied to new channels. + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + EXPECT_EQ(1000, channel2->receiving_timeout()); + EXPECT_TRUE(channel2->gather_continually()); +} + +TEST_F(TransportControllerTest, TestSetSslMaxProtocolVersion) { + EXPECT_TRUE(transport_controller_->SetSslMaxProtocolVersion( + rtc::SSL_PROTOCOL_DTLS_12)); + FakeTransportChannel* channel = CreateChannel("audio", 1); + + ASSERT_NE(nullptr, channel); + EXPECT_EQ(rtc::SSL_PROTOCOL_DTLS_12, channel->ssl_max_protocol_version()); + + // Setting max version after transport is created should fail. + EXPECT_FALSE(transport_controller_->SetSslMaxProtocolVersion( + rtc::SSL_PROTOCOL_DTLS_10)); +} + +TEST_F(TransportControllerTest, TestSetIceRole) { + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + + transport_controller_->SetIceRole(cricket::ICEROLE_CONTROLLING); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel1->GetIceRole()); + transport_controller_->SetIceRole(cricket::ICEROLE_CONTROLLED); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, channel1->GetIceRole()); + + // Test that value stored in controller is applied to new channels. + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, channel2->GetIceRole()); +} + +// Test that when one channel encounters a role conflict, the ICE role is +// swapped on every channel. +TEST_F(TransportControllerTest, TestIceRoleConflict) { + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + + transport_controller_->SetIceRole(cricket::ICEROLE_CONTROLLING); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel1->GetIceRole()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLING, channel2->GetIceRole()); + + channel1->SignalRoleConflict(channel1); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, channel1->GetIceRole()); + EXPECT_EQ(cricket::ICEROLE_CONTROLLED, channel2->GetIceRole()); +} + +TEST_F(TransportControllerTest, TestGetSslRole) { + FakeTransportChannel* channel = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel); + ASSERT_TRUE(channel->SetSslRole(rtc::SSL_CLIENT)); + rtc::SSLRole role; + EXPECT_TRUE(transport_controller_->GetSslRole(&role)); + EXPECT_EQ(rtc::SSL_CLIENT, role); +} + +TEST_F(TransportControllerTest, TestSetAndGetLocalCertificate) { + rtc::scoped_refptr<rtc::RTCCertificate> certificate1 = + rtc::RTCCertificate::Create( + rtc::scoped_ptr<rtc::SSLIdentity>( + rtc::SSLIdentity::Generate("session1", rtc::KT_DEFAULT)) + .Pass()); + rtc::scoped_refptr<rtc::RTCCertificate> certificate2 = + rtc::RTCCertificate::Create( + rtc::scoped_ptr<rtc::SSLIdentity>( + rtc::SSLIdentity::Generate("session2", rtc::KT_DEFAULT)) + .Pass()); + rtc::scoped_refptr<rtc::RTCCertificate> returned_certificate; + + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + + EXPECT_TRUE(transport_controller_->SetLocalCertificate(certificate1)); + EXPECT_TRUE(transport_controller_->GetLocalCertificate( + "audio", &returned_certificate)); + EXPECT_EQ(certificate1->identity()->certificate().ToPEMString(), + returned_certificate->identity()->certificate().ToPEMString()); + + // Should fail if called for a nonexistant transport. + EXPECT_FALSE(transport_controller_->GetLocalCertificate( + "video", &returned_certificate)); + + // Test that identity stored in controller is applied to new channels. + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + EXPECT_TRUE(transport_controller_->GetLocalCertificate( + "video", &returned_certificate)); + EXPECT_EQ(certificate1->identity()->certificate().ToPEMString(), + returned_certificate->identity()->certificate().ToPEMString()); + + // Shouldn't be able to change the identity once set. + EXPECT_FALSE(transport_controller_->SetLocalCertificate(certificate2)); +} + +TEST_F(TransportControllerTest, TestGetRemoteSSLCertificate) { + rtc::FakeSSLCertificate fake_certificate("fake_data"); + rtc::scoped_ptr<rtc::SSLCertificate> returned_certificate; + + FakeTransportChannel* channel = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel); + + channel->SetRemoteSSLCertificate(&fake_certificate); + EXPECT_TRUE(transport_controller_->GetRemoteSSLCertificate( + "audio", returned_certificate.accept())); + EXPECT_EQ(fake_certificate.ToPEMString(), + returned_certificate->ToPEMString()); + + // Should fail if called for a nonexistant transport. + EXPECT_FALSE(transport_controller_->GetRemoteSSLCertificate( + "video", returned_certificate.accept())); +} + +TEST_F(TransportControllerTest, TestSetLocalTransportDescription) { + FakeTransportChannel* channel = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel); + TransportDescription local_desc( + std::vector<std::string>(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, + cricket::CONNECTIONROLE_ACTPASS, nullptr, Candidates()); + std::string err; + EXPECT_TRUE(transport_controller_->SetLocalTransportDescription( + "audio", local_desc, cricket::CA_OFFER, &err)); + // Check that ICE ufrag and pwd were propagated to channel. + EXPECT_EQ(kIceUfrag1, channel->ice_ufrag()); + EXPECT_EQ(kIcePwd1, channel->ice_pwd()); + // After setting local description, we should be able to start gathering + // candidates. + transport_controller_->MaybeStartGathering(); + EXPECT_EQ_WAIT(cricket::kIceGatheringGathering, gathering_state_, kTimeout); + EXPECT_EQ(1, gathering_state_signal_count_); +} + +TEST_F(TransportControllerTest, TestSetRemoteTransportDescription) { + FakeTransportChannel* channel = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel); + TransportDescription remote_desc( + std::vector<std::string>(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, + cricket::CONNECTIONROLE_ACTPASS, nullptr, Candidates()); + std::string err; + EXPECT_TRUE(transport_controller_->SetRemoteTransportDescription( + "audio", remote_desc, cricket::CA_OFFER, &err)); + // Check that ICE ufrag and pwd were propagated to channel. + EXPECT_EQ(kIceUfrag1, channel->remote_ice_ufrag()); + EXPECT_EQ(kIcePwd1, channel->remote_ice_pwd()); +} + +TEST_F(TransportControllerTest, TestAddRemoteCandidates) { + FakeTransportChannel* channel = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel); + Candidates candidates; + candidates.push_back(CreateCandidate(1)); + std::string err; + EXPECT_TRUE( + transport_controller_->AddRemoteCandidates("audio", candidates, &err)); + EXPECT_EQ(1U, channel->remote_candidates().size()); +} + +TEST_F(TransportControllerTest, TestReadyForRemoteCandidates) { + FakeTransportChannel* channel = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel); + // We expect to be ready for remote candidates only after local and remote + // descriptions are set. + EXPECT_FALSE(transport_controller_->ReadyForRemoteCandidates("audio")); + + std::string err; + TransportDescription remote_desc( + std::vector<std::string>(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, + cricket::CONNECTIONROLE_ACTPASS, nullptr, Candidates()); + EXPECT_TRUE(transport_controller_->SetRemoteTransportDescription( + "audio", remote_desc, cricket::CA_OFFER, &err)); + EXPECT_FALSE(transport_controller_->ReadyForRemoteCandidates("audio")); + + TransportDescription local_desc( + std::vector<std::string>(), kIceUfrag2, kIcePwd2, cricket::ICEMODE_FULL, + cricket::CONNECTIONROLE_ACTPASS, nullptr, Candidates()); + EXPECT_TRUE(transport_controller_->SetLocalTransportDescription( + "audio", local_desc, cricket::CA_ANSWER, &err)); + EXPECT_TRUE(transport_controller_->ReadyForRemoteCandidates("audio")); +} + +TEST_F(TransportControllerTest, TestGetStats) { + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("audio", 2); + ASSERT_NE(nullptr, channel2); + FakeTransportChannel* channel3 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel3); + + TransportStats stats; + EXPECT_TRUE(transport_controller_->GetStats("audio", &stats)); + EXPECT_EQ("audio", stats.transport_name); + EXPECT_EQ(2U, stats.channel_stats.size()); +} + +// Test that transport gets destroyed when it has no more channels. +TEST_F(TransportControllerTest, TestCreateAndDestroyChannel) { + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel2); + ASSERT_EQ(channel1, channel2); + FakeTransportChannel* channel3 = CreateChannel("audio", 2); + ASSERT_NE(nullptr, channel3); + + // Using GetStats to check if transport is destroyed from an outside class's + // perspective. + TransportStats stats; + EXPECT_TRUE(transport_controller_->GetStats("audio", &stats)); + DestroyChannel("audio", 2); + DestroyChannel("audio", 1); + EXPECT_TRUE(transport_controller_->GetStats("audio", &stats)); + DestroyChannel("audio", 1); + EXPECT_FALSE(transport_controller_->GetStats("audio", &stats)); +} + +TEST_F(TransportControllerTest, TestSignalConnectionStateFailed) { + // Need controlling ICE role to get in failed state. + transport_controller_->SetIceRole(cricket::ICEROLE_CONTROLLING); + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + + // Should signal "failed" if any channel failed; channel is considered failed + // if it previously had a connection but now has none, and gathering is + // complete. + channel1->SetCandidatesGatheringComplete(); + channel1->SetConnectionCount(1); + channel1->SetConnectionCount(0); + EXPECT_EQ_WAIT(cricket::kIceConnectionFailed, connection_state_, kTimeout); + EXPECT_EQ(1, connection_state_signal_count_); +} + +TEST_F(TransportControllerTest, TestSignalConnectionStateConnected) { + transport_controller_->SetIceRole(cricket::ICEROLE_CONTROLLING); + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + FakeTransportChannel* channel3 = CreateChannel("video", 2); + ASSERT_NE(nullptr, channel3); + + // First, have one channel connect, and another fail, to ensure that + // the first channel connecting didn't trigger a "connected" state signal. + // We should only get a signal when all are connected. + channel1->SetConnectionCount(2); + channel1->SetWritable(true); + channel3->SetCandidatesGatheringComplete(); + channel3->SetConnectionCount(1); + channel3->SetConnectionCount(0); + EXPECT_EQ_WAIT(cricket::kIceConnectionFailed, connection_state_, kTimeout); + // Signal count of 1 means that the only signal emitted was "failed". + EXPECT_EQ(1, connection_state_signal_count_); + + // Destroy the failed channel to return to "connecting" state. + DestroyChannel("video", 2); + EXPECT_EQ_WAIT(cricket::kIceConnectionConnecting, connection_state_, + kTimeout); + EXPECT_EQ(2, connection_state_signal_count_); + + // Make the remaining channel reach a connected state. + channel2->SetConnectionCount(2); + channel2->SetWritable(true); + EXPECT_EQ_WAIT(cricket::kIceConnectionConnected, connection_state_, kTimeout); + EXPECT_EQ(3, connection_state_signal_count_); +} + +TEST_F(TransportControllerTest, TestSignalConnectionStateComplete) { + transport_controller_->SetIceRole(cricket::ICEROLE_CONTROLLING); + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + FakeTransportChannel* channel3 = CreateChannel("video", 2); + ASSERT_NE(nullptr, channel3); + + // Similar to above test, but we're now reaching the completed state, which + // means only one connection per FakeTransportChannel. + channel1->SetCandidatesGatheringComplete(); + channel1->SetConnectionCount(1); + channel1->SetWritable(true); + channel3->SetCandidatesGatheringComplete(); + channel3->SetConnectionCount(1); + channel3->SetConnectionCount(0); + EXPECT_EQ_WAIT(cricket::kIceConnectionFailed, connection_state_, kTimeout); + // Signal count of 1 means that the only signal emitted was "failed". + EXPECT_EQ(1, connection_state_signal_count_); + + // Destroy the failed channel to return to "connecting" state. + DestroyChannel("video", 2); + EXPECT_EQ_WAIT(cricket::kIceConnectionConnecting, connection_state_, + kTimeout); + EXPECT_EQ(2, connection_state_signal_count_); + + // Make the remaining channel reach a connected state. + channel2->SetCandidatesGatheringComplete(); + channel2->SetConnectionCount(2); + channel2->SetWritable(true); + EXPECT_EQ_WAIT(cricket::kIceConnectionConnected, connection_state_, kTimeout); + EXPECT_EQ(3, connection_state_signal_count_); + + // Finally, transition to completed state. + channel2->SetConnectionCount(1); + EXPECT_EQ_WAIT(cricket::kIceConnectionCompleted, connection_state_, kTimeout); + EXPECT_EQ(4, connection_state_signal_count_); +} + +// Make sure that if we're "connected" and remove a transport, we stay in the +// "connected" state. +TEST_F(TransportControllerTest, TestDestroyTransportAndStayConnected) { + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + + channel1->SetCandidatesGatheringComplete(); + channel1->SetConnectionCount(2); + channel1->SetWritable(true); + channel2->SetCandidatesGatheringComplete(); + channel2->SetConnectionCount(2); + channel2->SetWritable(true); + EXPECT_EQ_WAIT(cricket::kIceConnectionConnected, connection_state_, kTimeout); + EXPECT_EQ(1, connection_state_signal_count_); + + // Destroy one channel, then "complete" the other one, so we reach + // a known state. + DestroyChannel("video", 1); + channel1->SetConnectionCount(1); + EXPECT_EQ_WAIT(cricket::kIceConnectionCompleted, connection_state_, kTimeout); + // Signal count of 2 means the deletion didn't cause any unexpected signals + EXPECT_EQ(2, connection_state_signal_count_); +} + +// If we destroy the last/only transport, we should simply transition to +// "connecting". +TEST_F(TransportControllerTest, TestDestroyLastTransportWhileConnected) { + FakeTransportChannel* channel = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel); + + channel->SetCandidatesGatheringComplete(); + channel->SetConnectionCount(2); + channel->SetWritable(true); + EXPECT_EQ_WAIT(cricket::kIceConnectionConnected, connection_state_, kTimeout); + EXPECT_EQ(1, connection_state_signal_count_); + + DestroyChannel("audio", 1); + EXPECT_EQ_WAIT(cricket::kIceConnectionConnecting, connection_state_, + kTimeout); + // Signal count of 2 means the deletion didn't cause any unexpected signals + EXPECT_EQ(2, connection_state_signal_count_); +} + +TEST_F(TransportControllerTest, TestSignalReceiving) { + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + + // Should signal receiving as soon as any channel is receiving. + channel1->SetReceiving(true); + EXPECT_TRUE_WAIT(receiving_, kTimeout); + EXPECT_EQ(1, receiving_signal_count_); + + channel2->SetReceiving(true); + channel1->SetReceiving(false); + channel2->SetReceiving(false); + EXPECT_TRUE_WAIT(!receiving_, kTimeout); + EXPECT_EQ(2, receiving_signal_count_); +} + +TEST_F(TransportControllerTest, TestSignalGatheringStateGathering) { + FakeTransportChannel* channel = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel); + channel->Connect(); + channel->MaybeStartGathering(); + // Should be in the gathering state as soon as any transport starts gathering. + EXPECT_EQ_WAIT(cricket::kIceGatheringGathering, gathering_state_, kTimeout); + EXPECT_EQ(1, gathering_state_signal_count_); +} + +TEST_F(TransportControllerTest, TestSignalGatheringStateComplete) { + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + FakeTransportChannel* channel3 = CreateChannel("data", 1); + ASSERT_NE(nullptr, channel3); + + channel3->Connect(); + channel3->MaybeStartGathering(); + EXPECT_EQ_WAIT(cricket::kIceGatheringGathering, gathering_state_, kTimeout); + EXPECT_EQ(1, gathering_state_signal_count_); + + // Have one channel finish gathering, then destroy it, to make sure gathering + // completion wasn't signalled if only one transport finished gathering. + channel3->SetCandidatesGatheringComplete(); + DestroyChannel("data", 1); + EXPECT_EQ_WAIT(cricket::kIceGatheringNew, gathering_state_, kTimeout); + EXPECT_EQ(2, gathering_state_signal_count_); + + // Make remaining channels start and then finish gathering. + channel1->Connect(); + channel1->MaybeStartGathering(); + channel2->Connect(); + channel2->MaybeStartGathering(); + EXPECT_EQ_WAIT(cricket::kIceGatheringGathering, gathering_state_, kTimeout); + EXPECT_EQ(3, gathering_state_signal_count_); + + channel1->SetCandidatesGatheringComplete(); + channel2->SetCandidatesGatheringComplete(); + EXPECT_EQ_WAIT(cricket::kIceGatheringComplete, gathering_state_, kTimeout); + EXPECT_EQ(4, gathering_state_signal_count_); +} + +// Test that when the last transport that hasn't finished connecting and/or +// gathering is destroyed, the aggregate state jumps to "completed". This can +// happen if, for example, we have an audio and video transport, the audio +// transport completes, then we start bundling video on the audio transport. +TEST_F(TransportControllerTest, + TestSignalingWhenLastIncompleteTransportDestroyed) { + transport_controller_->SetIceRole(cricket::ICEROLE_CONTROLLING); + FakeTransportChannel* channel1 = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel1); + FakeTransportChannel* channel2 = CreateChannel("video", 1); + ASSERT_NE(nullptr, channel2); + + channel1->SetCandidatesGatheringComplete(); + EXPECT_EQ_WAIT(cricket::kIceGatheringGathering, gathering_state_, kTimeout); + EXPECT_EQ(1, gathering_state_signal_count_); + + channel1->SetConnectionCount(1); + channel1->SetWritable(true); + DestroyChannel("video", 1); + EXPECT_EQ_WAIT(cricket::kIceConnectionCompleted, connection_state_, kTimeout); + EXPECT_EQ(1, connection_state_signal_count_); + EXPECT_EQ_WAIT(cricket::kIceGatheringComplete, gathering_state_, kTimeout); + EXPECT_EQ(2, gathering_state_signal_count_); +} + +TEST_F(TransportControllerTest, TestSignalCandidatesGathered) { + FakeTransportChannel* channel = CreateChannel("audio", 1); + ASSERT_NE(nullptr, channel); + + // Transport won't signal candidates until it has a local description. + TransportDescription local_desc( + std::vector<std::string>(), kIceUfrag1, kIcePwd1, cricket::ICEMODE_FULL, + cricket::CONNECTIONROLE_ACTPASS, nullptr, Candidates()); + std::string err; + EXPECT_TRUE(transport_controller_->SetLocalTransportDescription( + "audio", local_desc, cricket::CA_OFFER, &err)); + transport_controller_->MaybeStartGathering(); + + channel->SignalCandidateGathered(channel, CreateCandidate(1)); + EXPECT_EQ_WAIT(1, candidates_signal_count_, kTimeout); + EXPECT_EQ(1U, candidates_["audio"].size()); +} + +TEST_F(TransportControllerTest, TestSignalingOccursOnSignalingThread) { + CreateTransportControllerWithWorkerThread(); + CreateChannelsAndCompleteConnectionOnWorkerThread(); + + // connecting --> connected --> completed + EXPECT_EQ_WAIT(cricket::kIceConnectionCompleted, connection_state_, kTimeout); + EXPECT_EQ(2, connection_state_signal_count_); + + EXPECT_TRUE_WAIT(receiving_, kTimeout); + EXPECT_EQ(1, receiving_signal_count_); + + // new --> gathering --> complete + EXPECT_EQ_WAIT(cricket::kIceGatheringComplete, gathering_state_, kTimeout); + EXPECT_EQ(2, gathering_state_signal_count_); + + EXPECT_EQ_WAIT(1U, candidates_["audio"].size(), kTimeout); + EXPECT_EQ_WAIT(1U, candidates_["video"].size(), kTimeout); + EXPECT_EQ(2, candidates_signal_count_); + + EXPECT_TRUE(!signaled_on_non_signaling_thread_); +} diff --git a/webrtc/p2p/base/transportdescription.cc b/webrtc/p2p/base/transportdescription.cc new file mode 100644 index 0000000000..52033ec9c3 --- /dev/null +++ b/webrtc/p2p/base/transportdescription.cc @@ -0,0 +1,56 @@ +/* + * Copyright 2013 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/transportdescription.h" + +#include "webrtc/base/basicdefs.h" +#include "webrtc/base/stringutils.h" +#include "webrtc/p2p/base/constants.h" + +namespace cricket { + +bool StringToConnectionRole(const std::string& role_str, ConnectionRole* role) { + const char* const roles[] = { + CONNECTIONROLE_ACTIVE_STR, + CONNECTIONROLE_PASSIVE_STR, + CONNECTIONROLE_ACTPASS_STR, + CONNECTIONROLE_HOLDCONN_STR + }; + + for (size_t i = 0; i < ARRAY_SIZE(roles); ++i) { + if (_stricmp(roles[i], role_str.c_str()) == 0) { + *role = static_cast<ConnectionRole>(CONNECTIONROLE_ACTIVE + i); + return true; + } + } + return false; +} + +bool ConnectionRoleToString(const ConnectionRole& role, std::string* role_str) { + switch (role) { + case cricket::CONNECTIONROLE_ACTIVE: + *role_str = cricket::CONNECTIONROLE_ACTIVE_STR; + break; + case cricket::CONNECTIONROLE_ACTPASS: + *role_str = cricket::CONNECTIONROLE_ACTPASS_STR; + break; + case cricket::CONNECTIONROLE_PASSIVE: + *role_str = cricket::CONNECTIONROLE_PASSIVE_STR; + break; + case cricket::CONNECTIONROLE_HOLDCONN: + *role_str = cricket::CONNECTIONROLE_HOLDCONN_STR; + break; + default: + return false; + } + return true; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/transportdescription.h b/webrtc/p2p/base/transportdescription.h new file mode 100644 index 0000000000..8ea1f4bc2e --- /dev/null +++ b/webrtc/p2p/base/transportdescription.h @@ -0,0 +1,154 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TRANSPORTDESCRIPTION_H_ +#define WEBRTC_P2P_BASE_TRANSPORTDESCRIPTION_H_ + +#include <algorithm> +#include <string> +#include <vector> + +#include "webrtc/p2p/base/candidate.h" +#include "webrtc/p2p/base/constants.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/sslfingerprint.h" + +namespace cricket { + +// SEC_ENABLED and SEC_REQUIRED should only be used if the session +// was negotiated over TLS, to protect the inline crypto material +// exchange. +// SEC_DISABLED: No crypto in outgoing offer, ignore any supplied crypto. +// SEC_ENABLED: Crypto in outgoing offer and answer (if supplied in offer). +// SEC_REQUIRED: Crypto in outgoing offer and answer. Fail any offer with absent +// or unsupported crypto. +enum SecurePolicy { + SEC_DISABLED, + SEC_ENABLED, + SEC_REQUIRED +}; + +// Whether our side of the call is driving the negotiation, or the other side. +enum IceRole { + ICEROLE_CONTROLLING = 0, + ICEROLE_CONTROLLED, + ICEROLE_UNKNOWN +}; + +// ICE RFC 5245 implementation type. +enum IceMode { + ICEMODE_FULL, // As defined in http://tools.ietf.org/html/rfc5245#section-4.1 + ICEMODE_LITE // As defined in http://tools.ietf.org/html/rfc5245#section-4.2 +}; + +// RFC 4145 - http://tools.ietf.org/html/rfc4145#section-4 +// 'active': The endpoint will initiate an outgoing connection. +// 'passive': The endpoint will accept an incoming connection. +// 'actpass': The endpoint is willing to accept an incoming +// connection or to initiate an outgoing connection. +enum ConnectionRole { + CONNECTIONROLE_NONE = 0, + CONNECTIONROLE_ACTIVE, + CONNECTIONROLE_PASSIVE, + CONNECTIONROLE_ACTPASS, + CONNECTIONROLE_HOLDCONN, +}; + +extern const char CONNECTIONROLE_ACTIVE_STR[]; +extern const char CONNECTIONROLE_PASSIVE_STR[]; +extern const char CONNECTIONROLE_ACTPASS_STR[]; +extern const char CONNECTIONROLE_HOLDCONN_STR[]; + +bool StringToConnectionRole(const std::string& role_str, ConnectionRole* role); +bool ConnectionRoleToString(const ConnectionRole& role, std::string* role_str); + +typedef std::vector<Candidate> Candidates; + +struct TransportDescription { + TransportDescription() + : ice_mode(ICEMODE_FULL), + connection_role(CONNECTIONROLE_NONE) {} + + TransportDescription(const std::vector<std::string>& transport_options, + const std::string& ice_ufrag, + const std::string& ice_pwd, + IceMode ice_mode, + ConnectionRole role, + const rtc::SSLFingerprint* identity_fingerprint, + const Candidates& candidates) + : transport_options(transport_options), + ice_ufrag(ice_ufrag), + ice_pwd(ice_pwd), + ice_mode(ice_mode), + connection_role(role), + identity_fingerprint(CopyFingerprint(identity_fingerprint)), + candidates(candidates) {} + TransportDescription(const std::string& ice_ufrag, + const std::string& ice_pwd) + : ice_ufrag(ice_ufrag), + ice_pwd(ice_pwd), + ice_mode(ICEMODE_FULL), + connection_role(CONNECTIONROLE_NONE) {} + TransportDescription(const TransportDescription& from) + : transport_options(from.transport_options), + ice_ufrag(from.ice_ufrag), + ice_pwd(from.ice_pwd), + ice_mode(from.ice_mode), + connection_role(from.connection_role), + identity_fingerprint(CopyFingerprint(from.identity_fingerprint.get())), + candidates(from.candidates) {} + + TransportDescription& operator=(const TransportDescription& from) { + // Self-assignment + if (this == &from) + return *this; + + transport_options = from.transport_options; + ice_ufrag = from.ice_ufrag; + ice_pwd = from.ice_pwd; + ice_mode = from.ice_mode; + connection_role = from.connection_role; + + identity_fingerprint.reset(CopyFingerprint( + from.identity_fingerprint.get())); + candidates = from.candidates; + return *this; + } + + bool HasOption(const std::string& option) const { + return (std::find(transport_options.begin(), transport_options.end(), + option) != transport_options.end()); + } + void AddOption(const std::string& option) { + transport_options.push_back(option); + } + bool secure() const { return identity_fingerprint != NULL; } + + static rtc::SSLFingerprint* CopyFingerprint( + const rtc::SSLFingerprint* from) { + if (!from) + return NULL; + + return new rtc::SSLFingerprint(*from); + } + + std::vector<std::string> transport_options; + std::string ice_ufrag; + std::string ice_pwd; + IceMode ice_mode; + ConnectionRole connection_role; + + rtc::scoped_ptr<rtc::SSLFingerprint> identity_fingerprint; + Candidates candidates; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TRANSPORTDESCRIPTION_H_ diff --git a/webrtc/p2p/base/transportdescriptionfactory.cc b/webrtc/p2p/base/transportdescriptionfactory.cc new file mode 100644 index 0000000000..1ddf55d4a1 --- /dev/null +++ b/webrtc/p2p/base/transportdescriptionfactory.cc @@ -0,0 +1,126 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/transportdescriptionfactory.h" + +#include "webrtc/p2p/base/transportdescription.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/messagedigest.h" +#include "webrtc/base/sslfingerprint.h" + +namespace cricket { + +TransportDescriptionFactory::TransportDescriptionFactory() + : secure_(SEC_DISABLED) { +} + +TransportDescription* TransportDescriptionFactory::CreateOffer( + const TransportOptions& options, + const TransportDescription* current_description) const { + rtc::scoped_ptr<TransportDescription> desc(new TransportDescription()); + + // Generate the ICE credentials if we don't already have them. + if (!current_description || options.ice_restart) { + desc->ice_ufrag = rtc::CreateRandomString(ICE_UFRAG_LENGTH); + desc->ice_pwd = rtc::CreateRandomString(ICE_PWD_LENGTH); + } else { + desc->ice_ufrag = current_description->ice_ufrag; + desc->ice_pwd = current_description->ice_pwd; + } + + // If we are trying to establish a secure transport, add a fingerprint. + if (secure_ == SEC_ENABLED || secure_ == SEC_REQUIRED) { + // Fail if we can't create the fingerprint. + // If we are the initiator set role to "actpass". + if (!SetSecurityInfo(desc.get(), CONNECTIONROLE_ACTPASS)) { + return NULL; + } + } + + return desc.release(); +} + +TransportDescription* TransportDescriptionFactory::CreateAnswer( + const TransportDescription* offer, + const TransportOptions& options, + const TransportDescription* current_description) const { + // TODO(juberti): Figure out why we get NULL offers, and fix this upstream. + if (!offer) { + LOG(LS_WARNING) << "Failed to create TransportDescription answer " << + "because offer is NULL"; + return NULL; + } + + rtc::scoped_ptr<TransportDescription> desc(new TransportDescription()); + // Generate the ICE credentials if we don't already have them or ice is + // being restarted. + if (!current_description || options.ice_restart) { + desc->ice_ufrag = rtc::CreateRandomString(ICE_UFRAG_LENGTH); + desc->ice_pwd = rtc::CreateRandomString(ICE_PWD_LENGTH); + } else { + desc->ice_ufrag = current_description->ice_ufrag; + desc->ice_pwd = current_description->ice_pwd; + } + + // Negotiate security params. + if (offer && offer->identity_fingerprint.get()) { + // The offer supports DTLS, so answer with DTLS, as long as we support it. + if (secure_ == SEC_ENABLED || secure_ == SEC_REQUIRED) { + // Fail if we can't create the fingerprint. + // Setting DTLS role to active. + ConnectionRole role = (options.prefer_passive_role) ? + CONNECTIONROLE_PASSIVE : CONNECTIONROLE_ACTIVE; + + if (!SetSecurityInfo(desc.get(), role)) { + return NULL; + } + } + } else if (secure_ == SEC_REQUIRED) { + // We require DTLS, but the other side didn't offer it. Fail. + LOG(LS_WARNING) << "Failed to create TransportDescription answer " + "because of incompatible security settings"; + return NULL; + } + + return desc.release(); +} + +bool TransportDescriptionFactory::SetSecurityInfo( + TransportDescription* desc, ConnectionRole role) const { + if (!certificate_) { + LOG(LS_ERROR) << "Cannot create identity digest with no certificate"; + return false; + } + + // This digest algorithm is used to produce the a=fingerprint lines in SDP. + // RFC 4572 Section 5 requires that those lines use the same hash function as + // the certificate's signature. + std::string digest_alg; + if (!certificate_->ssl_certificate().GetSignatureDigestAlgorithm( + &digest_alg)) { + LOG(LS_ERROR) << "Failed to retrieve the certificate's digest algorithm"; + return false; + } + + desc->identity_fingerprint.reset( + rtc::SSLFingerprint::Create(digest_alg, certificate_->identity())); + if (!desc->identity_fingerprint.get()) { + LOG(LS_ERROR) << "Failed to create identity fingerprint, alg=" + << digest_alg; + return false; + } + + // Assign security role. + desc->connection_role = role; + return true; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/transportdescriptionfactory.h b/webrtc/p2p/base/transportdescriptionfactory.h new file mode 100644 index 0000000000..828aa6d22c --- /dev/null +++ b/webrtc/p2p/base/transportdescriptionfactory.h @@ -0,0 +1,69 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TRANSPORTDESCRIPTIONFACTORY_H_ +#define WEBRTC_P2P_BASE_TRANSPORTDESCRIPTIONFACTORY_H_ + +#include "webrtc/base/rtccertificate.h" +#include "webrtc/p2p/base/transportdescription.h" + +namespace rtc { +class SSLIdentity; +} + +namespace cricket { + +struct TransportOptions { + TransportOptions() : ice_restart(false), prefer_passive_role(false) {} + bool ice_restart; + bool prefer_passive_role; +}; + +// Creates transport descriptions according to the supplied configuration. +// When creating answers, performs the appropriate negotiation +// of the various fields to determine the proper result. +class TransportDescriptionFactory { + public: + // Default ctor; use methods below to set configuration. + TransportDescriptionFactory(); + SecurePolicy secure() const { return secure_; } + // The certificate to use when setting up DTLS. + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate() const { + return certificate_; + } + + // Specifies the transport security policy to use. + void set_secure(SecurePolicy s) { secure_ = s; } + // Specifies the certificate to use (only used when secure != SEC_DISABLED). + void set_certificate( + const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { + certificate_ = certificate; + } + + // Creates a transport description suitable for use in an offer. + TransportDescription* CreateOffer(const TransportOptions& options, + const TransportDescription* current_description) const; + // Create a transport description that is a response to an offer. + TransportDescription* CreateAnswer( + const TransportDescription* offer, + const TransportOptions& options, + const TransportDescription* current_description) const; + + private: + bool SetSecurityInfo(TransportDescription* description, + ConnectionRole role) const; + + SecurePolicy secure_; + rtc::scoped_refptr<rtc::RTCCertificate> certificate_; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TRANSPORTDESCRIPTIONFACTORY_H_ diff --git a/webrtc/p2p/base/transportdescriptionfactory_unittest.cc b/webrtc/p2p/base/transportdescriptionfactory_unittest.cc new file mode 100644 index 0000000000..e3992dfdd3 --- /dev/null +++ b/webrtc/p2p/base/transportdescriptionfactory_unittest.cc @@ -0,0 +1,262 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <string> +#include <vector> + +#include "webrtc/p2p/base/constants.h" +#include "webrtc/p2p/base/transportdescription.h" +#include "webrtc/p2p/base/transportdescriptionfactory.h" +#include "webrtc/base/fakesslidentity.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/ssladapter.h" + +using rtc::scoped_ptr; +using cricket::TransportDescriptionFactory; +using cricket::TransportDescription; +using cricket::TransportOptions; + +class TransportDescriptionFactoryTest : public testing::Test { + public: + TransportDescriptionFactoryTest() + : cert1_(rtc::RTCCertificate::Create(scoped_ptr<rtc::SSLIdentity>( + new rtc::FakeSSLIdentity("User1")).Pass())), + cert2_(rtc::RTCCertificate::Create(scoped_ptr<rtc::SSLIdentity>( + new rtc::FakeSSLIdentity("User2")).Pass())) { + } + + void CheckDesc(const TransportDescription* desc, + const std::string& opt, const std::string& ice_ufrag, + const std::string& ice_pwd, const std::string& dtls_alg) { + ASSERT_TRUE(desc != NULL); + EXPECT_EQ(!opt.empty(), desc->HasOption(opt)); + if (ice_ufrag.empty() && ice_pwd.empty()) { + EXPECT_EQ(static_cast<size_t>(cricket::ICE_UFRAG_LENGTH), + desc->ice_ufrag.size()); + EXPECT_EQ(static_cast<size_t>(cricket::ICE_PWD_LENGTH), + desc->ice_pwd.size()); + } else { + EXPECT_EQ(ice_ufrag, desc->ice_ufrag); + EXPECT_EQ(ice_pwd, desc->ice_pwd); + } + if (dtls_alg.empty()) { + EXPECT_TRUE(desc->identity_fingerprint.get() == NULL); + } else { + ASSERT_TRUE(desc->identity_fingerprint.get() != NULL); + EXPECT_EQ(desc->identity_fingerprint->algorithm, dtls_alg); + EXPECT_GT(desc->identity_fingerprint->digest.size(), 0U); + } + } + + // This test ice restart by doing two offer answer exchanges. On the second + // exchange ice is restarted. The test verifies that the ufrag and password + // in the offer and answer is changed. + // If |dtls| is true, the test verifies that the finger print is not changed. + void TestIceRestart(bool dtls) { + if (dtls) { + f1_.set_secure(cricket::SEC_ENABLED); + f2_.set_secure(cricket::SEC_ENABLED); + f1_.set_certificate(cert1_); + f2_.set_certificate(cert2_); + } else { + f1_.set_secure(cricket::SEC_DISABLED); + f2_.set_secure(cricket::SEC_DISABLED); + } + + cricket::TransportOptions options; + // The initial offer / answer exchange. + rtc::scoped_ptr<TransportDescription> offer(f1_.CreateOffer( + options, NULL)); + rtc::scoped_ptr<TransportDescription> answer( + f2_.CreateAnswer(offer.get(), + options, NULL)); + + // Create an updated offer where we restart ice. + options.ice_restart = true; + rtc::scoped_ptr<TransportDescription> restart_offer(f1_.CreateOffer( + options, offer.get())); + + VerifyUfragAndPasswordChanged(dtls, offer.get(), restart_offer.get()); + + // Create a new answer. The transport ufrag and password is changed since + // |options.ice_restart == true| + rtc::scoped_ptr<TransportDescription> restart_answer( + f2_.CreateAnswer(restart_offer.get(), options, answer.get())); + ASSERT_TRUE(restart_answer.get() != NULL); + + VerifyUfragAndPasswordChanged(dtls, answer.get(), restart_answer.get()); + } + + void VerifyUfragAndPasswordChanged(bool dtls, + const TransportDescription* org_desc, + const TransportDescription* restart_desc) { + EXPECT_NE(org_desc->ice_pwd, restart_desc->ice_pwd); + EXPECT_NE(org_desc->ice_ufrag, restart_desc->ice_ufrag); + EXPECT_EQ(static_cast<size_t>(cricket::ICE_UFRAG_LENGTH), + restart_desc->ice_ufrag.size()); + EXPECT_EQ(static_cast<size_t>(cricket::ICE_PWD_LENGTH), + restart_desc->ice_pwd.size()); + // If DTLS is enabled, make sure the finger print is unchanged. + if (dtls) { + EXPECT_FALSE( + org_desc->identity_fingerprint->GetRfc4572Fingerprint().empty()); + EXPECT_EQ(org_desc->identity_fingerprint->GetRfc4572Fingerprint(), + restart_desc->identity_fingerprint->GetRfc4572Fingerprint()); + } + } + + protected: + TransportDescriptionFactory f1_; + TransportDescriptionFactory f2_; + + rtc::scoped_refptr<rtc::RTCCertificate> cert1_; + rtc::scoped_refptr<rtc::RTCCertificate> cert2_; +}; + +TEST_F(TransportDescriptionFactoryTest, TestOfferDefault) { + scoped_ptr<TransportDescription> desc(f1_.CreateOffer( + TransportOptions(), NULL)); + CheckDesc(desc.get(), "", "", "", ""); +} + +TEST_F(TransportDescriptionFactoryTest, TestOfferDtls) { + f1_.set_secure(cricket::SEC_ENABLED); + f1_.set_certificate(cert1_); + std::string digest_alg; + ASSERT_TRUE(cert1_->ssl_certificate().GetSignatureDigestAlgorithm( + &digest_alg)); + scoped_ptr<TransportDescription> desc(f1_.CreateOffer( + TransportOptions(), NULL)); + CheckDesc(desc.get(), "", "", "", digest_alg); + // Ensure it also works with SEC_REQUIRED. + f1_.set_secure(cricket::SEC_REQUIRED); + desc.reset(f1_.CreateOffer(TransportOptions(), NULL)); + CheckDesc(desc.get(), "", "", "", digest_alg); +} + +// Test generating an offer with DTLS fails with no identity. +TEST_F(TransportDescriptionFactoryTest, TestOfferDtlsWithNoIdentity) { + f1_.set_secure(cricket::SEC_ENABLED); + scoped_ptr<TransportDescription> desc(f1_.CreateOffer( + TransportOptions(), NULL)); + ASSERT_TRUE(desc.get() == NULL); +} + +// Test updating an offer with DTLS to pick ICE. +// The ICE credentials should stay the same in the new offer. +TEST_F(TransportDescriptionFactoryTest, TestOfferDtlsReofferDtls) { + f1_.set_secure(cricket::SEC_ENABLED); + f1_.set_certificate(cert1_); + std::string digest_alg; + ASSERT_TRUE(cert1_->ssl_certificate().GetSignatureDigestAlgorithm( + &digest_alg)); + scoped_ptr<TransportDescription> old_desc(f1_.CreateOffer( + TransportOptions(), NULL)); + ASSERT_TRUE(old_desc.get() != NULL); + scoped_ptr<TransportDescription> desc( + f1_.CreateOffer(TransportOptions(), old_desc.get())); + CheckDesc(desc.get(), "", + old_desc->ice_ufrag, old_desc->ice_pwd, digest_alg); +} + +TEST_F(TransportDescriptionFactoryTest, TestAnswerDefault) { + scoped_ptr<TransportDescription> offer(f1_.CreateOffer( + TransportOptions(), NULL)); + ASSERT_TRUE(offer.get() != NULL); + scoped_ptr<TransportDescription> desc(f2_.CreateAnswer( + offer.get(), TransportOptions(), NULL)); + CheckDesc(desc.get(), "", "", "", ""); + desc.reset(f2_.CreateAnswer(offer.get(), TransportOptions(), + NULL)); + CheckDesc(desc.get(), "", "", "", ""); +} + +// Test that we can update an answer properly; ICE credentials shouldn't change. +TEST_F(TransportDescriptionFactoryTest, TestReanswer) { + scoped_ptr<TransportDescription> offer( + f1_.CreateOffer(TransportOptions(), NULL)); + ASSERT_TRUE(offer.get() != NULL); + scoped_ptr<TransportDescription> old_desc( + f2_.CreateAnswer(offer.get(), TransportOptions(), NULL)); + ASSERT_TRUE(old_desc.get() != NULL); + scoped_ptr<TransportDescription> desc( + f2_.CreateAnswer(offer.get(), TransportOptions(), + old_desc.get())); + ASSERT_TRUE(desc.get() != NULL); + CheckDesc(desc.get(), "", + old_desc->ice_ufrag, old_desc->ice_pwd, ""); +} + +// Test that we handle answering an offer with DTLS with no DTLS. +TEST_F(TransportDescriptionFactoryTest, TestAnswerDtlsToNoDtls) { + f1_.set_secure(cricket::SEC_ENABLED); + f1_.set_certificate(cert1_); + scoped_ptr<TransportDescription> offer( + f1_.CreateOffer(TransportOptions(), NULL)); + ASSERT_TRUE(offer.get() != NULL); + scoped_ptr<TransportDescription> desc( + f2_.CreateAnswer(offer.get(), TransportOptions(), NULL)); + CheckDesc(desc.get(), "", "", "", ""); +} + +// Test that we handle answering an offer without DTLS if we have DTLS enabled, +// but fail if we require DTLS. +TEST_F(TransportDescriptionFactoryTest, TestAnswerNoDtlsToDtls) { + f2_.set_secure(cricket::SEC_ENABLED); + f2_.set_certificate(cert2_); + scoped_ptr<TransportDescription> offer( + f1_.CreateOffer(TransportOptions(), NULL)); + ASSERT_TRUE(offer.get() != NULL); + scoped_ptr<TransportDescription> desc( + f2_.CreateAnswer(offer.get(), TransportOptions(), NULL)); + CheckDesc(desc.get(), "", "", "", ""); + f2_.set_secure(cricket::SEC_REQUIRED); + desc.reset(f2_.CreateAnswer(offer.get(), TransportOptions(), + NULL)); + ASSERT_TRUE(desc.get() == NULL); +} + +// Test that we handle answering an DTLS offer with DTLS, both if we have +// DTLS enabled and required. +TEST_F(TransportDescriptionFactoryTest, TestAnswerDtlsToDtls) { + f1_.set_secure(cricket::SEC_ENABLED); + f1_.set_certificate(cert1_); + + f2_.set_secure(cricket::SEC_ENABLED); + f2_.set_certificate(cert2_); + // f2_ produces the answer that is being checked in this test, so the + // answer must contain fingerprint lines with cert2_'s digest algorithm. + std::string digest_alg2; + ASSERT_TRUE(cert2_->ssl_certificate().GetSignatureDigestAlgorithm( + &digest_alg2)); + + scoped_ptr<TransportDescription> offer( + f1_.CreateOffer(TransportOptions(), NULL)); + ASSERT_TRUE(offer.get() != NULL); + scoped_ptr<TransportDescription> desc( + f2_.CreateAnswer(offer.get(), TransportOptions(), NULL)); + CheckDesc(desc.get(), "", "", "", digest_alg2); + f2_.set_secure(cricket::SEC_REQUIRED); + desc.reset(f2_.CreateAnswer(offer.get(), TransportOptions(), + NULL)); + CheckDesc(desc.get(), "", "", "", digest_alg2); +} + +// Test that ice ufrag and password is changed in an updated offer and answer +// if |TransportDescriptionOptions::ice_restart| is true. +TEST_F(TransportDescriptionFactoryTest, TestIceRestart) { + TestIceRestart(false); +} + +// Test that ice ufrag and password is changed in an updated offer and answer +// if |TransportDescriptionOptions::ice_restart| is true and DTLS is enabled. +TEST_F(TransportDescriptionFactoryTest, TestIceRestartWithDtls) { + TestIceRestart(true); +} diff --git a/webrtc/p2p/base/transportinfo.h b/webrtc/p2p/base/transportinfo.h new file mode 100644 index 0000000000..3fbf204d98 --- /dev/null +++ b/webrtc/p2p/base/transportinfo.h @@ -0,0 +1,43 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TRANSPORTINFO_H_ +#define WEBRTC_P2P_BASE_TRANSPORTINFO_H_ + +#include <string> +#include <vector> + +#include "webrtc/p2p/base/candidate.h" +#include "webrtc/p2p/base/constants.h" +#include "webrtc/p2p/base/transportdescription.h" +#include "webrtc/base/helpers.h" + +namespace cricket { + +// A TransportInfo is NOT a transport-info message. It is comparable +// to a "ContentInfo". A transport-infos message is basically just a +// collection of TransportInfos. +struct TransportInfo { + TransportInfo() {} + + TransportInfo(const std::string& content_name, + const TransportDescription& description) + : content_name(content_name), + description(description) {} + + std::string content_name; + TransportDescription description; +}; + +typedef std::vector<TransportInfo> TransportInfos; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TRANSPORTINFO_H_ diff --git a/webrtc/p2p/base/turnport.cc b/webrtc/p2p/base/turnport.cc new file mode 100644 index 0000000000..3fdcac5f31 --- /dev/null +++ b/webrtc/p2p/base/turnport.cc @@ -0,0 +1,1371 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/turnport.h" + +#include <functional> + +#include "webrtc/p2p/base/common.h" +#include "webrtc/p2p/base/stun.h" +#include "webrtc/base/asyncpacketsocket.h" +#include "webrtc/base/byteorder.h" +#include "webrtc/base/common.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/nethelpers.h" +#include "webrtc/base/socketaddress.h" +#include "webrtc/base/stringencode.h" + +namespace cricket { + +// TODO(juberti): Move to stun.h when relay messages have been renamed. +static const int TURN_ALLOCATE_REQUEST = STUN_ALLOCATE_REQUEST; + +// TODO(juberti): Extract to turnmessage.h +static const int TURN_DEFAULT_PORT = 3478; +static const int TURN_CHANNEL_NUMBER_START = 0x4000; +static const int TURN_PERMISSION_TIMEOUT = 5 * 60 * 1000; // 5 minutes + +static const size_t TURN_CHANNEL_HEADER_SIZE = 4U; + +// Retry at most twice (i.e. three different ALLOCATE requests) on +// STUN_ERROR_ALLOCATION_MISMATCH error per rfc5766. +static const size_t MAX_ALLOCATE_MISMATCH_RETRIES = 2; + +inline bool IsTurnChannelData(uint16_t msg_type) { + return ((msg_type & 0xC000) == 0x4000); // MSB are 0b01 +} + +static int GetRelayPreference(cricket::ProtocolType proto, bool secure) { + int relay_preference = ICE_TYPE_PREFERENCE_RELAY; + if (proto == cricket::PROTO_TCP) { + relay_preference -= 1; + if (secure) + relay_preference -= 1; + } + + ASSERT(relay_preference >= 0); + return relay_preference; +} + +class TurnAllocateRequest : public StunRequest { + public: + explicit TurnAllocateRequest(TurnPort* port); + void Prepare(StunMessage* request) override; + void OnSent() override; + void OnResponse(StunMessage* response) override; + void OnErrorResponse(StunMessage* response) override; + void OnTimeout() override; + + private: + // Handles authentication challenge from the server. + void OnAuthChallenge(StunMessage* response, int code); + void OnTryAlternate(StunMessage* response, int code); + void OnUnknownAttribute(StunMessage* response); + + TurnPort* port_; +}; + +class TurnRefreshRequest : public StunRequest { + public: + explicit TurnRefreshRequest(TurnPort* port); + void Prepare(StunMessage* request) override; + void OnSent() override; + void OnResponse(StunMessage* response) override; + void OnErrorResponse(StunMessage* response) override; + void OnTimeout() override; + void set_lifetime(int lifetime) { lifetime_ = lifetime; } + + private: + TurnPort* port_; + int lifetime_; +}; + +class TurnCreatePermissionRequest : public StunRequest, + public sigslot::has_slots<> { + public: + TurnCreatePermissionRequest(TurnPort* port, TurnEntry* entry, + const rtc::SocketAddress& ext_addr); + void Prepare(StunMessage* request) override; + void OnSent() override; + void OnResponse(StunMessage* response) override; + void OnErrorResponse(StunMessage* response) override; + void OnTimeout() override; + + private: + void OnEntryDestroyed(TurnEntry* entry); + + TurnPort* port_; + TurnEntry* entry_; + rtc::SocketAddress ext_addr_; +}; + +class TurnChannelBindRequest : public StunRequest, + public sigslot::has_slots<> { + public: + TurnChannelBindRequest(TurnPort* port, TurnEntry* entry, int channel_id, + const rtc::SocketAddress& ext_addr); + void Prepare(StunMessage* request) override; + void OnSent() override; + void OnResponse(StunMessage* response) override; + void OnErrorResponse(StunMessage* response) override; + void OnTimeout() override; + + private: + void OnEntryDestroyed(TurnEntry* entry); + + TurnPort* port_; + TurnEntry* entry_; + int channel_id_; + rtc::SocketAddress ext_addr_; +}; + +// Manages a "connection" to a remote destination. We will attempt to bring up +// a channel for this remote destination to reduce the overhead of sending data. +class TurnEntry : public sigslot::has_slots<> { + public: + enum BindState { STATE_UNBOUND, STATE_BINDING, STATE_BOUND }; + TurnEntry(TurnPort* port, int channel_id, + const rtc::SocketAddress& ext_addr); + + TurnPort* port() { return port_; } + + int channel_id() const { return channel_id_; } + const rtc::SocketAddress& address() const { return ext_addr_; } + BindState state() const { return state_; } + + // Helper methods to send permission and channel bind requests. + void SendCreatePermissionRequest(); + void SendChannelBindRequest(int delay); + // Sends a packet to the given destination address. + // This will wrap the packet in STUN if necessary. + int Send(const void* data, size_t size, bool payload, + const rtc::PacketOptions& options); + + void OnCreatePermissionSuccess(); + void OnCreatePermissionError(StunMessage* response, int code); + void OnChannelBindSuccess(); + void OnChannelBindError(StunMessage* response, int code); + // Signal sent when TurnEntry is destroyed. + sigslot::signal1<TurnEntry*> SignalDestroyed; + + private: + TurnPort* port_; + int channel_id_; + rtc::SocketAddress ext_addr_; + BindState state_; +}; + +TurnPort::TurnPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + rtc::AsyncPacketSocket* socket, + const std::string& username, + const std::string& password, + const ProtocolAddress& server_address, + const RelayCredentials& credentials, + int server_priority, + const std::string& origin) + : Port(thread, + factory, + network, + socket->GetLocalAddress().ipaddr(), + username, + password), + server_address_(server_address), + credentials_(credentials), + socket_(socket), + resolver_(NULL), + error_(0), + request_manager_(thread), + next_channel_number_(TURN_CHANNEL_NUMBER_START), + state_(STATE_CONNECTING), + server_priority_(server_priority), + allocate_mismatch_retries_(0) { + request_manager_.SignalSendPacket.connect(this, &TurnPort::OnSendStunPacket); + request_manager_.set_origin(origin); +} + +TurnPort::TurnPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + const ProtocolAddress& server_address, + const RelayCredentials& credentials, + int server_priority, + const std::string& origin) + : Port(thread, + RELAY_PORT_TYPE, + factory, + network, + ip, + min_port, + max_port, + username, + password), + server_address_(server_address), + credentials_(credentials), + socket_(NULL), + resolver_(NULL), + error_(0), + request_manager_(thread), + next_channel_number_(TURN_CHANNEL_NUMBER_START), + state_(STATE_CONNECTING), + server_priority_(server_priority), + allocate_mismatch_retries_(0) { + request_manager_.SignalSendPacket.connect(this, &TurnPort::OnSendStunPacket); + request_manager_.set_origin(origin); +} + +TurnPort::~TurnPort() { + // TODO(juberti): Should this even be necessary? + + // release the allocation by sending a refresh with + // lifetime 0. + if (ready()) { + TurnRefreshRequest bye(this); + bye.set_lifetime(0); + SendRequest(&bye, 0); + } + + while (!entries_.empty()) { + DestroyEntry(entries_.front()->address()); + } + if (resolver_) { + resolver_->Destroy(false); + } + if (!SharedSocket()) { + delete socket_; + } +} + +rtc::SocketAddress TurnPort::GetLocalAddress() const { + return socket_ ? socket_->GetLocalAddress() : rtc::SocketAddress(); +} + +void TurnPort::PrepareAddress() { + if (credentials_.username.empty() || + credentials_.password.empty()) { + LOG(LS_ERROR) << "Allocation can't be started without setting the" + << " TURN server credentials for the user."; + OnAllocateError(); + return; + } + + if (!server_address_.address.port()) { + // We will set default TURN port, if no port is set in the address. + server_address_.address.SetPort(TURN_DEFAULT_PORT); + } + + if (server_address_.address.IsUnresolved()) { + ResolveTurnAddress(server_address_.address); + } else { + // If protocol family of server address doesn't match with local, return. + if (!IsCompatibleAddress(server_address_.address)) { + LOG(LS_ERROR) << "IP address family does not match: " + << "server: " << server_address_.address.family() + << "local: " << ip().family(); + OnAllocateError(); + return; + } + + // Insert the current address to prevent redirection pingpong. + attempted_server_addresses_.insert(server_address_.address); + + LOG_J(LS_INFO, this) << "Trying to connect to TURN server via " + << ProtoToString(server_address_.proto) << " @ " + << server_address_.address.ToSensitiveString(); + if (!CreateTurnClientSocket()) { + LOG(LS_ERROR) << "Failed to create TURN client socket"; + OnAllocateError(); + return; + } + if (server_address_.proto == PROTO_UDP) { + // If its UDP, send AllocateRequest now. + // For TCP and TLS AllcateRequest will be sent by OnSocketConnect. + SendRequest(new TurnAllocateRequest(this), 0); + } + } +} + +bool TurnPort::CreateTurnClientSocket() { + ASSERT(!socket_ || SharedSocket()); + + if (server_address_.proto == PROTO_UDP && !SharedSocket()) { + socket_ = socket_factory()->CreateUdpSocket( + rtc::SocketAddress(ip(), 0), min_port(), max_port()); + } else if (server_address_.proto == PROTO_TCP) { + ASSERT(!SharedSocket()); + int opts = rtc::PacketSocketFactory::OPT_STUN; + // If secure bit is enabled in server address, use TLS over TCP. + if (server_address_.secure) { + opts |= rtc::PacketSocketFactory::OPT_TLS; + } + socket_ = socket_factory()->CreateClientTcpSocket( + rtc::SocketAddress(ip(), 0), server_address_.address, + proxy(), user_agent(), opts); + } + + if (!socket_) { + error_ = SOCKET_ERROR; + return false; + } + + // Apply options if any. + for (SocketOptionsMap::iterator iter = socket_options_.begin(); + iter != socket_options_.end(); ++iter) { + socket_->SetOption(iter->first, iter->second); + } + + if (!SharedSocket()) { + // If socket is shared, AllocationSequence will receive the packet. + socket_->SignalReadPacket.connect(this, &TurnPort::OnReadPacket); + } + + socket_->SignalReadyToSend.connect(this, &TurnPort::OnReadyToSend); + + // TCP port is ready to send stun requests after the socket is connected, + // while UDP port is ready to do so once the socket is created. + if (server_address_.proto == PROTO_TCP) { + socket_->SignalConnect.connect(this, &TurnPort::OnSocketConnect); + socket_->SignalClose.connect(this, &TurnPort::OnSocketClose); + } else { + state_ = STATE_CONNECTED; + } + return true; +} + +void TurnPort::OnSocketConnect(rtc::AsyncPacketSocket* socket) { + ASSERT(server_address_.proto == PROTO_TCP); + // Do not use this port if the socket bound to a different address than + // the one we asked for. This is seen in Chrome, where TCP sockets cannot be + // given a binding address, and the platform is expected to pick the + // correct local address. + + // However, there are two situations in which we allow the bound address to + // differ from the requested address: 1. The bound address is the loopback + // address. This happens when a proxy forces TCP to bind to only the + // localhost address (see issue 3927). 2. The bound address is the "any + // address". This happens when multiple_routes is disabled (see issue 4780). + if (socket->GetLocalAddress().ipaddr() != ip()) { + if (socket->GetLocalAddress().IsLoopbackIP()) { + LOG(LS_WARNING) << "Socket is bound to a different address:" + << socket->GetLocalAddress().ipaddr().ToString() + << ", rather then the local port:" << ip().ToString() + << ". Still allowing it since it's localhost."; + } else if (IPIsAny(ip())) { + LOG(LS_WARNING) << "Socket is bound to a different address:" + << socket->GetLocalAddress().ipaddr().ToString() + << ", rather then the local port:" << ip().ToString() + << ". Still allowing it since it's any address" + << ", possibly caused by multiple_routes being disabled."; + } else { + LOG(LS_WARNING) << "Socket is bound to a different address:" + << socket->GetLocalAddress().ipaddr().ToString() + << ", rather then the local port:" << ip().ToString() + << ". Discarding TURN port."; + OnAllocateError(); + return; + } + } + + state_ = STATE_CONNECTED; // It is ready to send stun requests. + if (server_address_.address.IsUnresolved()) { + server_address_.address = socket_->GetRemoteAddress(); + } + + LOG(LS_INFO) << "TurnPort connected to " << socket->GetRemoteAddress() + << " using tcp."; + SendRequest(new TurnAllocateRequest(this), 0); +} + +void TurnPort::OnSocketClose(rtc::AsyncPacketSocket* socket, int error) { + LOG_J(LS_WARNING, this) << "Connection with server failed, error=" << error; + ASSERT(socket == socket_); + if (!ready()) { + OnAllocateError(); + } + request_manager_.Clear(); + state_ = STATE_DISCONNECTED; +} + +void TurnPort::OnAllocateMismatch() { + if (allocate_mismatch_retries_ >= MAX_ALLOCATE_MISMATCH_RETRIES) { + LOG_J(LS_WARNING, this) << "Giving up on the port after " + << allocate_mismatch_retries_ + << " retries for STUN_ERROR_ALLOCATION_MISMATCH"; + OnAllocateError(); + return; + } + + LOG_J(LS_INFO, this) << "Allocating a new socket after " + << "STUN_ERROR_ALLOCATION_MISMATCH, retry = " + << allocate_mismatch_retries_ + 1; + if (SharedSocket()) { + ResetSharedSocket(); + } else { + delete socket_; + } + socket_ = NULL; + + PrepareAddress(); + ++allocate_mismatch_retries_; +} + +Connection* TurnPort::CreateConnection(const Candidate& address, + CandidateOrigin origin) { + // TURN-UDP can only connect to UDP candidates. + if (address.protocol() != UDP_PROTOCOL_NAME) { + return NULL; + } + + if (!IsCompatibleAddress(address.address())) { + return NULL; + } + + if (state_ == STATE_DISCONNECTED) { + return NULL; + } + + // Create an entry, if needed, so we can get our permissions set up correctly. + CreateEntry(address.address()); + + // A TURN port will have two candiates, STUN and TURN. STUN may not + // present in all cases. If present stun candidate will be added first + // and TURN candidate later. + for (size_t index = 0; index < Candidates().size(); ++index) { + if (Candidates()[index].type() == RELAY_PORT_TYPE) { + ProxyConnection* conn = new ProxyConnection(this, index, address); + conn->SignalDestroyed.connect(this, &TurnPort::OnConnectionDestroyed); + AddConnection(conn); + return conn; + } + } + return NULL; +} + +int TurnPort::SetOption(rtc::Socket::Option opt, int value) { + if (!socket_) { + // If socket is not created yet, these options will be applied during socket + // creation. + socket_options_[opt] = value; + return 0; + } + return socket_->SetOption(opt, value); +} + +int TurnPort::GetOption(rtc::Socket::Option opt, int* value) { + if (!socket_) { + SocketOptionsMap::const_iterator it = socket_options_.find(opt); + if (it == socket_options_.end()) { + return -1; + } + *value = it->second; + return 0; + } + + return socket_->GetOption(opt, value); +} + +int TurnPort::GetError() { + return error_; +} + +int TurnPort::SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, + bool payload) { + // Try to find an entry for this specific address; we should have one. + TurnEntry* entry = FindEntry(addr); + if (!entry) { + LOG(LS_ERROR) << "Did not find the TurnEntry for address " << addr; + return 0; + } + + if (!ready()) { + error_ = EWOULDBLOCK; + return SOCKET_ERROR; + } + + // Send the actual contents to the server using the usual mechanism. + int sent = entry->Send(data, size, payload, options); + if (sent <= 0) { + return SOCKET_ERROR; + } + + // The caller of the function is expecting the number of user data bytes, + // rather than the size of the packet. + return static_cast<int>(size); +} + +void TurnPort::OnReadPacket( + rtc::AsyncPacketSocket* socket, const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + ASSERT(socket == socket_); + + // This is to guard against a STUN response from previous server after + // alternative server redirection. TODO(guoweis): add a unit test for this + // race condition. + if (remote_addr != server_address_.address) { + LOG_J(LS_WARNING, this) << "Discarding TURN message from unknown address:" + << remote_addr.ToString() + << ", server_address_:" + << server_address_.address.ToString(); + return; + } + + // The message must be at least the size of a channel header. + if (size < TURN_CHANNEL_HEADER_SIZE) { + LOG_J(LS_WARNING, this) << "Received TURN message that was too short"; + return; + } + + // Check the message type, to see if is a Channel Data message. + // The message will either be channel data, a TURN data indication, or + // a response to a previous request. + uint16_t msg_type = rtc::GetBE16(data); + if (IsTurnChannelData(msg_type)) { + HandleChannelData(msg_type, data, size, packet_time); + } else if (msg_type == TURN_DATA_INDICATION) { + HandleDataIndication(data, size, packet_time); + } else { + if (SharedSocket() && + (msg_type == STUN_BINDING_RESPONSE || + msg_type == STUN_BINDING_ERROR_RESPONSE)) { + LOG_J(LS_VERBOSE, this) << + "Ignoring STUN binding response message on shared socket."; + return; + } + + // This must be a response for one of our requests. + // Check success responses, but not errors, for MESSAGE-INTEGRITY. + if (IsStunSuccessResponseType(msg_type) && + !StunMessage::ValidateMessageIntegrity(data, size, hash())) { + LOG_J(LS_WARNING, this) << "Received TURN message with invalid " + << "message integrity, msg_type=" << msg_type; + return; + } + request_manager_.CheckResponse(data, size); + } +} + +void TurnPort::OnReadyToSend(rtc::AsyncPacketSocket* socket) { + if (ready()) { + Port::OnReadyToSend(); + } +} + + +// Update current server address port with the alternate server address port. +bool TurnPort::SetAlternateServer(const rtc::SocketAddress& address) { + // Check if we have seen this address before and reject if we did. + AttemptedServerSet::iterator iter = attempted_server_addresses_.find(address); + if (iter != attempted_server_addresses_.end()) { + LOG_J(LS_WARNING, this) << "Redirection to [" + << address.ToSensitiveString() + << "] ignored, allocation failed."; + return false; + } + + // If protocol family of server address doesn't match with local, return. + if (!IsCompatibleAddress(address)) { + LOG(LS_WARNING) << "Server IP address family does not match with " + << "local host address family type"; + return false; + } + + LOG_J(LS_INFO, this) << "Redirecting from TURN server [" + << server_address_.address.ToSensitiveString() + << "] to TURN server [" + << address.ToSensitiveString() + << "]"; + server_address_ = ProtocolAddress(address, server_address_.proto, + server_address_.secure); + + // Insert the current address to prevent redirection pingpong. + attempted_server_addresses_.insert(server_address_.address); + return true; +} + +void TurnPort::ResolveTurnAddress(const rtc::SocketAddress& address) { + if (resolver_) + return; + + resolver_ = socket_factory()->CreateAsyncResolver(); + resolver_->SignalDone.connect(this, &TurnPort::OnResolveResult); + resolver_->Start(address); +} + +void TurnPort::OnResolveResult(rtc::AsyncResolverInterface* resolver) { + ASSERT(resolver == resolver_); + // If DNS resolve is failed when trying to connect to the server using TCP, + // one of the reason could be due to DNS queries blocked by firewall. + // In such cases we will try to connect to the server with hostname, assuming + // socket layer will resolve the hostname through a HTTP proxy (if any). + if (resolver_->GetError() != 0 && server_address_.proto == PROTO_TCP) { + if (!CreateTurnClientSocket()) { + OnAllocateError(); + } + return; + } + + // Copy the original server address in |resolved_address|. For TLS based + // sockets we need hostname along with resolved address. + rtc::SocketAddress resolved_address = server_address_.address; + if (resolver_->GetError() != 0 || + !resolver_->GetResolvedAddress(ip().family(), &resolved_address)) { + LOG_J(LS_WARNING, this) << "TURN host lookup received error " + << resolver_->GetError(); + error_ = resolver_->GetError(); + OnAllocateError(); + return; + } + // Signal needs both resolved and unresolved address. After signal is sent + // we can copy resolved address back into |server_address_|. + SignalResolvedServerAddress(this, server_address_.address, + resolved_address); + server_address_.address = resolved_address; + PrepareAddress(); +} + +void TurnPort::OnSendStunPacket(const void* data, size_t size, + StunRequest* request) { + ASSERT(connected()); + rtc::PacketOptions options(DefaultDscpValue()); + if (Send(data, size, options) < 0) { + LOG_J(LS_ERROR, this) << "Failed to send TURN message, err=" + << socket_->GetError(); + } +} + +void TurnPort::OnStunAddress(const rtc::SocketAddress& address) { + // STUN Port will discover STUN candidate, as it's supplied with first TURN + // server address. + // Why not using this address? - P2PTransportChannel will start creating + // connections after first candidate, which means it could start creating the + // connections before TURN candidate added. For that to handle, we need to + // supply STUN candidate from this port to UDPPort, and TurnPort should have + // handle to UDPPort to pass back the address. +} + +void TurnPort::OnAllocateSuccess(const rtc::SocketAddress& address, + const rtc::SocketAddress& stun_address) { + state_ = STATE_READY; + + rtc::SocketAddress related_address = stun_address; + if (!(candidate_filter() & CF_REFLEXIVE)) { + // If candidate filter only allows relay type of address, empty raddr to + // avoid local address leakage. + related_address = rtc::EmptySocketAddressWithFamily(stun_address.family()); + } + + // For relayed candidate, Base is the candidate itself. + AddAddress(address, // Candidate address. + address, // Base address. + related_address, // Related address. + UDP_PROTOCOL_NAME, + ProtoToString(server_address_.proto), // The first hop protocol. + "", // TCP canddiate type, empty for turn candidates. + RELAY_PORT_TYPE, + GetRelayPreference(server_address_.proto, server_address_.secure), + server_priority_, true); +} + +void TurnPort::OnAllocateError() { + // We will send SignalPortError asynchronously as this can be sent during + // port initialization. This way it will not be blocking other port + // creation. + thread()->Post(this, MSG_ERROR); +} + +void TurnPort::OnMessage(rtc::Message* message) { + if (message->message_id == MSG_ERROR) { + SignalPortError(this); + return; + } else if (message->message_id == MSG_ALLOCATE_MISMATCH) { + OnAllocateMismatch(); + return; + } else if (message->message_id == MSG_TRY_ALTERNATE_SERVER) { + if (server_address().proto == PROTO_UDP) { + // Send another allocate request to alternate server, with the received + // realm and nonce values. + SendRequest(new TurnAllocateRequest(this), 0); + } else { + // Since it's TCP, we have to delete the connected socket and reconnect + // with the alternate server. PrepareAddress will send stun binding once + // the new socket is connected. + ASSERT(server_address().proto == PROTO_TCP); + ASSERT(!SharedSocket()); + delete socket_; + socket_ = NULL; + PrepareAddress(); + } + return; + } + + Port::OnMessage(message); +} + +void TurnPort::OnAllocateRequestTimeout() { + OnAllocateError(); +} + +void TurnPort::HandleDataIndication(const char* data, size_t size, + const rtc::PacketTime& packet_time) { + // Read in the message, and process according to RFC5766, Section 10.4. + rtc::ByteBuffer buf(data, size); + TurnMessage msg; + if (!msg.Read(&buf)) { + LOG_J(LS_WARNING, this) << "Received invalid TURN data indication"; + return; + } + + // Check mandatory attributes. + const StunAddressAttribute* addr_attr = + msg.GetAddress(STUN_ATTR_XOR_PEER_ADDRESS); + if (!addr_attr) { + LOG_J(LS_WARNING, this) << "Missing STUN_ATTR_XOR_PEER_ADDRESS attribute " + << "in data indication."; + return; + } + + const StunByteStringAttribute* data_attr = + msg.GetByteString(STUN_ATTR_DATA); + if (!data_attr) { + LOG_J(LS_WARNING, this) << "Missing STUN_ATTR_DATA attribute in " + << "data indication."; + return; + } + + // Verify that the data came from somewhere we think we have a permission for. + rtc::SocketAddress ext_addr(addr_attr->GetAddress()); + if (!HasPermission(ext_addr.ipaddr())) { + LOG_J(LS_WARNING, this) << "Received TURN data indication with invalid " + << "peer address, addr=" + << ext_addr.ToSensitiveString(); + return; + } + + DispatchPacket(data_attr->bytes(), data_attr->length(), ext_addr, + PROTO_UDP, packet_time); +} + +void TurnPort::HandleChannelData(int channel_id, const char* data, + size_t size, + const rtc::PacketTime& packet_time) { + // Read the message, and process according to RFC5766, Section 11.6. + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Channel Number | Length | + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | | + // / Application Data / + // / / + // | | + // | +-------------------------------+ + // | | + // +-------------------------------+ + + // Extract header fields from the message. + uint16_t len = rtc::GetBE16(data + 2); + if (len > size - TURN_CHANNEL_HEADER_SIZE) { + LOG_J(LS_WARNING, this) << "Received TURN channel data message with " + << "incorrect length, len=" << len; + return; + } + // Allowing messages larger than |len|, as ChannelData can be padded. + + TurnEntry* entry = FindEntry(channel_id); + if (!entry) { + LOG_J(LS_WARNING, this) << "Received TURN channel data message for invalid " + << "channel, channel_id=" << channel_id; + return; + } + + DispatchPacket(data + TURN_CHANNEL_HEADER_SIZE, len, entry->address(), + PROTO_UDP, packet_time); +} + +void TurnPort::DispatchPacket(const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + ProtocolType proto, const rtc::PacketTime& packet_time) { + if (Connection* conn = GetConnection(remote_addr)) { + conn->OnReadPacket(data, size, packet_time); + } else { + Port::OnReadPacket(data, size, remote_addr, proto); + } +} + +bool TurnPort::ScheduleRefresh(int lifetime) { + // Lifetime is in seconds; we schedule a refresh for one minute less. + if (lifetime < 2 * 60) { + LOG_J(LS_WARNING, this) << "Received response with lifetime that was " + << "too short, lifetime=" << lifetime; + return false; + } + + int delay = (lifetime - 60) * 1000; + SendRequest(new TurnRefreshRequest(this), delay); + LOG_J(LS_INFO, this) << "Scheduled refresh in " << delay << "ms."; + return true; +} + +void TurnPort::SendRequest(StunRequest* req, int delay) { + request_manager_.SendDelayed(req, delay); +} + +void TurnPort::AddRequestAuthInfo(StunMessage* msg) { + // If we've gotten the necessary data from the server, add it to our request. + VERIFY(!hash_.empty()); + VERIFY(msg->AddAttribute(new StunByteStringAttribute( + STUN_ATTR_USERNAME, credentials_.username))); + VERIFY(msg->AddAttribute(new StunByteStringAttribute( + STUN_ATTR_REALM, realm_))); + VERIFY(msg->AddAttribute(new StunByteStringAttribute( + STUN_ATTR_NONCE, nonce_))); + VERIFY(msg->AddMessageIntegrity(hash())); +} + +int TurnPort::Send(const void* data, size_t len, + const rtc::PacketOptions& options) { + return socket_->SendTo(data, len, server_address_.address, options); +} + +void TurnPort::UpdateHash() { + VERIFY(ComputeStunCredentialHash(credentials_.username, realm_, + credentials_.password, &hash_)); +} + +bool TurnPort::UpdateNonce(StunMessage* response) { + // When stale nonce error received, we should update + // hash and store realm and nonce. + // Check the mandatory attributes. + const StunByteStringAttribute* realm_attr = + response->GetByteString(STUN_ATTR_REALM); + if (!realm_attr) { + LOG(LS_ERROR) << "Missing STUN_ATTR_REALM attribute in " + << "stale nonce error response."; + return false; + } + set_realm(realm_attr->GetString()); + + const StunByteStringAttribute* nonce_attr = + response->GetByteString(STUN_ATTR_NONCE); + if (!nonce_attr) { + LOG(LS_ERROR) << "Missing STUN_ATTR_NONCE attribute in " + << "stale nonce error response."; + return false; + } + set_nonce(nonce_attr->GetString()); + return true; +} + +static bool MatchesIP(TurnEntry* e, rtc::IPAddress ipaddr) { + return e->address().ipaddr() == ipaddr; +} +bool TurnPort::HasPermission(const rtc::IPAddress& ipaddr) const { + return (std::find_if(entries_.begin(), entries_.end(), + std::bind2nd(std::ptr_fun(MatchesIP), ipaddr)) != entries_.end()); +} + +static bool MatchesAddress(TurnEntry* e, rtc::SocketAddress addr) { + return e->address() == addr; +} +TurnEntry* TurnPort::FindEntry(const rtc::SocketAddress& addr) const { + EntryList::const_iterator it = std::find_if(entries_.begin(), entries_.end(), + std::bind2nd(std::ptr_fun(MatchesAddress), addr)); + return (it != entries_.end()) ? *it : NULL; +} + +static bool MatchesChannelId(TurnEntry* e, int id) { + return e->channel_id() == id; +} +TurnEntry* TurnPort::FindEntry(int channel_id) const { + EntryList::const_iterator it = std::find_if(entries_.begin(), entries_.end(), + std::bind2nd(std::ptr_fun(MatchesChannelId), channel_id)); + return (it != entries_.end()) ? *it : NULL; +} + +TurnEntry* TurnPort::CreateEntry(const rtc::SocketAddress& addr) { + ASSERT(FindEntry(addr) == NULL); + TurnEntry* entry = new TurnEntry(this, next_channel_number_++, addr); + entries_.push_back(entry); + return entry; +} + +void TurnPort::DestroyEntry(const rtc::SocketAddress& addr) { + TurnEntry* entry = FindEntry(addr); + ASSERT(entry != NULL); + entry->SignalDestroyed(entry); + entries_.remove(entry); + delete entry; +} + +void TurnPort::OnConnectionDestroyed(Connection* conn) { + // Destroying TurnEntry for the connection, which is already destroyed. + DestroyEntry(conn->remote_candidate().address()); +} + +TurnAllocateRequest::TurnAllocateRequest(TurnPort* port) + : StunRequest(new TurnMessage()), + port_(port) { +} + +void TurnAllocateRequest::Prepare(StunMessage* request) { + // Create the request as indicated in RFC 5766, Section 6.1. + request->SetType(TURN_ALLOCATE_REQUEST); + StunUInt32Attribute* transport_attr = StunAttribute::CreateUInt32( + STUN_ATTR_REQUESTED_TRANSPORT); + transport_attr->SetValue(IPPROTO_UDP << 24); + VERIFY(request->AddAttribute(transport_attr)); + if (!port_->hash().empty()) { + port_->AddRequestAuthInfo(request); + } +} + +void TurnAllocateRequest::OnSent() { + LOG_J(LS_INFO, port_) << "TURN allocate request sent" + << ", id=" << rtc::hex_encode(id()); + StunRequest::OnSent(); +} + +void TurnAllocateRequest::OnResponse(StunMessage* response) { + LOG_J(LS_INFO, port_) << "TURN allocate requested successfully" + << ", id=" << rtc::hex_encode(id()) + << ", code=0" // Makes logging easier to parse. + << ", rtt=" << Elapsed(); + + // Check mandatory attributes as indicated in RFC5766, Section 6.3. + const StunAddressAttribute* mapped_attr = + response->GetAddress(STUN_ATTR_XOR_MAPPED_ADDRESS); + if (!mapped_attr) { + LOG_J(LS_WARNING, port_) << "Missing STUN_ATTR_XOR_MAPPED_ADDRESS " + << "attribute in allocate success response"; + return; + } + // Using XOR-Mapped-Address for stun. + port_->OnStunAddress(mapped_attr->GetAddress()); + + const StunAddressAttribute* relayed_attr = + response->GetAddress(STUN_ATTR_XOR_RELAYED_ADDRESS); + if (!relayed_attr) { + LOG_J(LS_WARNING, port_) << "Missing STUN_ATTR_XOR_RELAYED_ADDRESS " + << "attribute in allocate success response"; + return; + } + + const StunUInt32Attribute* lifetime_attr = + response->GetUInt32(STUN_ATTR_TURN_LIFETIME); + if (!lifetime_attr) { + LOG_J(LS_WARNING, port_) << "Missing STUN_ATTR_TURN_LIFETIME attribute in " + << "allocate success response"; + return; + } + // Notify the port the allocate succeeded, and schedule a refresh request. + port_->OnAllocateSuccess(relayed_attr->GetAddress(), + mapped_attr->GetAddress()); + port_->ScheduleRefresh(lifetime_attr->value()); +} + +void TurnAllocateRequest::OnErrorResponse(StunMessage* response) { + // Process error response according to RFC5766, Section 6.4. + const StunErrorCodeAttribute* error_code = response->GetErrorCode(); + + LOG_J(LS_INFO, port_) << "Received TURN allocate error response" + << ", id=" << rtc::hex_encode(id()) + << ", code=" << error_code->code() + << ", rtt=" << Elapsed(); + + switch (error_code->code()) { + case STUN_ERROR_UNAUTHORIZED: // Unauthrorized. + OnAuthChallenge(response, error_code->code()); + break; + case STUN_ERROR_TRY_ALTERNATE: + OnTryAlternate(response, error_code->code()); + break; + case STUN_ERROR_ALLOCATION_MISMATCH: + // We must handle this error async because trying to delete the socket in + // OnErrorResponse will cause a deadlock on the socket. + port_->thread()->Post(port_, TurnPort::MSG_ALLOCATE_MISMATCH); + break; + default: + LOG_J(LS_WARNING, port_) << "Received TURN allocate error response" + << ", id=" << rtc::hex_encode(id()) + << ", code=" << error_code->code() + << ", rtt=" << Elapsed(); + port_->OnAllocateError(); + } +} + +void TurnAllocateRequest::OnTimeout() { + LOG_J(LS_WARNING, port_) << "TURN allocate request " + << rtc::hex_encode(id()) << " timout"; + port_->OnAllocateRequestTimeout(); +} + +void TurnAllocateRequest::OnAuthChallenge(StunMessage* response, int code) { + // If we failed to authenticate even after we sent our credentials, fail hard. + if (code == STUN_ERROR_UNAUTHORIZED && !port_->hash().empty()) { + LOG_J(LS_WARNING, port_) << "Failed to authenticate with the server " + << "after challenge."; + port_->OnAllocateError(); + return; + } + + // Check the mandatory attributes. + const StunByteStringAttribute* realm_attr = + response->GetByteString(STUN_ATTR_REALM); + if (!realm_attr) { + LOG_J(LS_WARNING, port_) << "Missing STUN_ATTR_REALM attribute in " + << "allocate unauthorized response."; + return; + } + port_->set_realm(realm_attr->GetString()); + + const StunByteStringAttribute* nonce_attr = + response->GetByteString(STUN_ATTR_NONCE); + if (!nonce_attr) { + LOG_J(LS_WARNING, port_) << "Missing STUN_ATTR_NONCE attribute in " + << "allocate unauthorized response."; + return; + } + port_->set_nonce(nonce_attr->GetString()); + + // Send another allocate request, with the received realm and nonce values. + port_->SendRequest(new TurnAllocateRequest(port_), 0); +} + +void TurnAllocateRequest::OnTryAlternate(StunMessage* response, int code) { + + // According to RFC 5389 section 11, there are use cases where + // authentication of response is not possible, we're not validating + // message integrity. + + // Get the alternate server address attribute value. + const StunAddressAttribute* alternate_server_attr = + response->GetAddress(STUN_ATTR_ALTERNATE_SERVER); + if (!alternate_server_attr) { + LOG_J(LS_WARNING, port_) << "Missing STUN_ATTR_ALTERNATE_SERVER " + << "attribute in try alternate error response"; + port_->OnAllocateError(); + return; + } + if (!port_->SetAlternateServer(alternate_server_attr->GetAddress())) { + port_->OnAllocateError(); + return; + } + + // Check the attributes. + const StunByteStringAttribute* realm_attr = + response->GetByteString(STUN_ATTR_REALM); + if (realm_attr) { + LOG_J(LS_INFO, port_) << "Applying STUN_ATTR_REALM attribute in " + << "try alternate error response."; + port_->set_realm(realm_attr->GetString()); + } + + const StunByteStringAttribute* nonce_attr = + response->GetByteString(STUN_ATTR_NONCE); + if (nonce_attr) { + LOG_J(LS_INFO, port_) << "Applying STUN_ATTR_NONCE attribute in " + << "try alternate error response."; + port_->set_nonce(nonce_attr->GetString()); + } + + // For TCP, we can't close the original Tcp socket during handling a 300 as + // we're still inside that socket's event handler. Doing so will cause + // deadlock. + port_->thread()->Post(port_, TurnPort::MSG_TRY_ALTERNATE_SERVER); +} + +TurnRefreshRequest::TurnRefreshRequest(TurnPort* port) + : StunRequest(new TurnMessage()), + port_(port), + lifetime_(-1) { +} + +void TurnRefreshRequest::Prepare(StunMessage* request) { + // Create the request as indicated in RFC 5766, Section 7.1. + // No attributes need to be included. + request->SetType(TURN_REFRESH_REQUEST); + if (lifetime_ > -1) { + VERIFY(request->AddAttribute(new StunUInt32Attribute( + STUN_ATTR_LIFETIME, lifetime_))); + } + + port_->AddRequestAuthInfo(request); +} + +void TurnRefreshRequest::OnSent() { + LOG_J(LS_INFO, port_) << "TURN refresh request sent" + << ", id=" << rtc::hex_encode(id()); + StunRequest::OnSent(); +} + +void TurnRefreshRequest::OnResponse(StunMessage* response) { + LOG_J(LS_INFO, port_) << "TURN refresh requested successfully" + << ", id=" << rtc::hex_encode(id()) + << ", code=0" // Makes logging easier to parse. + << ", rtt=" << Elapsed(); + + // Check mandatory attributes as indicated in RFC5766, Section 7.3. + const StunUInt32Attribute* lifetime_attr = + response->GetUInt32(STUN_ATTR_TURN_LIFETIME); + if (!lifetime_attr) { + LOG_J(LS_WARNING, port_) << "Missing STUN_ATTR_TURN_LIFETIME attribute in " + << "refresh success response."; + return; + } + + // Schedule a refresh based on the returned lifetime value. + port_->ScheduleRefresh(lifetime_attr->value()); +} + +void TurnRefreshRequest::OnErrorResponse(StunMessage* response) { + const StunErrorCodeAttribute* error_code = response->GetErrorCode(); + + LOG_J(LS_INFO, port_) << "Received TURN refresh error response" + << ", id=" << rtc::hex_encode(id()) + << ", code=" << error_code->code() + << ", rtt=" << Elapsed(); + + if (error_code->code() == STUN_ERROR_STALE_NONCE) { + if (port_->UpdateNonce(response)) { + // Send RefreshRequest immediately. + port_->SendRequest(new TurnRefreshRequest(port_), 0); + } + } else { + LOG_J(LS_WARNING, port_) << "Received TURN refresh error response" + << ", id=" << rtc::hex_encode(id()) + << ", code=" << error_code->code() + << ", rtt=" << Elapsed(); + } +} + +void TurnRefreshRequest::OnTimeout() { + LOG_J(LS_WARNING, port_) << "TURN refresh timeout " << rtc::hex_encode(id()); +} + +TurnCreatePermissionRequest::TurnCreatePermissionRequest( + TurnPort* port, TurnEntry* entry, + const rtc::SocketAddress& ext_addr) + : StunRequest(new TurnMessage()), + port_(port), + entry_(entry), + ext_addr_(ext_addr) { + entry_->SignalDestroyed.connect( + this, &TurnCreatePermissionRequest::OnEntryDestroyed); +} + +void TurnCreatePermissionRequest::Prepare(StunMessage* request) { + // Create the request as indicated in RFC5766, Section 9.1. + request->SetType(TURN_CREATE_PERMISSION_REQUEST); + VERIFY(request->AddAttribute(new StunXorAddressAttribute( + STUN_ATTR_XOR_PEER_ADDRESS, ext_addr_))); + port_->AddRequestAuthInfo(request); +} + +void TurnCreatePermissionRequest::OnSent() { + LOG_J(LS_INFO, port_) << "TURN create permission request sent" + << ", id=" << rtc::hex_encode(id()); + StunRequest::OnSent(); +} + +void TurnCreatePermissionRequest::OnResponse(StunMessage* response) { + LOG_J(LS_INFO, port_) << "TURN permission requested successfully" + << ", id=" << rtc::hex_encode(id()) + << ", code=0" // Makes logging easier to parse. + << ", rtt=" << Elapsed(); + + if (entry_) { + entry_->OnCreatePermissionSuccess(); + } +} + +void TurnCreatePermissionRequest::OnErrorResponse(StunMessage* response) { + const StunErrorCodeAttribute* error_code = response->GetErrorCode(); + LOG_J(LS_WARNING, port_) << "Received TURN create permission error response" + << ", id=" << rtc::hex_encode(id()) + << ", code=" << error_code->code() + << ", rtt=" << Elapsed(); + if (entry_) { + entry_->OnCreatePermissionError(response, error_code->code()); + } +} + +void TurnCreatePermissionRequest::OnTimeout() { + LOG_J(LS_WARNING, port_) << "TURN create permission timeout " + << rtc::hex_encode(id()); +} + +void TurnCreatePermissionRequest::OnEntryDestroyed(TurnEntry* entry) { + ASSERT(entry_ == entry); + entry_ = NULL; +} + +TurnChannelBindRequest::TurnChannelBindRequest( + TurnPort* port, TurnEntry* entry, + int channel_id, const rtc::SocketAddress& ext_addr) + : StunRequest(new TurnMessage()), + port_(port), + entry_(entry), + channel_id_(channel_id), + ext_addr_(ext_addr) { + entry_->SignalDestroyed.connect( + this, &TurnChannelBindRequest::OnEntryDestroyed); +} + +void TurnChannelBindRequest::Prepare(StunMessage* request) { + // Create the request as indicated in RFC5766, Section 11.1. + request->SetType(TURN_CHANNEL_BIND_REQUEST); + VERIFY(request->AddAttribute(new StunUInt32Attribute( + STUN_ATTR_CHANNEL_NUMBER, channel_id_ << 16))); + VERIFY(request->AddAttribute(new StunXorAddressAttribute( + STUN_ATTR_XOR_PEER_ADDRESS, ext_addr_))); + port_->AddRequestAuthInfo(request); +} + +void TurnChannelBindRequest::OnSent() { + LOG_J(LS_INFO, port_) << "TURN channel bind request sent" + << ", id=" << rtc::hex_encode(id()); + StunRequest::OnSent(); +} + +void TurnChannelBindRequest::OnResponse(StunMessage* response) { + LOG_J(LS_INFO, port_) << "TURN channel bind requested successfully" + << ", id=" << rtc::hex_encode(id()) + << ", code=0" // Makes logging easier to parse. + << ", rtt=" << Elapsed(); + + if (entry_) { + entry_->OnChannelBindSuccess(); + // Refresh the channel binding just under the permission timeout + // threshold. The channel binding has a longer lifetime, but + // this is the easiest way to keep both the channel and the + // permission from expiring. + int delay = TURN_PERMISSION_TIMEOUT - 60000; + entry_->SendChannelBindRequest(delay); + LOG_J(LS_INFO, port_) << "Scheduled channel bind in " << delay << "ms."; + } +} + +void TurnChannelBindRequest::OnErrorResponse(StunMessage* response) { + const StunErrorCodeAttribute* error_code = response->GetErrorCode(); + LOG_J(LS_WARNING, port_) << "Received TURN channel bind error response" + << ", id=" << rtc::hex_encode(id()) + << ", code=" << error_code->code() + << ", rtt=" << Elapsed(); + if (entry_) { + entry_->OnChannelBindError(response, error_code->code()); + } +} + +void TurnChannelBindRequest::OnTimeout() { + LOG_J(LS_WARNING, port_) << "TURN channel bind timeout " + << rtc::hex_encode(id()); +} + +void TurnChannelBindRequest::OnEntryDestroyed(TurnEntry* entry) { + ASSERT(entry_ == entry); + entry_ = NULL; +} + +TurnEntry::TurnEntry(TurnPort* port, int channel_id, + const rtc::SocketAddress& ext_addr) + : port_(port), + channel_id_(channel_id), + ext_addr_(ext_addr), + state_(STATE_UNBOUND) { + // Creating permission for |ext_addr_|. + SendCreatePermissionRequest(); +} + +void TurnEntry::SendCreatePermissionRequest() { + port_->SendRequest(new TurnCreatePermissionRequest( + port_, this, ext_addr_), 0); +} + +void TurnEntry::SendChannelBindRequest(int delay) { + port_->SendRequest(new TurnChannelBindRequest( + port_, this, channel_id_, ext_addr_), delay); +} + +int TurnEntry::Send(const void* data, size_t size, bool payload, + const rtc::PacketOptions& options) { + rtc::ByteBuffer buf; + if (state_ != STATE_BOUND) { + // If we haven't bound the channel yet, we have to use a Send Indication. + TurnMessage msg; + msg.SetType(TURN_SEND_INDICATION); + msg.SetTransactionID( + rtc::CreateRandomString(kStunTransactionIdLength)); + VERIFY(msg.AddAttribute(new StunXorAddressAttribute( + STUN_ATTR_XOR_PEER_ADDRESS, ext_addr_))); + VERIFY(msg.AddAttribute(new StunByteStringAttribute( + STUN_ATTR_DATA, data, size))); + VERIFY(msg.Write(&buf)); + + // If we're sending real data, request a channel bind that we can use later. + if (state_ == STATE_UNBOUND && payload) { + SendChannelBindRequest(0); + state_ = STATE_BINDING; + } + } else { + // If the channel is bound, we can send the data as a Channel Message. + buf.WriteUInt16(channel_id_); + buf.WriteUInt16(static_cast<uint16_t>(size)); + buf.WriteBytes(reinterpret_cast<const char*>(data), size); + } + return port_->Send(buf.Data(), buf.Length(), options); +} + +void TurnEntry::OnCreatePermissionSuccess() { + LOG_J(LS_INFO, port_) << "Create permission for " + << ext_addr_.ToSensitiveString() + << " succeeded"; + // For success result code will be 0. + port_->SignalCreatePermissionResult(port_, ext_addr_, 0); +} + +void TurnEntry::OnCreatePermissionError(StunMessage* response, int code) { + if (code == STUN_ERROR_STALE_NONCE) { + if (port_->UpdateNonce(response)) { + SendCreatePermissionRequest(); + } + } else { + // Send signal with error code. + port_->SignalCreatePermissionResult(port_, ext_addr_, code); + } +} + +void TurnEntry::OnChannelBindSuccess() { + LOG_J(LS_INFO, port_) << "Channel bind for " << ext_addr_.ToSensitiveString() + << " succeeded"; + ASSERT(state_ == STATE_BINDING || state_ == STATE_BOUND); + state_ = STATE_BOUND; +} + +void TurnEntry::OnChannelBindError(StunMessage* response, int code) { + // TODO(mallinath) - Implement handling of error response for channel + // bind request as per http://tools.ietf.org/html/rfc5766#section-11.3 + if (code == STUN_ERROR_STALE_NONCE) { + if (port_->UpdateNonce(response)) { + // Send channel bind request with fresh nonce. + SendChannelBindRequest(0); + } + } +} + +} // namespace cricket diff --git a/webrtc/p2p/base/turnport.h b/webrtc/p2p/base/turnport.h new file mode 100644 index 0000000000..3bca727346 --- /dev/null +++ b/webrtc/p2p/base/turnport.h @@ -0,0 +1,254 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TURNPORT_H_ +#define WEBRTC_P2P_BASE_TURNPORT_H_ + +#include <stdio.h> +#include <list> +#include <set> +#include <string> + +#include "webrtc/p2p/base/port.h" +#include "webrtc/p2p/client/basicportallocator.h" +#include "webrtc/base/asyncpacketsocket.h" + +namespace rtc { +class AsyncResolver; +class SignalThread; +} + +namespace cricket { + +extern const char TURN_PORT_TYPE[]; +class TurnAllocateRequest; +class TurnEntry; + +class TurnPort : public Port { + public: + enum PortState { + STATE_CONNECTING, // Initial state, cannot send any packets. + STATE_CONNECTED, // Socket connected, ready to send stun requests. + STATE_READY, // Received allocate success, can send any packets. + STATE_DISCONNECTED, // TCP connection died, cannot send any packets. + }; + static TurnPort* Create(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + rtc::AsyncPacketSocket* socket, + const std::string& username, // ice username. + const std::string& password, // ice password. + const ProtocolAddress& server_address, + const RelayCredentials& credentials, + int server_priority, + const std::string& origin) { + return new TurnPort(thread, factory, network, socket, username, password, + server_address, credentials, server_priority, origin); + } + + static TurnPort* Create(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, // ice username. + const std::string& password, // ice password. + const ProtocolAddress& server_address, + const RelayCredentials& credentials, + int server_priority, + const std::string& origin) { + return new TurnPort(thread, factory, network, ip, min_port, max_port, + username, password, server_address, credentials, + server_priority, origin); + } + + virtual ~TurnPort(); + + const ProtocolAddress& server_address() const { return server_address_; } + // Returns an empty address if the local address has not been assigned. + rtc::SocketAddress GetLocalAddress() const; + + bool ready() const { return state_ == STATE_READY; } + bool connected() const { + return state_ == STATE_READY || state_ == STATE_CONNECTED; + } + const RelayCredentials& credentials() const { return credentials_; } + + virtual void PrepareAddress(); + virtual Connection* CreateConnection( + const Candidate& c, PortInterface::CandidateOrigin origin); + virtual int SendTo(const void* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketOptions& options, + bool payload); + virtual int SetOption(rtc::Socket::Option opt, int value); + virtual int GetOption(rtc::Socket::Option opt, int* value); + virtual int GetError(); + + virtual bool HandleIncomingPacket( + rtc::AsyncPacketSocket* socket, const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + OnReadPacket(socket, data, size, remote_addr, packet_time); + return true; + } + virtual void OnReadPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time); + + virtual void OnReadyToSend(rtc::AsyncPacketSocket* socket); + + void OnSocketConnect(rtc::AsyncPacketSocket* socket); + void OnSocketClose(rtc::AsyncPacketSocket* socket, int error); + + + const std::string& hash() const { return hash_; } + const std::string& nonce() const { return nonce_; } + + int error() const { return error_; } + + void OnAllocateMismatch(); + + rtc::AsyncPacketSocket* socket() const { + return socket_; + } + + // Signal with resolved server address. + // Parameters are port, server address and resolved server address. + // This signal will be sent only if server address is resolved successfully. + sigslot::signal3<TurnPort*, + const rtc::SocketAddress&, + const rtc::SocketAddress&> SignalResolvedServerAddress; + + // This signal is only for testing purpose. + sigslot::signal3<TurnPort*, const rtc::SocketAddress&, int> + SignalCreatePermissionResult; + + protected: + TurnPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + rtc::AsyncPacketSocket* socket, + const std::string& username, + const std::string& password, + const ProtocolAddress& server_address, + const RelayCredentials& credentials, + int server_priority, + const std::string& origin); + + TurnPort(rtc::Thread* thread, + rtc::PacketSocketFactory* factory, + rtc::Network* network, + const rtc::IPAddress& ip, + uint16_t min_port, + uint16_t max_port, + const std::string& username, + const std::string& password, + const ProtocolAddress& server_address, + const RelayCredentials& credentials, + int server_priority, + const std::string& origin); + + private: + enum { + MSG_ERROR = MSG_FIRST_AVAILABLE, + MSG_ALLOCATE_MISMATCH, + MSG_TRY_ALTERNATE_SERVER + }; + + typedef std::list<TurnEntry*> EntryList; + typedef std::map<rtc::Socket::Option, int> SocketOptionsMap; + typedef std::set<rtc::SocketAddress> AttemptedServerSet; + + virtual void OnMessage(rtc::Message* pmsg); + + bool CreateTurnClientSocket(); + + void set_nonce(const std::string& nonce) { nonce_ = nonce; } + void set_realm(const std::string& realm) { + if (realm != realm_) { + realm_ = realm; + UpdateHash(); + } + } + + bool SetAlternateServer(const rtc::SocketAddress& address); + void ResolveTurnAddress(const rtc::SocketAddress& address); + void OnResolveResult(rtc::AsyncResolverInterface* resolver); + + void AddRequestAuthInfo(StunMessage* msg); + void OnSendStunPacket(const void* data, size_t size, StunRequest* request); + // Stun address from allocate success response. + // Currently used only for testing. + void OnStunAddress(const rtc::SocketAddress& address); + void OnAllocateSuccess(const rtc::SocketAddress& address, + const rtc::SocketAddress& stun_address); + void OnAllocateError(); + void OnAllocateRequestTimeout(); + + void HandleDataIndication(const char* data, size_t size, + const rtc::PacketTime& packet_time); + void HandleChannelData(int channel_id, const char* data, size_t size, + const rtc::PacketTime& packet_time); + void DispatchPacket(const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + ProtocolType proto, const rtc::PacketTime& packet_time); + + bool ScheduleRefresh(int lifetime); + void SendRequest(StunRequest* request, int delay); + int Send(const void* data, size_t size, + const rtc::PacketOptions& options); + void UpdateHash(); + bool UpdateNonce(StunMessage* response); + + bool HasPermission(const rtc::IPAddress& ipaddr) const; + TurnEntry* FindEntry(const rtc::SocketAddress& address) const; + TurnEntry* FindEntry(int channel_id) const; + TurnEntry* CreateEntry(const rtc::SocketAddress& address); + void DestroyEntry(const rtc::SocketAddress& address); + void OnConnectionDestroyed(Connection* conn); + + ProtocolAddress server_address_; + RelayCredentials credentials_; + AttemptedServerSet attempted_server_addresses_; + + rtc::AsyncPacketSocket* socket_; + SocketOptionsMap socket_options_; + rtc::AsyncResolverInterface* resolver_; + int error_; + + StunRequestManager request_manager_; + std::string realm_; // From 401/438 response message. + std::string nonce_; // From 401/438 response message. + std::string hash_; // Digest of username:realm:password + + int next_channel_number_; + EntryList entries_; + + PortState state_; + // By default the value will be set to 0. This value will be used in + // calculating the candidate priority. + int server_priority_; + + // The number of retries made due to allocate mismatch error. + size_t allocate_mismatch_retries_; + + friend class TurnEntry; + friend class TurnAllocateRequest; + friend class TurnRefreshRequest; + friend class TurnCreatePermissionRequest; + friend class TurnChannelBindRequest; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TURNPORT_H_ diff --git a/webrtc/p2p/base/turnport_unittest.cc b/webrtc/p2p/base/turnport_unittest.cc new file mode 100644 index 0000000000..724485ddde --- /dev/null +++ b/webrtc/p2p/base/turnport_unittest.cc @@ -0,0 +1,816 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#if defined(WEBRTC_POSIX) +#include <dirent.h> +#endif + +#include "webrtc/p2p/base/basicpacketsocketfactory.h" +#include "webrtc/p2p/base/constants.h" +#include "webrtc/p2p/base/tcpport.h" +#include "webrtc/p2p/base/testturnserver.h" +#include "webrtc/p2p/base/turnport.h" +#include "webrtc/p2p/base/udpport.h" +#include "webrtc/base/asynctcpsocket.h" +#include "webrtc/base/buffer.h" +#include "webrtc/base/dscp.h" +#include "webrtc/base/firewallsocketserver.h" +#include "webrtc/base/gunit.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/physicalsocketserver.h" +#include "webrtc/base/scoped_ptr.h" +#include "webrtc/base/socketaddress.h" +#include "webrtc/base/ssladapter.h" +#include "webrtc/base/thread.h" +#include "webrtc/base/virtualsocketserver.h" + +using rtc::SocketAddress; +using cricket::Connection; +using cricket::Port; +using cricket::PortInterface; +using cricket::TurnPort; +using cricket::UDPPort; + +static const SocketAddress kLocalAddr1("11.11.11.11", 0); +static const SocketAddress kLocalAddr2("22.22.22.22", 0); +static const SocketAddress kLocalIPv6Addr( + "2401:fa00:4:1000:be30:5bff:fee5:c3", 0); +static const SocketAddress kTurnUdpIntAddr("99.99.99.3", + cricket::TURN_SERVER_PORT); +static const SocketAddress kTurnTcpIntAddr("99.99.99.4", + cricket::TURN_SERVER_PORT); +static const SocketAddress kTurnUdpExtAddr("99.99.99.5", 0); +static const SocketAddress kTurnAlternateIntAddr("99.99.99.6", + cricket::TURN_SERVER_PORT); +static const SocketAddress kTurnIntAddr("99.99.99.7", + cricket::TURN_SERVER_PORT); +static const SocketAddress kTurnIPv6IntAddr( + "2400:4030:2:2c00:be30:abcd:efab:cdef", + cricket::TURN_SERVER_PORT); +static const SocketAddress kTurnUdpIPv6IntAddr( + "2400:4030:1:2c00:be30:abcd:efab:cdef", cricket::TURN_SERVER_PORT); +static const SocketAddress kTurnUdpIPv6ExtAddr( + "2620:0:1000:1b03:2e41:38ff:fea6:f2a4", 0); + +static const char kIceUfrag1[] = "TESTICEUFRAG0001"; +static const char kIceUfrag2[] = "TESTICEUFRAG0002"; +static const char kIcePwd1[] = "TESTICEPWD00000000000001"; +static const char kIcePwd2[] = "TESTICEPWD00000000000002"; +static const char kTurnUsername[] = "test"; +static const char kTurnPassword[] = "test"; +static const char kTestOrigin[] = "http://example.com"; +static const unsigned int kTimeout = 1000; + +static const cricket::ProtocolAddress kTurnUdpProtoAddr( + kTurnUdpIntAddr, cricket::PROTO_UDP); +static const cricket::ProtocolAddress kTurnTcpProtoAddr( + kTurnTcpIntAddr, cricket::PROTO_TCP); +static const cricket::ProtocolAddress kTurnUdpIPv6ProtoAddr( + kTurnUdpIPv6IntAddr, cricket::PROTO_UDP); + +static const unsigned int MSG_TESTFINISH = 0; + +#if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID) +static int GetFDCount() { + struct dirent *dp; + int fd_count = 0; + DIR *dir = opendir("/proc/self/fd/"); + while ((dp = readdir(dir)) != NULL) { + if (dp->d_name[0] == '.') + continue; + ++fd_count; + } + closedir(dir); + return fd_count; +} +#endif + +class TurnPortTestVirtualSocketServer : public rtc::VirtualSocketServer { + public: + explicit TurnPortTestVirtualSocketServer(SocketServer* ss) + : VirtualSocketServer(ss) {} + + using rtc::VirtualSocketServer::LookupBinding; +}; + +class TurnPortTest : public testing::Test, + public sigslot::has_slots<>, + public rtc::MessageHandler { + public: + TurnPortTest() + : main_(rtc::Thread::Current()), + pss_(new rtc::PhysicalSocketServer), + ss_(new TurnPortTestVirtualSocketServer(pss_.get())), + ss_scope_(ss_.get()), + network_("unittest", "unittest", rtc::IPAddress(INADDR_ANY), 32), + socket_factory_(rtc::Thread::Current()), + turn_server_(main_, kTurnUdpIntAddr, kTurnUdpExtAddr), + turn_ready_(false), + turn_error_(false), + turn_unknown_address_(false), + turn_create_permission_success_(false), + udp_ready_(false), + test_finish_(false) { + network_.AddIP(rtc::IPAddress(INADDR_ANY)); + } + + virtual void OnMessage(rtc::Message* msg) { + ASSERT(msg->message_id == MSG_TESTFINISH); + if (msg->message_id == MSG_TESTFINISH) + test_finish_ = true; + } + + void ConnectSignalAddressReadyToSetLocalhostAsAltenertativeLocalAddress() { + rtc::AsyncPacketSocket* socket = turn_port_->socket(); + rtc::VirtualSocket* virtual_socket = + ss_->LookupBinding(socket->GetLocalAddress()); + virtual_socket->SignalAddressReady.connect( + this, &TurnPortTest::SetLocalhostAsAltenertativeLocalAddress); + } + + void SetLocalhostAsAltenertativeLocalAddress( + rtc::VirtualSocket* socket, + const rtc::SocketAddress& address) { + SocketAddress local_address("127.0.0.1", 2000); + socket->SetAlternativeLocalAddress(local_address); + } + + void OnTurnPortComplete(Port* port) { + turn_ready_ = true; + } + void OnTurnPortError(Port* port) { + turn_error_ = true; + } + void OnTurnUnknownAddress(PortInterface* port, const SocketAddress& addr, + cricket::ProtocolType proto, + cricket::IceMessage* msg, const std::string& rf, + bool /*port_muxed*/) { + turn_unknown_address_ = true; + } + void OnTurnCreatePermissionResult(TurnPort* port, const SocketAddress& addr, + int code) { + // Ignoring the address. + if (code == 0) { + turn_create_permission_success_ = true; + } + } + void OnTurnReadPacket(Connection* conn, const char* data, size_t size, + const rtc::PacketTime& packet_time) { + turn_packets_.push_back(rtc::Buffer(data, size)); + } + void OnUdpPortComplete(Port* port) { + udp_ready_ = true; + } + void OnUdpReadPacket(Connection* conn, const char* data, size_t size, + const rtc::PacketTime& packet_time) { + udp_packets_.push_back(rtc::Buffer(data, size)); + } + void OnSocketReadPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + turn_port_->HandleIncomingPacket(socket, data, size, remote_addr, + packet_time); + } + rtc::AsyncSocket* CreateServerSocket(const SocketAddress addr) { + rtc::AsyncSocket* socket = ss_->CreateAsyncSocket(SOCK_STREAM); + EXPECT_GE(socket->Bind(addr), 0); + EXPECT_GE(socket->Listen(5), 0); + return socket; + } + + void CreateTurnPort(const std::string& username, + const std::string& password, + const cricket::ProtocolAddress& server_address) { + CreateTurnPort(kLocalAddr1, username, password, server_address); + } + void CreateTurnPort(const rtc::SocketAddress& local_address, + const std::string& username, + const std::string& password, + const cricket::ProtocolAddress& server_address) { + cricket::RelayCredentials credentials(username, password); + turn_port_.reset(TurnPort::Create(main_, &socket_factory_, &network_, + local_address.ipaddr(), 0, 0, + kIceUfrag1, kIcePwd1, + server_address, credentials, 0, + std::string())); + // This TURN port will be the controlling. + turn_port_->SetIceRole(cricket::ICEROLE_CONTROLLING); + ConnectSignals(); + } + + // Should be identical to CreateTurnPort but specifies an origin value + // when creating the instance of TurnPort. + void CreateTurnPortWithOrigin(const rtc::SocketAddress& local_address, + const std::string& username, + const std::string& password, + const cricket::ProtocolAddress& server_address, + const std::string& origin) { + cricket::RelayCredentials credentials(username, password); + turn_port_.reset(TurnPort::Create(main_, &socket_factory_, &network_, + local_address.ipaddr(), 0, 0, + kIceUfrag1, kIcePwd1, + server_address, credentials, 0, + origin)); + // This TURN port will be the controlling. + turn_port_->SetIceRole(cricket::ICEROLE_CONTROLLING); + ConnectSignals(); + } + + void CreateSharedTurnPort(const std::string& username, + const std::string& password, + const cricket::ProtocolAddress& server_address) { + ASSERT(server_address.proto == cricket::PROTO_UDP); + + if (!socket_) { + socket_.reset(socket_factory_.CreateUdpSocket( + rtc::SocketAddress(kLocalAddr1.ipaddr(), 0), 0, 0)); + ASSERT_TRUE(socket_ != NULL); + socket_->SignalReadPacket.connect( + this, &TurnPortTest::OnSocketReadPacket); + } + + cricket::RelayCredentials credentials(username, password); + turn_port_.reset(cricket::TurnPort::Create( + main_, &socket_factory_, &network_, socket_.get(), + kIceUfrag1, kIcePwd1, server_address, credentials, 0, std::string())); + // This TURN port will be the controlling. + turn_port_->SetIceRole(cricket::ICEROLE_CONTROLLING); + ConnectSignals(); + } + + void ConnectSignals() { + turn_port_->SignalPortComplete.connect(this, + &TurnPortTest::OnTurnPortComplete); + turn_port_->SignalPortError.connect(this, + &TurnPortTest::OnTurnPortError); + turn_port_->SignalUnknownAddress.connect(this, + &TurnPortTest::OnTurnUnknownAddress); + turn_port_->SignalCreatePermissionResult.connect(this, + &TurnPortTest::OnTurnCreatePermissionResult); + } + void CreateUdpPort() { + udp_port_.reset(UDPPort::Create(main_, &socket_factory_, &network_, + kLocalAddr2.ipaddr(), 0, 0, + kIceUfrag2, kIcePwd2, + std::string(), false)); + // UDP port will be controlled. + udp_port_->SetIceRole(cricket::ICEROLE_CONTROLLED); + udp_port_->SignalPortComplete.connect( + this, &TurnPortTest::OnUdpPortComplete); + } + + void TestTurnAlternateServer(cricket::ProtocolType protocol_type) { + std::vector<rtc::SocketAddress> redirect_addresses; + redirect_addresses.push_back(kTurnAlternateIntAddr); + + cricket::TestTurnRedirector redirector(redirect_addresses); + + turn_server_.AddInternalSocket(kTurnIntAddr, protocol_type); + turn_server_.AddInternalSocket(kTurnAlternateIntAddr, protocol_type); + turn_server_.set_redirect_hook(&redirector); + CreateTurnPort(kTurnUsername, kTurnPassword, + cricket::ProtocolAddress(kTurnIntAddr, protocol_type)); + + // Retrieve the address before we run the state machine. + const SocketAddress old_addr = turn_port_->server_address().address; + + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout * 100); + // Retrieve the address again, the turn port's address should be + // changed. + const SocketAddress new_addr = turn_port_->server_address().address; + EXPECT_NE(old_addr, new_addr); + ASSERT_EQ(1U, turn_port_->Candidates().size()); + EXPECT_EQ(kTurnUdpExtAddr.ipaddr(), + turn_port_->Candidates()[0].address().ipaddr()); + EXPECT_NE(0, turn_port_->Candidates()[0].address().port()); + } + + void TestTurnAlternateServerV4toV6(cricket::ProtocolType protocol_type) { + std::vector<rtc::SocketAddress> redirect_addresses; + redirect_addresses.push_back(kTurnIPv6IntAddr); + + cricket::TestTurnRedirector redirector(redirect_addresses); + turn_server_.AddInternalSocket(kTurnIntAddr, protocol_type); + turn_server_.set_redirect_hook(&redirector); + CreateTurnPort(kTurnUsername, kTurnPassword, + cricket::ProtocolAddress(kTurnIntAddr, protocol_type)); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_error_, kTimeout); + } + + void TestTurnAlternateServerPingPong(cricket::ProtocolType protocol_type) { + std::vector<rtc::SocketAddress> redirect_addresses; + redirect_addresses.push_back(kTurnAlternateIntAddr); + redirect_addresses.push_back(kTurnIntAddr); + + cricket::TestTurnRedirector redirector(redirect_addresses); + + turn_server_.AddInternalSocket(kTurnIntAddr, protocol_type); + turn_server_.AddInternalSocket(kTurnAlternateIntAddr, protocol_type); + turn_server_.set_redirect_hook(&redirector); + CreateTurnPort(kTurnUsername, kTurnPassword, + cricket::ProtocolAddress(kTurnIntAddr, protocol_type)); + + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_error_, kTimeout); + ASSERT_EQ(0U, turn_port_->Candidates().size()); + rtc::SocketAddress address; + // Verify that we have exhausted all alternate servers instead of + // failure caused by other errors. + EXPECT_FALSE(redirector.ShouldRedirect(address, &address)); + } + + void TestTurnAlternateServerDetectRepetition( + cricket::ProtocolType protocol_type) { + std::vector<rtc::SocketAddress> redirect_addresses; + redirect_addresses.push_back(kTurnAlternateIntAddr); + redirect_addresses.push_back(kTurnAlternateIntAddr); + + cricket::TestTurnRedirector redirector(redirect_addresses); + + turn_server_.AddInternalSocket(kTurnIntAddr, protocol_type); + turn_server_.AddInternalSocket(kTurnAlternateIntAddr, protocol_type); + turn_server_.set_redirect_hook(&redirector); + CreateTurnPort(kTurnUsername, kTurnPassword, + cricket::ProtocolAddress(kTurnIntAddr, protocol_type)); + + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_error_, kTimeout); + ASSERT_EQ(0U, turn_port_->Candidates().size()); + } + + void TestTurnConnection() { + // Create ports and prepare addresses. + ASSERT_TRUE(turn_port_ != NULL); + turn_port_->PrepareAddress(); + ASSERT_TRUE_WAIT(turn_ready_, kTimeout); + CreateUdpPort(); + udp_port_->PrepareAddress(); + ASSERT_TRUE_WAIT(udp_ready_, kTimeout); + + // Send ping from UDP to TURN. + Connection* conn1 = udp_port_->CreateConnection( + turn_port_->Candidates()[0], Port::ORIGIN_MESSAGE); + ASSERT_TRUE(conn1 != NULL); + conn1->Ping(0); + WAIT(!turn_unknown_address_, kTimeout); + EXPECT_FALSE(turn_unknown_address_); + EXPECT_FALSE(conn1->receiving()); + EXPECT_EQ(Connection::STATE_WRITE_INIT, conn1->write_state()); + + // Send ping from TURN to UDP. + Connection* conn2 = turn_port_->CreateConnection( + udp_port_->Candidates()[0], Port::ORIGIN_MESSAGE); + ASSERT_TRUE(conn2 != NULL); + ASSERT_TRUE_WAIT(turn_create_permission_success_, kTimeout); + conn2->Ping(0); + + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, conn2->write_state(), kTimeout); + EXPECT_TRUE(conn1->receiving()); + EXPECT_TRUE(conn2->receiving()); + EXPECT_EQ(Connection::STATE_WRITE_INIT, conn1->write_state()); + + // Send another ping from UDP to TURN. + conn1->Ping(0); + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, conn1->write_state(), kTimeout); + EXPECT_TRUE(conn2->receiving()); + } + + void TestTurnSendData() { + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + CreateUdpPort(); + udp_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(udp_ready_, kTimeout); + // Create connections and send pings. + Connection* conn1 = turn_port_->CreateConnection( + udp_port_->Candidates()[0], Port::ORIGIN_MESSAGE); + Connection* conn2 = udp_port_->CreateConnection( + turn_port_->Candidates()[0], Port::ORIGIN_MESSAGE); + ASSERT_TRUE(conn1 != NULL); + ASSERT_TRUE(conn2 != NULL); + conn1->SignalReadPacket.connect(static_cast<TurnPortTest*>(this), + &TurnPortTest::OnTurnReadPacket); + conn2->SignalReadPacket.connect(static_cast<TurnPortTest*>(this), + &TurnPortTest::OnUdpReadPacket); + conn1->Ping(0); + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, conn1->write_state(), kTimeout); + conn2->Ping(0); + EXPECT_EQ_WAIT(Connection::STATE_WRITABLE, conn2->write_state(), kTimeout); + + // Send some data. + size_t num_packets = 256; + for (size_t i = 0; i < num_packets; ++i) { + unsigned char buf[256] = { 0 }; + for (size_t j = 0; j < i + 1; ++j) { + buf[j] = 0xFF - static_cast<unsigned char>(j); + } + conn1->Send(buf, i + 1, options); + conn2->Send(buf, i + 1, options); + main_->ProcessMessages(0); + } + + // Check the data. + ASSERT_EQ_WAIT(num_packets, turn_packets_.size(), kTimeout); + ASSERT_EQ_WAIT(num_packets, udp_packets_.size(), kTimeout); + for (size_t i = 0; i < num_packets; ++i) { + EXPECT_EQ(i + 1, turn_packets_[i].size()); + EXPECT_EQ(i + 1, udp_packets_[i].size()); + EXPECT_EQ(turn_packets_[i], udp_packets_[i]); + } + } + + protected: + rtc::Thread* main_; + rtc::scoped_ptr<rtc::PhysicalSocketServer> pss_; + rtc::scoped_ptr<TurnPortTestVirtualSocketServer> ss_; + rtc::SocketServerScope ss_scope_; + rtc::Network network_; + rtc::BasicPacketSocketFactory socket_factory_; + rtc::scoped_ptr<rtc::AsyncPacketSocket> socket_; + cricket::TestTurnServer turn_server_; + rtc::scoped_ptr<TurnPort> turn_port_; + rtc::scoped_ptr<UDPPort> udp_port_; + bool turn_ready_; + bool turn_error_; + bool turn_unknown_address_; + bool turn_create_permission_success_; + bool udp_ready_; + bool test_finish_; + std::vector<rtc::Buffer> turn_packets_; + std::vector<rtc::Buffer> udp_packets_; + rtc::PacketOptions options; +}; + +// Do a normal TURN allocation. +TEST_F(TurnPortTest, TestTurnAllocate) { + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + EXPECT_EQ(0, turn_port_->SetOption(rtc::Socket::OPT_SNDBUF, 10*1024)); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + ASSERT_EQ(1U, turn_port_->Candidates().size()); + EXPECT_EQ(kTurnUdpExtAddr.ipaddr(), + turn_port_->Candidates()[0].address().ipaddr()); + EXPECT_NE(0, turn_port_->Candidates()[0].address().port()); +} + +// Testing a normal UDP allocation using TCP connection. +TEST_F(TurnPortTest, TestTurnTcpAllocate) { + turn_server_.AddInternalSocket(kTurnTcpIntAddr, cricket::PROTO_TCP); + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTcpProtoAddr); + EXPECT_EQ(0, turn_port_->SetOption(rtc::Socket::OPT_SNDBUF, 10*1024)); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + ASSERT_EQ(1U, turn_port_->Candidates().size()); + EXPECT_EQ(kTurnUdpExtAddr.ipaddr(), + turn_port_->Candidates()[0].address().ipaddr()); + EXPECT_NE(0, turn_port_->Candidates()[0].address().port()); +} + +// Test case for WebRTC issue 3927 where a proxy binds to the local host address +// instead the address that TurnPort originally bound to. The candidate pair +// impacted by this behavior should still be used. +TEST_F(TurnPortTest, TestTurnTcpAllocationWhenProxyChangesAddressToLocalHost) { + turn_server_.AddInternalSocket(kTurnTcpIntAddr, cricket::PROTO_TCP); + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTcpProtoAddr); + EXPECT_EQ(0, turn_port_->SetOption(rtc::Socket::OPT_SNDBUF, 10 * 1024)); + turn_port_->PrepareAddress(); + ConnectSignalAddressReadyToSetLocalhostAsAltenertativeLocalAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + ASSERT_EQ(1U, turn_port_->Candidates().size()); + EXPECT_EQ(kTurnUdpExtAddr.ipaddr(), + turn_port_->Candidates()[0].address().ipaddr()); + EXPECT_NE(0, turn_port_->Candidates()[0].address().port()); +} + +// Testing turn port will attempt to create TCP socket on address resolution +// failure. +TEST_F(TurnPortTest, DISABLED_TestTurnTcpOnAddressResolveFailure) { + turn_server_.AddInternalSocket(kTurnTcpIntAddr, cricket::PROTO_TCP); + CreateTurnPort(kTurnUsername, kTurnPassword, cricket::ProtocolAddress( + rtc::SocketAddress("www.webrtc-blah-blah.com", 3478), + cricket::PROTO_TCP)); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_error_, kTimeout); + // As VSS doesn't provide a DNS resolution, name resolve will fail. TurnPort + // will proceed in creating a TCP socket which will fail as there is no + // server on the above domain and error will be set to SOCKET_ERROR. + EXPECT_EQ(SOCKET_ERROR, turn_port_->error()); +} + +// In case of UDP on address resolve failure, TurnPort will not create socket +// and return allocate failure. +TEST_F(TurnPortTest, DISABLED_TestTurnUdpOnAdressResolveFailure) { + CreateTurnPort(kTurnUsername, kTurnPassword, cricket::ProtocolAddress( + rtc::SocketAddress("www.webrtc-blah-blah.com", 3478), + cricket::PROTO_UDP)); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_error_, kTimeout); + // Error from turn port will not be socket error. + EXPECT_NE(SOCKET_ERROR, turn_port_->error()); +} + +// Try to do a TURN allocation with an invalid password. +TEST_F(TurnPortTest, TestTurnAllocateBadPassword) { + CreateTurnPort(kTurnUsername, "bad", kTurnUdpProtoAddr); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_error_, kTimeout); + ASSERT_EQ(0U, turn_port_->Candidates().size()); +} + +// Tests that a new local address is created after +// STUN_ERROR_ALLOCATION_MISMATCH. +TEST_F(TurnPortTest, TestTurnAllocateMismatch) { + // Do a normal allocation first. + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + rtc::SocketAddress first_addr(turn_port_->socket()->GetLocalAddress()); + + // Clear connected_ flag on turnport to suppress the release of + // the allocation. + turn_port_->OnSocketClose(turn_port_->socket(), 0); + + // Forces the socket server to assign the same port. + ss_->SetNextPortForTesting(first_addr.port()); + + turn_ready_ = false; + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + turn_port_->PrepareAddress(); + + // Verifies that the new port has the same address. + EXPECT_EQ(first_addr, turn_port_->socket()->GetLocalAddress()); + + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + + // Verifies that the new port has a different address now. + EXPECT_NE(first_addr, turn_port_->socket()->GetLocalAddress()); +} + +// Tests that a shared-socket-TurnPort creates its own socket after +// STUN_ERROR_ALLOCATION_MISMATCH. +TEST_F(TurnPortTest, TestSharedSocketAllocateMismatch) { + // Do a normal allocation first. + CreateSharedTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + rtc::SocketAddress first_addr(turn_port_->socket()->GetLocalAddress()); + + // Clear connected_ flag on turnport to suppress the release of + // the allocation. + turn_port_->OnSocketClose(turn_port_->socket(), 0); + + turn_ready_ = false; + CreateSharedTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + + // Verifies that the new port has the same address. + EXPECT_EQ(first_addr, turn_port_->socket()->GetLocalAddress()); + EXPECT_TRUE(turn_port_->SharedSocket()); + + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + + // Verifies that the new port has a different address now. + EXPECT_NE(first_addr, turn_port_->socket()->GetLocalAddress()); + EXPECT_FALSE(turn_port_->SharedSocket()); +} + +TEST_F(TurnPortTest, TestTurnTcpAllocateMismatch) { + turn_server_.AddInternalSocket(kTurnTcpIntAddr, cricket::PROTO_TCP); + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTcpProtoAddr); + + // Do a normal allocation first. + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + rtc::SocketAddress first_addr(turn_port_->socket()->GetLocalAddress()); + + // Clear connected_ flag on turnport to suppress the release of + // the allocation. + turn_port_->OnSocketClose(turn_port_->socket(), 0); + + // Forces the socket server to assign the same port. + ss_->SetNextPortForTesting(first_addr.port()); + + turn_ready_ = false; + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTcpProtoAddr); + turn_port_->PrepareAddress(); + + // Verifies that the new port has the same address. + EXPECT_EQ(first_addr, turn_port_->socket()->GetLocalAddress()); + + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + + // Verifies that the new port has a different address now. + EXPECT_NE(first_addr, turn_port_->socket()->GetLocalAddress()); +} + +// Test that CreateConnection will return null if port becomes disconnected. +TEST_F(TurnPortTest, TestCreateConnectionWhenSocketClosed) { + turn_server_.AddInternalSocket(kTurnTcpIntAddr, cricket::PROTO_TCP); + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTcpProtoAddr); + turn_port_->PrepareAddress(); + ASSERT_TRUE_WAIT(turn_ready_, kTimeout); + + CreateUdpPort(); + udp_port_->PrepareAddress(); + ASSERT_TRUE_WAIT(udp_ready_, kTimeout); + // Create a connection. + Connection* conn1 = turn_port_->CreateConnection(udp_port_->Candidates()[0], + Port::ORIGIN_MESSAGE); + ASSERT_TRUE(conn1 != NULL); + + // Close the socket and create a connection again. + turn_port_->OnSocketClose(turn_port_->socket(), 1); + conn1 = turn_port_->CreateConnection(udp_port_->Candidates()[0], + Port::ORIGIN_MESSAGE); + ASSERT_TRUE(conn1 == NULL); +} + +// Test try-alternate-server feature. +TEST_F(TurnPortTest, TestTurnAlternateServerUDP) { + TestTurnAlternateServer(cricket::PROTO_UDP); +} + +TEST_F(TurnPortTest, TestTurnAlternateServerTCP) { + TestTurnAlternateServer(cricket::PROTO_TCP); +} + +// Test that we fail when we redirect to an address different from +// current IP family. +TEST_F(TurnPortTest, TestTurnAlternateServerV4toV6UDP) { + TestTurnAlternateServerV4toV6(cricket::PROTO_UDP); +} + +TEST_F(TurnPortTest, TestTurnAlternateServerV4toV6TCP) { + TestTurnAlternateServerV4toV6(cricket::PROTO_TCP); +} + +// Test try-alternate-server catches the case of pingpong. +TEST_F(TurnPortTest, TestTurnAlternateServerPingPongUDP) { + TestTurnAlternateServerPingPong(cricket::PROTO_UDP); +} + +TEST_F(TurnPortTest, TestTurnAlternateServerPingPongTCP) { + TestTurnAlternateServerPingPong(cricket::PROTO_TCP); +} + +// Test try-alternate-server catch the case of repeated server. +TEST_F(TurnPortTest, TestTurnAlternateServerDetectRepetitionUDP) { + TestTurnAlternateServerDetectRepetition(cricket::PROTO_UDP); +} + +TEST_F(TurnPortTest, TestTurnAlternateServerDetectRepetitionTCP) { + TestTurnAlternateServerDetectRepetition(cricket::PROTO_TCP); +} + +// Do a TURN allocation and try to send a packet to it from the outside. +// The packet should be dropped. Then, try to send a packet from TURN to the +// outside. It should reach its destination. Finally, try again from the +// outside. It should now work as well. +TEST_F(TurnPortTest, TestTurnConnection) { + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + TestTurnConnection(); +} + +// Similar to above, except that this test will use the shared socket. +TEST_F(TurnPortTest, TestTurnConnectionUsingSharedSocket) { + CreateSharedTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + TestTurnConnection(); +} + +// Test that we can establish a TCP connection with TURN server. +TEST_F(TurnPortTest, TestTurnTcpConnection) { + turn_server_.AddInternalSocket(kTurnTcpIntAddr, cricket::PROTO_TCP); + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTcpProtoAddr); + TestTurnConnection(); +} + +// Test that we fail to create a connection when we want to use TLS over TCP. +// This test should be removed once we have TLS support. +TEST_F(TurnPortTest, TestTurnTlsTcpConnectionFails) { + cricket::ProtocolAddress secure_addr(kTurnTcpProtoAddr.address, + kTurnTcpProtoAddr.proto, + true); + CreateTurnPort(kTurnUsername, kTurnPassword, secure_addr); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_error_, kTimeout); + ASSERT_EQ(0U, turn_port_->Candidates().size()); +} + +// Run TurnConnectionTest with one-time-use nonce feature. +// Here server will send a 438 STALE_NONCE error message for +// every TURN transaction. +TEST_F(TurnPortTest, TestTurnConnectionUsingOTUNonce) { + turn_server_.set_enable_otu_nonce(true); + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + TestTurnConnection(); +} + +// Do a TURN allocation, establish a UDP connection, and send some data. +TEST_F(TurnPortTest, TestTurnSendDataTurnUdpToUdp) { + // Create ports and prepare addresses. + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + TestTurnSendData(); + EXPECT_EQ(cricket::UDP_PROTOCOL_NAME, + turn_port_->Candidates()[0].relay_protocol()); +} + +// Do a TURN allocation, establish a TCP connection, and send some data. +TEST_F(TurnPortTest, TestTurnSendDataTurnTcpToUdp) { + turn_server_.AddInternalSocket(kTurnTcpIntAddr, cricket::PROTO_TCP); + // Create ports and prepare addresses. + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTcpProtoAddr); + TestTurnSendData(); + EXPECT_EQ(cricket::TCP_PROTOCOL_NAME, + turn_port_->Candidates()[0].relay_protocol()); +} + +// Test TURN fails to make a connection from IPv6 address to a server which has +// IPv4 address. +TEST_F(TurnPortTest, TestTurnLocalIPv6AddressServerIPv4) { + turn_server_.AddInternalSocket(kTurnUdpIPv6IntAddr, cricket::PROTO_UDP); + CreateTurnPort(kLocalIPv6Addr, kTurnUsername, kTurnPassword, + kTurnUdpProtoAddr); + turn_port_->PrepareAddress(); + ASSERT_TRUE_WAIT(turn_error_, kTimeout); + EXPECT_TRUE(turn_port_->Candidates().empty()); +} + +// Test TURN make a connection from IPv6 address to a server which has +// IPv6 intenal address. But in this test external address is a IPv4 address, +// hence allocated address will be a IPv4 address. +TEST_F(TurnPortTest, TestTurnLocalIPv6AddressServerIPv6ExtenalIPv4) { + turn_server_.AddInternalSocket(kTurnUdpIPv6IntAddr, cricket::PROTO_UDP); + CreateTurnPort(kLocalIPv6Addr, kTurnUsername, kTurnPassword, + kTurnUdpIPv6ProtoAddr); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + ASSERT_EQ(1U, turn_port_->Candidates().size()); + EXPECT_EQ(kTurnUdpExtAddr.ipaddr(), + turn_port_->Candidates()[0].address().ipaddr()); + EXPECT_NE(0, turn_port_->Candidates()[0].address().port()); +} + +TEST_F(TurnPortTest, TestOriginHeader) { + CreateTurnPortWithOrigin(kLocalAddr1, kTurnUsername, kTurnPassword, + kTurnUdpProtoAddr, kTestOrigin); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + ASSERT_GT(turn_server_.server()->allocations().size(), 0U); + SocketAddress local_address = turn_port_->GetLocalAddress(); + ASSERT_TRUE(turn_server_.FindAllocation(local_address) != NULL); + EXPECT_EQ(kTestOrigin, turn_server_.FindAllocation(local_address)->origin()); +} + +// Test that a TURN allocation is released when the port is closed. +TEST_F(TurnPortTest, TestTurnReleaseAllocation) { + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + + ASSERT_GT(turn_server_.server()->allocations().size(), 0U); + turn_port_.reset(); + EXPECT_EQ_WAIT(0U, turn_server_.server()->allocations().size(), kTimeout); +} + +// Test that a TURN TCP allocation is released when the port is closed. +TEST_F(TurnPortTest, DISABLED_TestTurnTCPReleaseAllocation) { + turn_server_.AddInternalSocket(kTurnTcpIntAddr, cricket::PROTO_TCP); + CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTcpProtoAddr); + turn_port_->PrepareAddress(); + EXPECT_TRUE_WAIT(turn_ready_, kTimeout); + + ASSERT_GT(turn_server_.server()->allocations().size(), 0U); + turn_port_.reset(); + EXPECT_EQ_WAIT(0U, turn_server_.server()->allocations().size(), kTimeout); +} + +// This test verifies any FD's are not leaked after TurnPort is destroyed. +// https://code.google.com/p/webrtc/issues/detail?id=2651 +#if defined(WEBRTC_LINUX) && !defined(WEBRTC_ANDROID) +TEST_F(TurnPortTest, TestResolverShutdown) { + turn_server_.AddInternalSocket(kTurnUdpIPv6IntAddr, cricket::PROTO_UDP); + int last_fd_count = GetFDCount(); + // Need to supply unresolved address to kick off resolver. + CreateTurnPort(kLocalIPv6Addr, kTurnUsername, kTurnPassword, + cricket::ProtocolAddress(rtc::SocketAddress( + "www.google.invalid", 3478), cricket::PROTO_UDP)); + turn_port_->PrepareAddress(); + ASSERT_TRUE_WAIT(turn_error_, kTimeout); + EXPECT_TRUE(turn_port_->Candidates().empty()); + turn_port_.reset(); + rtc::Thread::Current()->Post(this, MSG_TESTFINISH); + // Waiting for above message to be processed. + ASSERT_TRUE_WAIT(test_finish_, kTimeout); + EXPECT_EQ(last_fd_count, GetFDCount()); +} +#endif diff --git a/webrtc/p2p/base/turnserver.cc b/webrtc/p2p/base/turnserver.cc new file mode 100644 index 0000000000..8d40a9030c --- /dev/null +++ b/webrtc/p2p/base/turnserver.cc @@ -0,0 +1,945 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/p2p/base/turnserver.h" + +#include "webrtc/p2p/base/asyncstuntcpsocket.h" +#include "webrtc/p2p/base/common.h" +#include "webrtc/p2p/base/packetsocketfactory.h" +#include "webrtc/p2p/base/stun.h" +#include "webrtc/base/bytebuffer.h" +#include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" +#include "webrtc/base/messagedigest.h" +#include "webrtc/base/socketadapters.h" +#include "webrtc/base/stringencode.h" +#include "webrtc/base/thread.h" + +namespace cricket { + +// TODO(juberti): Move this all to a future turnmessage.h +//static const int IPPROTO_UDP = 17; +static const int kNonceTimeout = 60 * 60 * 1000; // 60 minutes +static const int kDefaultAllocationTimeout = 10 * 60 * 1000; // 10 minutes +static const int kPermissionTimeout = 5 * 60 * 1000; // 5 minutes +static const int kChannelTimeout = 10 * 60 * 1000; // 10 minutes + +static const int kMinChannelNumber = 0x4000; +static const int kMaxChannelNumber = 0x7FFF; + +static const size_t kNonceKeySize = 16; +static const size_t kNonceSize = 40; + +static const size_t TURN_CHANNEL_HEADER_SIZE = 4U; + +// TODO(mallinath) - Move these to a common place. +inline bool IsTurnChannelData(uint16_t msg_type) { + // The first two bits of a channel data message are 0b01. + return ((msg_type & 0xC000) == 0x4000); +} + +// IDs used for posted messages for TurnServerAllocation. +enum { + MSG_ALLOCATION_TIMEOUT, +}; + +// Encapsulates a TURN permission. +// The object is created when a create permission request is received by an +// allocation, and self-deletes when its lifetime timer expires. +class TurnServerAllocation::Permission : public rtc::MessageHandler { + public: + Permission(rtc::Thread* thread, const rtc::IPAddress& peer); + ~Permission(); + + const rtc::IPAddress& peer() const { return peer_; } + void Refresh(); + + sigslot::signal1<Permission*> SignalDestroyed; + + private: + virtual void OnMessage(rtc::Message* msg); + + rtc::Thread* thread_; + rtc::IPAddress peer_; +}; + +// Encapsulates a TURN channel binding. +// The object is created when a channel bind request is received by an +// allocation, and self-deletes when its lifetime timer expires. +class TurnServerAllocation::Channel : public rtc::MessageHandler { + public: + Channel(rtc::Thread* thread, int id, + const rtc::SocketAddress& peer); + ~Channel(); + + int id() const { return id_; } + const rtc::SocketAddress& peer() const { return peer_; } + void Refresh(); + + sigslot::signal1<Channel*> SignalDestroyed; + + private: + virtual void OnMessage(rtc::Message* msg); + + rtc::Thread* thread_; + int id_; + rtc::SocketAddress peer_; +}; + +static bool InitResponse(const StunMessage* req, StunMessage* resp) { + int resp_type = (req) ? GetStunSuccessResponseType(req->type()) : -1; + if (resp_type == -1) + return false; + resp->SetType(resp_type); + resp->SetTransactionID(req->transaction_id()); + return true; +} + +static bool InitErrorResponse(const StunMessage* req, int code, + const std::string& reason, StunMessage* resp) { + int resp_type = (req) ? GetStunErrorResponseType(req->type()) : -1; + if (resp_type == -1) + return false; + resp->SetType(resp_type); + resp->SetTransactionID(req->transaction_id()); + VERIFY(resp->AddAttribute(new cricket::StunErrorCodeAttribute( + STUN_ATTR_ERROR_CODE, code, reason))); + return true; +} + + +TurnServer::TurnServer(rtc::Thread* thread) + : thread_(thread), + nonce_key_(rtc::CreateRandomString(kNonceKeySize)), + auth_hook_(NULL), + redirect_hook_(NULL), + enable_otu_nonce_(false) { +} + +TurnServer::~TurnServer() { + for (AllocationMap::iterator it = allocations_.begin(); + it != allocations_.end(); ++it) { + delete it->second; + } + + for (InternalSocketMap::iterator it = server_sockets_.begin(); + it != server_sockets_.end(); ++it) { + rtc::AsyncPacketSocket* socket = it->first; + delete socket; + } + + for (ServerSocketMap::iterator it = server_listen_sockets_.begin(); + it != server_listen_sockets_.end(); ++it) { + rtc::AsyncSocket* socket = it->first; + delete socket; + } +} + +void TurnServer::AddInternalSocket(rtc::AsyncPacketSocket* socket, + ProtocolType proto) { + ASSERT(server_sockets_.end() == server_sockets_.find(socket)); + server_sockets_[socket] = proto; + socket->SignalReadPacket.connect(this, &TurnServer::OnInternalPacket); +} + +void TurnServer::AddInternalServerSocket(rtc::AsyncSocket* socket, + ProtocolType proto) { + ASSERT(server_listen_sockets_.end() == + server_listen_sockets_.find(socket)); + server_listen_sockets_[socket] = proto; + socket->SignalReadEvent.connect(this, &TurnServer::OnNewInternalConnection); +} + +void TurnServer::SetExternalSocketFactory( + rtc::PacketSocketFactory* factory, + const rtc::SocketAddress& external_addr) { + external_socket_factory_.reset(factory); + external_addr_ = external_addr; +} + +void TurnServer::OnNewInternalConnection(rtc::AsyncSocket* socket) { + ASSERT(server_listen_sockets_.find(socket) != server_listen_sockets_.end()); + AcceptConnection(socket); +} + +void TurnServer::AcceptConnection(rtc::AsyncSocket* server_socket) { + // Check if someone is trying to connect to us. + rtc::SocketAddress accept_addr; + rtc::AsyncSocket* accepted_socket = server_socket->Accept(&accept_addr); + if (accepted_socket != NULL) { + ProtocolType proto = server_listen_sockets_[server_socket]; + cricket::AsyncStunTCPSocket* tcp_socket = + new cricket::AsyncStunTCPSocket(accepted_socket, false); + + tcp_socket->SignalClose.connect(this, &TurnServer::OnInternalSocketClose); + // Finally add the socket so it can start communicating with the client. + AddInternalSocket(tcp_socket, proto); + } +} + +void TurnServer::OnInternalSocketClose(rtc::AsyncPacketSocket* socket, + int err) { + DestroyInternalSocket(socket); +} + +void TurnServer::OnInternalPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketTime& packet_time) { + // Fail if the packet is too small to even contain a channel header. + if (size < TURN_CHANNEL_HEADER_SIZE) { + return; + } + InternalSocketMap::iterator iter = server_sockets_.find(socket); + ASSERT(iter != server_sockets_.end()); + TurnServerConnection conn(addr, iter->second, socket); + uint16_t msg_type = rtc::GetBE16(data); + if (!IsTurnChannelData(msg_type)) { + // This is a STUN message. + HandleStunMessage(&conn, data, size); + } else { + // This is a channel message; let the allocation handle it. + TurnServerAllocation* allocation = FindAllocation(&conn); + if (allocation) { + allocation->HandleChannelData(data, size); + } + } +} + +void TurnServer::HandleStunMessage(TurnServerConnection* conn, const char* data, + size_t size) { + TurnMessage msg; + rtc::ByteBuffer buf(data, size); + if (!msg.Read(&buf) || (buf.Length() > 0)) { + LOG(LS_WARNING) << "Received invalid STUN message"; + return; + } + + // If it's a STUN binding request, handle that specially. + if (msg.type() == STUN_BINDING_REQUEST) { + HandleBindingRequest(conn, &msg); + return; + } + + if (redirect_hook_ != NULL && msg.type() == STUN_ALLOCATE_REQUEST) { + rtc::SocketAddress address; + if (redirect_hook_->ShouldRedirect(conn->src(), &address)) { + SendErrorResponseWithAlternateServer( + conn, &msg, address); + return; + } + } + + // Look up the key that we'll use to validate the M-I. If we have an + // existing allocation, the key will already be cached. + TurnServerAllocation* allocation = FindAllocation(conn); + std::string key; + if (!allocation) { + GetKey(&msg, &key); + } else { + key = allocation->key(); + } + + // Ensure the message is authorized; only needed for requests. + if (IsStunRequestType(msg.type())) { + if (!CheckAuthorization(conn, &msg, data, size, key)) { + return; + } + } + + if (!allocation && msg.type() == STUN_ALLOCATE_REQUEST) { + HandleAllocateRequest(conn, &msg, key); + } else if (allocation && + (msg.type() != STUN_ALLOCATE_REQUEST || + msg.transaction_id() == allocation->transaction_id())) { + // This is a non-allocate request, or a retransmit of an allocate. + // Check that the username matches the previous username used. + if (IsStunRequestType(msg.type()) && + msg.GetByteString(STUN_ATTR_USERNAME)->GetString() != + allocation->username()) { + SendErrorResponse(conn, &msg, STUN_ERROR_WRONG_CREDENTIALS, + STUN_ERROR_REASON_WRONG_CREDENTIALS); + return; + } + allocation->HandleTurnMessage(&msg); + } else { + // Allocation mismatch. + SendErrorResponse(conn, &msg, STUN_ERROR_ALLOCATION_MISMATCH, + STUN_ERROR_REASON_ALLOCATION_MISMATCH); + } +} + +bool TurnServer::GetKey(const StunMessage* msg, std::string* key) { + const StunByteStringAttribute* username_attr = + msg->GetByteString(STUN_ATTR_USERNAME); + if (!username_attr) { + return false; + } + + std::string username = username_attr->GetString(); + return (auth_hook_ != NULL && auth_hook_->GetKey(username, realm_, key)); +} + +bool TurnServer::CheckAuthorization(TurnServerConnection* conn, + const StunMessage* msg, + const char* data, size_t size, + const std::string& key) { + // RFC 5389, 10.2.2. + ASSERT(IsStunRequestType(msg->type())); + const StunByteStringAttribute* mi_attr = + msg->GetByteString(STUN_ATTR_MESSAGE_INTEGRITY); + const StunByteStringAttribute* username_attr = + msg->GetByteString(STUN_ATTR_USERNAME); + const StunByteStringAttribute* realm_attr = + msg->GetByteString(STUN_ATTR_REALM); + const StunByteStringAttribute* nonce_attr = + msg->GetByteString(STUN_ATTR_NONCE); + + // Fail if no M-I. + if (!mi_attr) { + SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_UNAUTHORIZED, + STUN_ERROR_REASON_UNAUTHORIZED); + return false; + } + + // Fail if there is M-I but no username, nonce, or realm. + if (!username_attr || !realm_attr || !nonce_attr) { + SendErrorResponse(conn, msg, STUN_ERROR_BAD_REQUEST, + STUN_ERROR_REASON_BAD_REQUEST); + return false; + } + + // Fail if bad nonce. + if (!ValidateNonce(nonce_attr->GetString())) { + SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_STALE_NONCE, + STUN_ERROR_REASON_STALE_NONCE); + return false; + } + + // Fail if bad username or M-I. + // We need |data| and |size| for the call to ValidateMessageIntegrity. + if (key.empty() || !StunMessage::ValidateMessageIntegrity(data, size, key)) { + SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_UNAUTHORIZED, + STUN_ERROR_REASON_UNAUTHORIZED); + return false; + } + + // Fail if one-time-use nonce feature is enabled. + TurnServerAllocation* allocation = FindAllocation(conn); + if (enable_otu_nonce_ && allocation && + allocation->last_nonce() == nonce_attr->GetString()) { + SendErrorResponseWithRealmAndNonce(conn, msg, STUN_ERROR_STALE_NONCE, + STUN_ERROR_REASON_STALE_NONCE); + return false; + } + + if (allocation) { + allocation->set_last_nonce(nonce_attr->GetString()); + } + // Success. + return true; +} + +void TurnServer::HandleBindingRequest(TurnServerConnection* conn, + const StunMessage* req) { + StunMessage response; + InitResponse(req, &response); + + // Tell the user the address that we received their request from. + StunAddressAttribute* mapped_addr_attr; + mapped_addr_attr = new StunXorAddressAttribute( + STUN_ATTR_XOR_MAPPED_ADDRESS, conn->src()); + VERIFY(response.AddAttribute(mapped_addr_attr)); + + SendStun(conn, &response); +} + +void TurnServer::HandleAllocateRequest(TurnServerConnection* conn, + const TurnMessage* msg, + const std::string& key) { + // Check the parameters in the request. + const StunUInt32Attribute* transport_attr = + msg->GetUInt32(STUN_ATTR_REQUESTED_TRANSPORT); + if (!transport_attr) { + SendErrorResponse(conn, msg, STUN_ERROR_BAD_REQUEST, + STUN_ERROR_REASON_BAD_REQUEST); + return; + } + + // Only UDP is supported right now. + int proto = transport_attr->value() >> 24; + if (proto != IPPROTO_UDP) { + SendErrorResponse(conn, msg, STUN_ERROR_UNSUPPORTED_PROTOCOL, + STUN_ERROR_REASON_UNSUPPORTED_PROTOCOL); + return; + } + + // Create the allocation and let it send the success response. + // If the actual socket allocation fails, send an internal error. + TurnServerAllocation* alloc = CreateAllocation(conn, proto, key); + if (alloc) { + alloc->HandleTurnMessage(msg); + } else { + SendErrorResponse(conn, msg, STUN_ERROR_SERVER_ERROR, + "Failed to allocate socket"); + } +} + +std::string TurnServer::GenerateNonce() const { + // Generate a nonce of the form hex(now + HMAC-MD5(nonce_key_, now)) + uint32_t now = rtc::Time(); + std::string input(reinterpret_cast<const char*>(&now), sizeof(now)); + std::string nonce = rtc::hex_encode(input.c_str(), input.size()); + nonce += rtc::ComputeHmac(rtc::DIGEST_MD5, nonce_key_, input); + ASSERT(nonce.size() == kNonceSize); + return nonce; +} + +bool TurnServer::ValidateNonce(const std::string& nonce) const { + // Check the size. + if (nonce.size() != kNonceSize) { + return false; + } + + // Decode the timestamp. + uint32_t then; + char* p = reinterpret_cast<char*>(&then); + size_t len = rtc::hex_decode(p, sizeof(then), + nonce.substr(0, sizeof(then) * 2)); + if (len != sizeof(then)) { + return false; + } + + // Verify the HMAC. + if (nonce.substr(sizeof(then) * 2) != rtc::ComputeHmac( + rtc::DIGEST_MD5, nonce_key_, std::string(p, sizeof(then)))) { + return false; + } + + // Validate the timestamp. + return rtc::TimeSince(then) < kNonceTimeout; +} + +TurnServerAllocation* TurnServer::FindAllocation(TurnServerConnection* conn) { + AllocationMap::const_iterator it = allocations_.find(*conn); + return (it != allocations_.end()) ? it->second : NULL; +} + +TurnServerAllocation* TurnServer::CreateAllocation(TurnServerConnection* conn, + int proto, + const std::string& key) { + rtc::AsyncPacketSocket* external_socket = (external_socket_factory_) ? + external_socket_factory_->CreateUdpSocket(external_addr_, 0, 0) : NULL; + if (!external_socket) { + return NULL; + } + + // The Allocation takes ownership of the socket. + TurnServerAllocation* allocation = new TurnServerAllocation(this, + thread_, *conn, external_socket, key); + allocation->SignalDestroyed.connect(this, &TurnServer::OnAllocationDestroyed); + allocations_[*conn] = allocation; + return allocation; +} + +void TurnServer::SendErrorResponse(TurnServerConnection* conn, + const StunMessage* req, + int code, const std::string& reason) { + TurnMessage resp; + InitErrorResponse(req, code, reason, &resp); + LOG(LS_INFO) << "Sending error response, type=" << resp.type() + << ", code=" << code << ", reason=" << reason; + SendStun(conn, &resp); +} + +void TurnServer::SendErrorResponseWithRealmAndNonce( + TurnServerConnection* conn, const StunMessage* msg, + int code, const std::string& reason) { + TurnMessage resp; + InitErrorResponse(msg, code, reason, &resp); + VERIFY(resp.AddAttribute(new StunByteStringAttribute( + STUN_ATTR_NONCE, GenerateNonce()))); + VERIFY(resp.AddAttribute(new StunByteStringAttribute( + STUN_ATTR_REALM, realm_))); + SendStun(conn, &resp); +} + +void TurnServer::SendErrorResponseWithAlternateServer( + TurnServerConnection* conn, const StunMessage* msg, + const rtc::SocketAddress& addr) { + TurnMessage resp; + InitErrorResponse(msg, STUN_ERROR_TRY_ALTERNATE, + STUN_ERROR_REASON_TRY_ALTERNATE_SERVER, &resp); + VERIFY(resp.AddAttribute(new StunAddressAttribute( + STUN_ATTR_ALTERNATE_SERVER, addr))); + SendStun(conn, &resp); +} + +void TurnServer::SendStun(TurnServerConnection* conn, StunMessage* msg) { + rtc::ByteBuffer buf; + // Add a SOFTWARE attribute if one is set. + if (!software_.empty()) { + VERIFY(msg->AddAttribute( + new StunByteStringAttribute(STUN_ATTR_SOFTWARE, software_))); + } + msg->Write(&buf); + Send(conn, buf); +} + +void TurnServer::Send(TurnServerConnection* conn, + const rtc::ByteBuffer& buf) { + rtc::PacketOptions options; + conn->socket()->SendTo(buf.Data(), buf.Length(), conn->src(), options); +} + +void TurnServer::OnAllocationDestroyed(TurnServerAllocation* allocation) { + // Removing the internal socket if the connection is not udp. + rtc::AsyncPacketSocket* socket = allocation->conn()->socket(); + InternalSocketMap::iterator iter = server_sockets_.find(socket); + ASSERT(iter != server_sockets_.end()); + // Skip if the socket serving this allocation is UDP, as this will be shared + // by all allocations. + if (iter->second != cricket::PROTO_UDP) { + DestroyInternalSocket(socket); + } + + AllocationMap::iterator it = allocations_.find(*(allocation->conn())); + if (it != allocations_.end()) + allocations_.erase(it); +} + +void TurnServer::DestroyInternalSocket(rtc::AsyncPacketSocket* socket) { + InternalSocketMap::iterator iter = server_sockets_.find(socket); + if (iter != server_sockets_.end()) { + rtc::AsyncPacketSocket* socket = iter->first; + // We must destroy the socket async to avoid invalidating the sigslot + // callback list iterator inside a sigslot callback. + rtc::Thread::Current()->Dispose(socket); + server_sockets_.erase(iter); + } +} + +TurnServerConnection::TurnServerConnection(const rtc::SocketAddress& src, + ProtocolType proto, + rtc::AsyncPacketSocket* socket) + : src_(src), + dst_(socket->GetRemoteAddress()), + proto_(proto), + socket_(socket) { +} + +bool TurnServerConnection::operator==(const TurnServerConnection& c) const { + return src_ == c.src_ && dst_ == c.dst_ && proto_ == c.proto_; +} + +bool TurnServerConnection::operator<(const TurnServerConnection& c) const { + return src_ < c.src_ || dst_ < c.dst_ || proto_ < c.proto_; +} + +std::string TurnServerConnection::ToString() const { + const char* const kProtos[] = { + "unknown", "udp", "tcp", "ssltcp" + }; + std::ostringstream ost; + ost << src_.ToString() << "-" << dst_.ToString() << ":"<< kProtos[proto_]; + return ost.str(); +} + +TurnServerAllocation::TurnServerAllocation(TurnServer* server, + rtc::Thread* thread, + const TurnServerConnection& conn, + rtc::AsyncPacketSocket* socket, + const std::string& key) + : server_(server), + thread_(thread), + conn_(conn), + external_socket_(socket), + key_(key) { + external_socket_->SignalReadPacket.connect( + this, &TurnServerAllocation::OnExternalPacket); +} + +TurnServerAllocation::~TurnServerAllocation() { + for (ChannelList::iterator it = channels_.begin(); + it != channels_.end(); ++it) { + delete *it; + } + for (PermissionList::iterator it = perms_.begin(); + it != perms_.end(); ++it) { + delete *it; + } + thread_->Clear(this, MSG_ALLOCATION_TIMEOUT); + LOG_J(LS_INFO, this) << "Allocation destroyed"; +} + +std::string TurnServerAllocation::ToString() const { + std::ostringstream ost; + ost << "Alloc[" << conn_.ToString() << "]"; + return ost.str(); +} + +void TurnServerAllocation::HandleTurnMessage(const TurnMessage* msg) { + ASSERT(msg != NULL); + switch (msg->type()) { + case STUN_ALLOCATE_REQUEST: + HandleAllocateRequest(msg); + break; + case TURN_REFRESH_REQUEST: + HandleRefreshRequest(msg); + break; + case TURN_SEND_INDICATION: + HandleSendIndication(msg); + break; + case TURN_CREATE_PERMISSION_REQUEST: + HandleCreatePermissionRequest(msg); + break; + case TURN_CHANNEL_BIND_REQUEST: + HandleChannelBindRequest(msg); + break; + default: + // Not sure what to do with this, just eat it. + LOG_J(LS_WARNING, this) << "Invalid TURN message type received: " + << msg->type(); + } +} + +void TurnServerAllocation::HandleAllocateRequest(const TurnMessage* msg) { + // Copy the important info from the allocate request. + transaction_id_ = msg->transaction_id(); + const StunByteStringAttribute* username_attr = + msg->GetByteString(STUN_ATTR_USERNAME); + ASSERT(username_attr != NULL); + username_ = username_attr->GetString(); + const StunByteStringAttribute* origin_attr = + msg->GetByteString(STUN_ATTR_ORIGIN); + if (origin_attr) { + origin_ = origin_attr->GetString(); + } + + // Figure out the lifetime and start the allocation timer. + int lifetime_secs = ComputeLifetime(msg); + thread_->PostDelayed(lifetime_secs * 1000, this, MSG_ALLOCATION_TIMEOUT); + + LOG_J(LS_INFO, this) << "Created allocation, lifetime=" << lifetime_secs; + + // We've already validated all the important bits; just send a response here. + TurnMessage response; + InitResponse(msg, &response); + + StunAddressAttribute* mapped_addr_attr = + new StunXorAddressAttribute(STUN_ATTR_XOR_MAPPED_ADDRESS, conn_.src()); + StunAddressAttribute* relayed_addr_attr = + new StunXorAddressAttribute(STUN_ATTR_XOR_RELAYED_ADDRESS, + external_socket_->GetLocalAddress()); + StunUInt32Attribute* lifetime_attr = + new StunUInt32Attribute(STUN_ATTR_LIFETIME, lifetime_secs); + VERIFY(response.AddAttribute(mapped_addr_attr)); + VERIFY(response.AddAttribute(relayed_addr_attr)); + VERIFY(response.AddAttribute(lifetime_attr)); + + SendResponse(&response); +} + +void TurnServerAllocation::HandleRefreshRequest(const TurnMessage* msg) { + // Figure out the new lifetime. + int lifetime_secs = ComputeLifetime(msg); + + // Reset the expiration timer. + thread_->Clear(this, MSG_ALLOCATION_TIMEOUT); + thread_->PostDelayed(lifetime_secs * 1000, this, MSG_ALLOCATION_TIMEOUT); + + LOG_J(LS_INFO, this) << "Refreshed allocation, lifetime=" << lifetime_secs; + + // Send a success response with a LIFETIME attribute. + TurnMessage response; + InitResponse(msg, &response); + + StunUInt32Attribute* lifetime_attr = + new StunUInt32Attribute(STUN_ATTR_LIFETIME, lifetime_secs); + VERIFY(response.AddAttribute(lifetime_attr)); + + SendResponse(&response); +} + +void TurnServerAllocation::HandleSendIndication(const TurnMessage* msg) { + // Check mandatory attributes. + const StunByteStringAttribute* data_attr = msg->GetByteString(STUN_ATTR_DATA); + const StunAddressAttribute* peer_attr = + msg->GetAddress(STUN_ATTR_XOR_PEER_ADDRESS); + if (!data_attr || !peer_attr) { + LOG_J(LS_WARNING, this) << "Received invalid send indication"; + return; + } + + // If a permission exists, send the data on to the peer. + if (HasPermission(peer_attr->GetAddress().ipaddr())) { + SendExternal(data_attr->bytes(), data_attr->length(), + peer_attr->GetAddress()); + } else { + LOG_J(LS_WARNING, this) << "Received send indication without permission" + << "peer=" << peer_attr->GetAddress(); + } +} + +void TurnServerAllocation::HandleCreatePermissionRequest( + const TurnMessage* msg) { + // Check mandatory attributes. + const StunAddressAttribute* peer_attr = + msg->GetAddress(STUN_ATTR_XOR_PEER_ADDRESS); + if (!peer_attr) { + SendBadRequestResponse(msg); + return; + } + + // Add this permission. + AddPermission(peer_attr->GetAddress().ipaddr()); + + LOG_J(LS_INFO, this) << "Created permission, peer=" + << peer_attr->GetAddress(); + + // Send a success response. + TurnMessage response; + InitResponse(msg, &response); + SendResponse(&response); +} + +void TurnServerAllocation::HandleChannelBindRequest(const TurnMessage* msg) { + // Check mandatory attributes. + const StunUInt32Attribute* channel_attr = + msg->GetUInt32(STUN_ATTR_CHANNEL_NUMBER); + const StunAddressAttribute* peer_attr = + msg->GetAddress(STUN_ATTR_XOR_PEER_ADDRESS); + if (!channel_attr || !peer_attr) { + SendBadRequestResponse(msg); + return; + } + + // Check that channel id is valid. + int channel_id = channel_attr->value() >> 16; + if (channel_id < kMinChannelNumber || channel_id > kMaxChannelNumber) { + SendBadRequestResponse(msg); + return; + } + + // Check that this channel id isn't bound to another transport address, and + // that this transport address isn't bound to another channel id. + Channel* channel1 = FindChannel(channel_id); + Channel* channel2 = FindChannel(peer_attr->GetAddress()); + if (channel1 != channel2) { + SendBadRequestResponse(msg); + return; + } + + // Add or refresh this channel. + if (!channel1) { + channel1 = new Channel(thread_, channel_id, peer_attr->GetAddress()); + channel1->SignalDestroyed.connect(this, + &TurnServerAllocation::OnChannelDestroyed); + channels_.push_back(channel1); + } else { + channel1->Refresh(); + } + + // Channel binds also refresh permissions. + AddPermission(peer_attr->GetAddress().ipaddr()); + + LOG_J(LS_INFO, this) << "Bound channel, id=" << channel_id + << ", peer=" << peer_attr->GetAddress(); + + // Send a success response. + TurnMessage response; + InitResponse(msg, &response); + SendResponse(&response); +} + +void TurnServerAllocation::HandleChannelData(const char* data, size_t size) { + // Extract the channel number from the data. + uint16_t channel_id = rtc::GetBE16(data); + Channel* channel = FindChannel(channel_id); + if (channel) { + // Send the data to the peer address. + SendExternal(data + TURN_CHANNEL_HEADER_SIZE, + size - TURN_CHANNEL_HEADER_SIZE, channel->peer()); + } else { + LOG_J(LS_WARNING, this) << "Received channel data for invalid channel, id=" + << channel_id; + } +} + +void TurnServerAllocation::OnExternalPacket( + rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketTime& packet_time) { + ASSERT(external_socket_.get() == socket); + Channel* channel = FindChannel(addr); + if (channel) { + // There is a channel bound to this address. Send as a channel message. + rtc::ByteBuffer buf; + buf.WriteUInt16(channel->id()); + buf.WriteUInt16(static_cast<uint16_t>(size)); + buf.WriteBytes(data, size); + server_->Send(&conn_, buf); + } else if (HasPermission(addr.ipaddr())) { + // No channel, but a permission exists. Send as a data indication. + TurnMessage msg; + msg.SetType(TURN_DATA_INDICATION); + msg.SetTransactionID( + rtc::CreateRandomString(kStunTransactionIdLength)); + VERIFY(msg.AddAttribute(new StunXorAddressAttribute( + STUN_ATTR_XOR_PEER_ADDRESS, addr))); + VERIFY(msg.AddAttribute(new StunByteStringAttribute( + STUN_ATTR_DATA, data, size))); + server_->SendStun(&conn_, &msg); + } else { + LOG_J(LS_WARNING, this) << "Received external packet without permission, " + << "peer=" << addr; + } +} + +int TurnServerAllocation::ComputeLifetime(const TurnMessage* msg) { + // Return the smaller of our default lifetime and the requested lifetime. + uint32_t lifetime = kDefaultAllocationTimeout / 1000; // convert to seconds + const StunUInt32Attribute* lifetime_attr = msg->GetUInt32(STUN_ATTR_LIFETIME); + if (lifetime_attr && lifetime_attr->value() < lifetime) { + lifetime = lifetime_attr->value(); + } + return lifetime; +} + +bool TurnServerAllocation::HasPermission(const rtc::IPAddress& addr) { + return (FindPermission(addr) != NULL); +} + +void TurnServerAllocation::AddPermission(const rtc::IPAddress& addr) { + Permission* perm = FindPermission(addr); + if (!perm) { + perm = new Permission(thread_, addr); + perm->SignalDestroyed.connect( + this, &TurnServerAllocation::OnPermissionDestroyed); + perms_.push_back(perm); + } else { + perm->Refresh(); + } +} + +TurnServerAllocation::Permission* TurnServerAllocation::FindPermission( + const rtc::IPAddress& addr) const { + for (PermissionList::const_iterator it = perms_.begin(); + it != perms_.end(); ++it) { + if ((*it)->peer() == addr) + return *it; + } + return NULL; +} + +TurnServerAllocation::Channel* TurnServerAllocation::FindChannel( + int channel_id) const { + for (ChannelList::const_iterator it = channels_.begin(); + it != channels_.end(); ++it) { + if ((*it)->id() == channel_id) + return *it; + } + return NULL; +} + +TurnServerAllocation::Channel* TurnServerAllocation::FindChannel( + const rtc::SocketAddress& addr) const { + for (ChannelList::const_iterator it = channels_.begin(); + it != channels_.end(); ++it) { + if ((*it)->peer() == addr) + return *it; + } + return NULL; +} + +void TurnServerAllocation::SendResponse(TurnMessage* msg) { + // Success responses always have M-I. + msg->AddMessageIntegrity(key_); + server_->SendStun(&conn_, msg); +} + +void TurnServerAllocation::SendBadRequestResponse(const TurnMessage* req) { + SendErrorResponse(req, STUN_ERROR_BAD_REQUEST, STUN_ERROR_REASON_BAD_REQUEST); +} + +void TurnServerAllocation::SendErrorResponse(const TurnMessage* req, int code, + const std::string& reason) { + server_->SendErrorResponse(&conn_, req, code, reason); +} + +void TurnServerAllocation::SendExternal(const void* data, size_t size, + const rtc::SocketAddress& peer) { + rtc::PacketOptions options; + external_socket_->SendTo(data, size, peer, options); +} + +void TurnServerAllocation::OnMessage(rtc::Message* msg) { + ASSERT(msg->message_id == MSG_ALLOCATION_TIMEOUT); + SignalDestroyed(this); + delete this; +} + +void TurnServerAllocation::OnPermissionDestroyed(Permission* perm) { + PermissionList::iterator it = std::find(perms_.begin(), perms_.end(), perm); + ASSERT(it != perms_.end()); + perms_.erase(it); +} + +void TurnServerAllocation::OnChannelDestroyed(Channel* channel) { + ChannelList::iterator it = + std::find(channels_.begin(), channels_.end(), channel); + ASSERT(it != channels_.end()); + channels_.erase(it); +} + +TurnServerAllocation::Permission::Permission(rtc::Thread* thread, + const rtc::IPAddress& peer) + : thread_(thread), peer_(peer) { + Refresh(); +} + +TurnServerAllocation::Permission::~Permission() { + thread_->Clear(this, MSG_ALLOCATION_TIMEOUT); +} + +void TurnServerAllocation::Permission::Refresh() { + thread_->Clear(this, MSG_ALLOCATION_TIMEOUT); + thread_->PostDelayed(kPermissionTimeout, this, MSG_ALLOCATION_TIMEOUT); +} + +void TurnServerAllocation::Permission::OnMessage(rtc::Message* msg) { + ASSERT(msg->message_id == MSG_ALLOCATION_TIMEOUT); + SignalDestroyed(this); + delete this; +} + +TurnServerAllocation::Channel::Channel(rtc::Thread* thread, int id, + const rtc::SocketAddress& peer) + : thread_(thread), id_(id), peer_(peer) { + Refresh(); +} + +TurnServerAllocation::Channel::~Channel() { + thread_->Clear(this, MSG_ALLOCATION_TIMEOUT); +} + +void TurnServerAllocation::Channel::Refresh() { + thread_->Clear(this, MSG_ALLOCATION_TIMEOUT); + thread_->PostDelayed(kChannelTimeout, this, MSG_ALLOCATION_TIMEOUT); +} + +void TurnServerAllocation::Channel::OnMessage(rtc::Message* msg) { + ASSERT(msg->message_id == MSG_ALLOCATION_TIMEOUT); + SignalDestroyed(this); + delete this; +} + +} // namespace cricket diff --git a/webrtc/p2p/base/turnserver.h b/webrtc/p2p/base/turnserver.h new file mode 100644 index 0000000000..d3bd77a866 --- /dev/null +++ b/webrtc/p2p/base/turnserver.h @@ -0,0 +1,272 @@ +/* + * Copyright 2012 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_TURNSERVER_H_ +#define WEBRTC_P2P_BASE_TURNSERVER_H_ + +#include <list> +#include <map> +#include <set> +#include <string> + +#include "webrtc/p2p/base/portinterface.h" +#include "webrtc/base/asyncpacketsocket.h" +#include "webrtc/base/messagequeue.h" +#include "webrtc/base/sigslot.h" +#include "webrtc/base/socketaddress.h" + +namespace rtc { +class ByteBuffer; +class PacketSocketFactory; +class Thread; +} + +namespace cricket { + +class StunMessage; +class TurnMessage; +class TurnServer; + +// The default server port for TURN, as specified in RFC5766. +const int TURN_SERVER_PORT = 3478; + +// Encapsulates the client's connection to the server. +class TurnServerConnection { + public: + TurnServerConnection() : proto_(PROTO_UDP), socket_(NULL) {} + TurnServerConnection(const rtc::SocketAddress& src, + ProtocolType proto, + rtc::AsyncPacketSocket* socket); + const rtc::SocketAddress& src() const { return src_; } + rtc::AsyncPacketSocket* socket() { return socket_; } + bool operator==(const TurnServerConnection& t) const; + bool operator<(const TurnServerConnection& t) const; + std::string ToString() const; + + private: + rtc::SocketAddress src_; + rtc::SocketAddress dst_; + cricket::ProtocolType proto_; + rtc::AsyncPacketSocket* socket_; +}; + +// Encapsulates a TURN allocation. +// The object is created when an allocation request is received, and then +// handles TURN messages (via HandleTurnMessage) and channel data messages +// (via HandleChannelData) for this allocation when received by the server. +// The object self-deletes and informs the server if its lifetime timer expires. +class TurnServerAllocation : public rtc::MessageHandler, + public sigslot::has_slots<> { + public: + TurnServerAllocation(TurnServer* server_, + rtc::Thread* thread, + const TurnServerConnection& conn, + rtc::AsyncPacketSocket* server_socket, + const std::string& key); + virtual ~TurnServerAllocation(); + + TurnServerConnection* conn() { return &conn_; } + const std::string& key() const { return key_; } + const std::string& transaction_id() const { return transaction_id_; } + const std::string& username() const { return username_; } + const std::string& origin() const { return origin_; } + const std::string& last_nonce() const { return last_nonce_; } + void set_last_nonce(const std::string& nonce) { last_nonce_ = nonce; } + + std::string ToString() const; + + void HandleTurnMessage(const TurnMessage* msg); + void HandleChannelData(const char* data, size_t size); + + sigslot::signal1<TurnServerAllocation*> SignalDestroyed; + + private: + class Channel; + class Permission; + typedef std::list<Permission*> PermissionList; + typedef std::list<Channel*> ChannelList; + + void HandleAllocateRequest(const TurnMessage* msg); + void HandleRefreshRequest(const TurnMessage* msg); + void HandleSendIndication(const TurnMessage* msg); + void HandleCreatePermissionRequest(const TurnMessage* msg); + void HandleChannelBindRequest(const TurnMessage* msg); + + void OnExternalPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketTime& packet_time); + + static int ComputeLifetime(const TurnMessage* msg); + bool HasPermission(const rtc::IPAddress& addr); + void AddPermission(const rtc::IPAddress& addr); + Permission* FindPermission(const rtc::IPAddress& addr) const; + Channel* FindChannel(int channel_id) const; + Channel* FindChannel(const rtc::SocketAddress& addr) const; + + void SendResponse(TurnMessage* msg); + void SendBadRequestResponse(const TurnMessage* req); + void SendErrorResponse(const TurnMessage* req, int code, + const std::string& reason); + void SendExternal(const void* data, size_t size, + const rtc::SocketAddress& peer); + + void OnPermissionDestroyed(Permission* perm); + void OnChannelDestroyed(Channel* channel); + virtual void OnMessage(rtc::Message* msg); + + TurnServer* server_; + rtc::Thread* thread_; + TurnServerConnection conn_; + rtc::scoped_ptr<rtc::AsyncPacketSocket> external_socket_; + std::string key_; + std::string transaction_id_; + std::string username_; + std::string origin_; + std::string last_nonce_; + PermissionList perms_; + ChannelList channels_; +}; + +// An interface through which the MD5 credential hash can be retrieved. +class TurnAuthInterface { + public: + // Gets HA1 for the specified user and realm. + // HA1 = MD5(A1) = MD5(username:realm:password). + // Return true if the given username and realm are valid, or false if not. + virtual bool GetKey(const std::string& username, const std::string& realm, + std::string* key) = 0; +}; + +// An interface enables Turn Server to control redirection behavior. +class TurnRedirectInterface { + public: + virtual bool ShouldRedirect(const rtc::SocketAddress& address, + rtc::SocketAddress* out) = 0; + virtual ~TurnRedirectInterface() {} +}; + +// The core TURN server class. Give it a socket to listen on via +// AddInternalServerSocket, and a factory to create external sockets via +// SetExternalSocketFactory, and it's ready to go. +// Not yet wired up: TCP support. +class TurnServer : public sigslot::has_slots<> { + public: + typedef std::map<TurnServerConnection, TurnServerAllocation*> AllocationMap; + + explicit TurnServer(rtc::Thread* thread); + ~TurnServer(); + + // Gets/sets the realm value to use for the server. + const std::string& realm() const { return realm_; } + void set_realm(const std::string& realm) { realm_ = realm; } + + // Gets/sets the value for the SOFTWARE attribute for TURN messages. + const std::string& software() const { return software_; } + void set_software(const std::string& software) { software_ = software; } + + const AllocationMap& allocations() const { return allocations_; } + + // Sets the authentication callback; does not take ownership. + void set_auth_hook(TurnAuthInterface* auth_hook) { auth_hook_ = auth_hook; } + + void set_redirect_hook(TurnRedirectInterface* redirect_hook) { + redirect_hook_ = redirect_hook; + } + + void set_enable_otu_nonce(bool enable) { enable_otu_nonce_ = enable; } + + // Starts listening for packets from internal clients. + void AddInternalSocket(rtc::AsyncPacketSocket* socket, + ProtocolType proto); + // Starts listening for the connections on this socket. When someone tries + // to connect, the connection will be accepted and a new internal socket + // will be added. + void AddInternalServerSocket(rtc::AsyncSocket* socket, + ProtocolType proto); + // Specifies the factory to use for creating external sockets. + void SetExternalSocketFactory(rtc::PacketSocketFactory* factory, + const rtc::SocketAddress& address); + + private: + void OnInternalPacket(rtc::AsyncPacketSocket* socket, const char* data, + size_t size, const rtc::SocketAddress& address, + const rtc::PacketTime& packet_time); + + void OnNewInternalConnection(rtc::AsyncSocket* socket); + + // Accept connections on this server socket. + void AcceptConnection(rtc::AsyncSocket* server_socket); + void OnInternalSocketClose(rtc::AsyncPacketSocket* socket, int err); + + void HandleStunMessage( + TurnServerConnection* conn, const char* data, size_t size); + void HandleBindingRequest(TurnServerConnection* conn, const StunMessage* msg); + void HandleAllocateRequest(TurnServerConnection* conn, const TurnMessage* msg, + const std::string& key); + + bool GetKey(const StunMessage* msg, std::string* key); + bool CheckAuthorization(TurnServerConnection* conn, const StunMessage* msg, + const char* data, size_t size, + const std::string& key); + std::string GenerateNonce() const; + bool ValidateNonce(const std::string& nonce) const; + + TurnServerAllocation* FindAllocation(TurnServerConnection* conn); + TurnServerAllocation* CreateAllocation( + TurnServerConnection* conn, int proto, const std::string& key); + + void SendErrorResponse(TurnServerConnection* conn, const StunMessage* req, + int code, const std::string& reason); + + void SendErrorResponseWithRealmAndNonce(TurnServerConnection* conn, + const StunMessage* req, + int code, + const std::string& reason); + + void SendErrorResponseWithAlternateServer(TurnServerConnection* conn, + const StunMessage* req, + const rtc::SocketAddress& addr); + + void SendStun(TurnServerConnection* conn, StunMessage* msg); + void Send(TurnServerConnection* conn, const rtc::ByteBuffer& buf); + + void OnAllocationDestroyed(TurnServerAllocation* allocation); + void DestroyInternalSocket(rtc::AsyncPacketSocket* socket); + + typedef std::map<rtc::AsyncPacketSocket*, + ProtocolType> InternalSocketMap; + typedef std::map<rtc::AsyncSocket*, + ProtocolType> ServerSocketMap; + + rtc::Thread* thread_; + std::string nonce_key_; + std::string realm_; + std::string software_; + TurnAuthInterface* auth_hook_; + TurnRedirectInterface* redirect_hook_; + // otu - one-time-use. Server will respond with 438 if it's + // sees the same nonce in next transaction. + bool enable_otu_nonce_; + + InternalSocketMap server_sockets_; + ServerSocketMap server_listen_sockets_; + rtc::scoped_ptr<rtc::PacketSocketFactory> + external_socket_factory_; + rtc::SocketAddress external_addr_; + + AllocationMap allocations_; + + friend class TurnServerAllocation; +}; + +} // namespace cricket + +#endif // WEBRTC_P2P_BASE_TURNSERVER_H_ diff --git a/webrtc/p2p/base/udpport.h b/webrtc/p2p/base/udpport.h new file mode 100644 index 0000000000..9f868644ee --- /dev/null +++ b/webrtc/p2p/base/udpport.h @@ -0,0 +1,17 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef WEBRTC_P2P_BASE_UDPPORT_H_ +#define WEBRTC_P2P_BASE_UDPPORT_H_ + +// StunPort will be handling UDPPort functionality. +#include "webrtc/p2p/base/stunport.h" + +#endif // WEBRTC_P2P_BASE_UDPPORT_H_ |