/* * Copyright (C) 2016 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "UnixSocket.h" #include #include #include #include #include #include #include "IOEventLoop.h" static bool CreateUnixSocketAddress(const std::string& server_path, bool is_abstract, sockaddr_un& serv_addr) { memset(&serv_addr, 0, sizeof(serv_addr)); serv_addr.sun_family = AF_UNIX; size_t sun_path_len = sizeof(serv_addr.sun_path); char* p = serv_addr.sun_path; if (is_abstract) { sun_path_len--; p++; } if (server_path.size() + 1 > sun_path_len) { LOG(ERROR) << "can't create unix domain socket as server path is too long: " << server_path; return false; } strcpy(p, server_path.c_str()); return true; } std::unique_ptr UnixSocketServer::Create( const std::string& server_path, bool is_abstract) { int sockfd = socket(AF_UNIX, SOCK_STREAM, 0); if (sockfd < 0) { PLOG(ERROR) << "socket() failed"; return nullptr; } sockaddr_un serv_addr; if (!CreateUnixSocketAddress(server_path, is_abstract, serv_addr)) { return nullptr; } if (bind(sockfd, reinterpret_cast(&serv_addr), sizeof(serv_addr)) < 0) { PLOG(ERROR) << "bind() failed for " << server_path; return nullptr; } if (listen(sockfd, 1) < 0) { PLOG(ERROR) << "listen() failed"; return nullptr; } return std::unique_ptr( new UnixSocketServer(sockfd, server_path)); } UnixSocketServer::~UnixSocketServer() { close(server_fd_); } std::unique_ptr UnixSocketServer::AcceptConnection() { int sockfd = accept(server_fd_, nullptr, nullptr); if (sockfd < 0) { PLOG(ERROR) << "accept() failed"; return nullptr; } return std::unique_ptr( new UnixSocketConnection(sockfd)); } std::unique_ptr UnixSocketConnection::Connect( const std::string& server_path, bool is_abstract) { int sockfd = socket(AF_UNIX, SOCK_STREAM, 0); if (sockfd < 0) { PLOG(DEBUG) << "socket() failed"; return nullptr; } sockaddr_un serv_addr; if (!CreateUnixSocketAddress(server_path, is_abstract, serv_addr)) { return nullptr; } if (connect(sockfd, reinterpret_cast(&serv_addr), sizeof(serv_addr)) < 0) { PLOG(DEBUG) << "connect() failed, server_path = " << server_path; return nullptr; } return std::unique_ptr( new UnixSocketConnection(sockfd)); } bool UnixSocketConnection::PrepareForIO( IOEventLoop& loop, const std::function& receive_message_callback, const std::function& close_connection_callback) { read_callback_ = receive_message_callback; close_callback_ = close_connection_callback; read_event_ = loop.AddReadEvent(fd_, [&]() { return ReadData(); }); if (read_event_ == nullptr) { return false; } std::lock_guard lock(send_buffer_and_write_event_mtx_); write_event_ = loop.AddWriteEvent(fd_, [&]() { return WriteData(); }); if (write_event_ == nullptr) { return false; } return DisableWriteEventWithLock(); } bool UnixSocketConnection::WriteData() { const char* write_data; size_t write_data_size; if (!GetDataFromSendBuffer(&write_data, &write_data_size)) { return false; } if (write_data_size == 0u) { return true; } // Use MSG_NOSIGNAL to prevent receiving SIGPIPE. ssize_t result = TEMP_FAILURE_RETRY(send(fd_, write_data, write_data_size, MSG_NOSIGNAL)); if (result >= 0) { std::lock_guard lock(send_buffer_and_write_event_mtx_); send_buffer_.CommitData(result); } else if (errno != EAGAIN) { PLOG(ERROR) << "send() failed"; return false; } return true; } bool UnixSocketConnection::GetDataFromSendBuffer(const char** pdata, size_t* pdata_size) { { std::lock_guard lock(send_buffer_and_write_event_mtx_); *pdata_size = send_buffer_.PeekData(pdata); if (*pdata_size != 0u) { return true; } // The send buffer is empty. If we can receive more messages, just disable // the write event temporarily, otherwise close the connection. if (!no_more_message_) { return DisableWriteEventWithLock(); } } return CloseConnection(); } bool UnixSocketConnection::ReadData() { ssize_t result = TEMP_FAILURE_RETRY(read(fd_, &read_buffer_[read_buffer_size_], read_buffer_.size() - read_buffer_size_)); if (result < 0) { if (errno == EAGAIN) { return true; } PLOG(ERROR) << "read() failed"; return false; } else if (result == 0) { // The connection is closed, and no need to write pending messages. return CloseConnection(); } read_buffer_size_ += result; return ConsumeDataInReadBuffer(); } bool UnixSocketConnection::ConsumeDataInReadBuffer() { char* p = read_buffer_.data(); size_t left_size = read_buffer_size_; uint32_t aligned_len = 0; while (left_size >= sizeof(UnixSocketMessage)) { UnixSocketMessage* msg = reinterpret_cast(p); aligned_len = Align(msg->len, UnixSocketMessageAlignment); if (left_size < aligned_len) { break; } if (!read_callback_(*msg)) { return false; } p += aligned_len; left_size -= aligned_len; } if (left_size > 0u) { // Move the unfinished message to the start of read_buffer_. memmove(read_buffer_.data(), p, left_size); // Extend the buffer to store this big message. if (aligned_len > read_buffer_.size()) { read_buffer_.resize(aligned_len); } } read_buffer_size_ = left_size; return true; } bool UnixSocketConnection::CloseConnection() { // Disable read_event and write_event here, so ReadData() and WriteData() // won't be called in the future. if (!IOEventLoop::DisableEvent(read_event_)) { return false; } { std::lock_guard lock(send_buffer_and_write_event_mtx_); no_more_message_ = true; if (!DisableWriteEventWithLock()) { return false; } } close(fd_); fd_ = -1; return close_callback_(); } UnixSocketConnection::~UnixSocketConnection() { if (fd_ != -1) { // It only happens when IO operations are not finished properly by // CloseConnection(). Don't call CloseConnection() here as the // IOEventLoop used to register read/write events may have been destroyed. close(fd_); } }