summaryrefslogtreecommitdiff
path: root/mojo/core/channel_win.cc
diff options
context:
space:
mode:
Diffstat (limited to 'mojo/core/channel_win.cc')
-rw-r--r--mojo/core/channel_win.cc377
1 files changed, 377 insertions, 0 deletions
diff --git a/mojo/core/channel_win.cc b/mojo/core/channel_win.cc
new file mode 100644
index 0000000000..30a14867be
--- /dev/null
+++ b/mojo/core/channel_win.cc
@@ -0,0 +1,377 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "mojo/core/channel.h"
+
+#include <stdint.h>
+#include <windows.h>
+
+#include <algorithm>
+#include <limits>
+#include <memory>
+
+#include "base/bind.h"
+#include "base/containers/queue.h"
+#include "base/location.h"
+#include "base/macros.h"
+#include "base/memory/ref_counted.h"
+#include "base/message_loop/message_loop_current.h"
+#include "base/message_loop/message_pump_for_io.h"
+#include "base/process/process_handle.h"
+#include "base/synchronization/lock.h"
+#include "base/task_runner.h"
+#include "base/win/scoped_handle.h"
+#include "base/win/win_util.h"
+
+namespace mojo {
+namespace core {
+
+namespace {
+
+class ChannelWin : public Channel,
+ public base::MessageLoopCurrent::DestructionObserver,
+ public base::MessagePumpForIO::IOHandler {
+ public:
+ ChannelWin(Delegate* delegate,
+ ConnectionParams connection_params,
+ scoped_refptr<base::TaskRunner> io_task_runner)
+ : Channel(delegate), self_(this), io_task_runner_(io_task_runner) {
+ if (connection_params.server_endpoint().is_valid()) {
+ handle_ = connection_params.TakeServerEndpoint()
+ .TakePlatformHandle()
+ .TakeHandle();
+ needs_connection_ = true;
+ } else {
+ handle_ =
+ connection_params.TakeEndpoint().TakePlatformHandle().TakeHandle();
+ }
+
+ CHECK(handle_.IsValid());
+ }
+
+ void Start() override {
+ io_task_runner_->PostTask(
+ FROM_HERE, base::BindOnce(&ChannelWin::StartOnIOThread, this));
+ }
+
+ void ShutDownImpl() override {
+ // Always shut down asynchronously when called through the public interface.
+ io_task_runner_->PostTask(
+ FROM_HERE, base::BindOnce(&ChannelWin::ShutDownOnIOThread, this));
+ }
+
+ void Write(MessagePtr message) override {
+ if (remote_process().is_valid()) {
+ // If we know the remote process handle, we transfer all outgoing handles
+ // to the process now rewriting them in the message.
+ std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
+ for (auto& handle : handles) {
+ if (handle.handle().is_valid())
+ handle.TransferToProcess(remote_process().Clone());
+ }
+ message->SetHandles(std::move(handles));
+ }
+
+ bool write_error = false;
+ {
+ base::AutoLock lock(write_lock_);
+ if (reject_writes_)
+ return;
+
+ bool write_now = !delay_writes_ && outgoing_messages_.empty();
+ outgoing_messages_.emplace_back(std::move(message));
+ if (write_now && !WriteNoLock(outgoing_messages_.front()))
+ reject_writes_ = write_error = true;
+ }
+ if (write_error) {
+ // Do not synchronously invoke OnWriteError(). Write() may have been
+ // called by the delegate and we don't want to re-enter it.
+ io_task_runner_->PostTask(FROM_HERE,
+ base::BindOnce(&ChannelWin::OnWriteError, this,
+ Error::kDisconnected));
+ }
+ }
+
+ void LeakHandle() override {
+ DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
+ leak_handle_ = true;
+ }
+
+ bool GetReadPlatformHandles(const void* payload,
+ size_t payload_size,
+ size_t num_handles,
+ const void* extra_header,
+ size_t extra_header_size,
+ std::vector<PlatformHandle>* handles,
+ bool* deferred) override {
+ DCHECK(extra_header);
+ if (num_handles > std::numeric_limits<uint16_t>::max())
+ return false;
+ using HandleEntry = Channel::Message::HandleEntry;
+ size_t handles_size = sizeof(HandleEntry) * num_handles;
+ if (handles_size > extra_header_size)
+ return false;
+ handles->reserve(num_handles);
+ const HandleEntry* extra_header_handles =
+ reinterpret_cast<const HandleEntry*>(extra_header);
+ for (size_t i = 0; i < num_handles; i++) {
+ HANDLE handle_value =
+ base::win::Uint32ToHandle(extra_header_handles[i].handle);
+ if (remote_process().is_valid()) {
+ // If we know the remote process's handle, we assume it doesn't know
+ // ours; that means any handle values still belong to that process, and
+ // we need to transfer them to this process.
+ handle_value = PlatformHandleInTransit::TakeIncomingRemoteHandle(
+ handle_value, remote_process().get())
+ .ReleaseHandle();
+ }
+ handles->emplace_back(base::win::ScopedHandle(std::move(handle_value)));
+ }
+ return true;
+ }
+
+ private:
+ // May run on any thread.
+ ~ChannelWin() override {}
+
+ void StartOnIOThread() {
+ base::MessageLoopCurrent::Get()->AddDestructionObserver(this);
+ base::MessageLoopCurrentForIO::Get()->RegisterIOHandler(handle_.Get(),
+ this);
+
+ if (needs_connection_) {
+ BOOL ok = ::ConnectNamedPipe(handle_.Get(), &connect_context_.overlapped);
+ if (ok) {
+ PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
+ OnError(Error::kConnectionFailed);
+ return;
+ }
+
+ const DWORD err = GetLastError();
+ switch (err) {
+ case ERROR_PIPE_CONNECTED:
+ break;
+ case ERROR_IO_PENDING:
+ is_connect_pending_ = true;
+ AddRef();
+ return;
+ case ERROR_NO_DATA:
+ default:
+ OnError(Error::kConnectionFailed);
+ return;
+ }
+ }
+
+ // Now that we have registered our IOHandler, we can start writing.
+ {
+ base::AutoLock lock(write_lock_);
+ if (delay_writes_) {
+ delay_writes_ = false;
+ WriteNextNoLock();
+ }
+ }
+
+ // Keep this alive in case we synchronously run shutdown, via OnError(),
+ // as a result of a ReadFile() failure on the channel.
+ scoped_refptr<ChannelWin> keep_alive(this);
+ ReadMore(0);
+ }
+
+ void ShutDownOnIOThread() {
+ base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this);
+
+ // TODO(https://crbug.com/583525): This function is expected to be called
+ // once, and |handle_| should be valid at this point.
+ CHECK(handle_.IsValid());
+ CancelIo(handle_.Get());
+ if (leak_handle_)
+ ignore_result(handle_.Take());
+ else
+ handle_.Close();
+
+ // Allow |this| to be destroyed as soon as no IO is pending.
+ self_ = nullptr;
+ }
+
+ // base::MessageLoopCurrent::DestructionObserver:
+ void WillDestroyCurrentMessageLoop() override {
+ DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
+ if (self_)
+ ShutDownOnIOThread();
+ }
+
+ // base::MessageLoop::IOHandler:
+ void OnIOCompleted(base::MessagePumpForIO::IOContext* context,
+ DWORD bytes_transfered,
+ DWORD error) override {
+ if (error != ERROR_SUCCESS) {
+ if (context == &write_context_) {
+ {
+ base::AutoLock lock(write_lock_);
+ reject_writes_ = true;
+ }
+ OnWriteError(Error::kDisconnected);
+ } else {
+ OnError(Error::kDisconnected);
+ }
+ } else if (context == &connect_context_) {
+ DCHECK(is_connect_pending_);
+ is_connect_pending_ = false;
+ ReadMore(0);
+
+ base::AutoLock lock(write_lock_);
+ if (delay_writes_) {
+ delay_writes_ = false;
+ WriteNextNoLock();
+ }
+ } else if (context == &read_context_) {
+ OnReadDone(static_cast<size_t>(bytes_transfered));
+ } else {
+ CHECK(context == &write_context_);
+ OnWriteDone(static_cast<size_t>(bytes_transfered));
+ }
+ Release();
+ }
+
+ void OnReadDone(size_t bytes_read) {
+ DCHECK(is_read_pending_);
+ is_read_pending_ = false;
+
+ if (bytes_read > 0) {
+ size_t next_read_size = 0;
+ if (OnReadComplete(bytes_read, &next_read_size)) {
+ ReadMore(next_read_size);
+ } else {
+ OnError(Error::kReceivedMalformedData);
+ }
+ } else if (bytes_read == 0) {
+ OnError(Error::kDisconnected);
+ }
+ }
+
+ void OnWriteDone(size_t bytes_written) {
+ if (bytes_written == 0)
+ return;
+
+ bool write_error = false;
+ {
+ base::AutoLock lock(write_lock_);
+
+ DCHECK(is_write_pending_);
+ is_write_pending_ = false;
+ DCHECK(!outgoing_messages_.empty());
+
+ Channel::MessagePtr message = std::move(outgoing_messages_.front());
+ outgoing_messages_.pop_front();
+
+ // Invalidate all the scoped handles so we don't attempt to close them.
+ std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
+ for (auto& handle : handles)
+ handle.CompleteTransit();
+
+ // Overlapped WriteFile() to a pipe should always fully complete.
+ if (message->data_num_bytes() != bytes_written)
+ reject_writes_ = write_error = true;
+ else if (!WriteNextNoLock())
+ reject_writes_ = write_error = true;
+ }
+ if (write_error)
+ OnWriteError(Error::kDisconnected);
+ }
+
+ void ReadMore(size_t next_read_size_hint) {
+ DCHECK(!is_read_pending_);
+
+ size_t buffer_capacity = next_read_size_hint;
+ char* buffer = GetReadBuffer(&buffer_capacity);
+ DCHECK_GT(buffer_capacity, 0u);
+
+ BOOL ok =
+ ::ReadFile(handle_.Get(), buffer, static_cast<DWORD>(buffer_capacity),
+ NULL, &read_context_.overlapped);
+ if (ok || GetLastError() == ERROR_IO_PENDING) {
+ is_read_pending_ = true;
+ AddRef();
+ } else {
+ OnError(Error::kDisconnected);
+ }
+ }
+
+ // Attempts to write a message directly to the channel. If the full message
+ // cannot be written, it's queued and a wait is initiated to write the message
+ // ASAP on the I/O thread.
+ bool WriteNoLock(const Channel::MessagePtr& message) {
+ BOOL ok = WriteFile(handle_.Get(), message->data(),
+ static_cast<DWORD>(message->data_num_bytes()), NULL,
+ &write_context_.overlapped);
+ if (ok || GetLastError() == ERROR_IO_PENDING) {
+ is_write_pending_ = true;
+ AddRef();
+ return true;
+ }
+ return false;
+ }
+
+ bool WriteNextNoLock() {
+ if (outgoing_messages_.empty())
+ return true;
+ return WriteNoLock(outgoing_messages_.front());
+ }
+
+ void OnWriteError(Error error) {
+ DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
+ DCHECK(reject_writes_);
+
+ if (error == Error::kDisconnected) {
+ // If we can't write because the pipe is disconnected then continue
+ // reading to fetch any in-flight messages, relying on end-of-stream to
+ // signal the actual disconnection.
+ if (is_read_pending_ || is_connect_pending_)
+ return;
+ }
+
+ OnError(error);
+ }
+
+ // Keeps the Channel alive at least until explicit shutdown on the IO thread.
+ scoped_refptr<Channel> self_;
+
+ // The pipe handle this Channel uses for communication.
+ base::win::ScopedHandle handle_;
+
+ // Indicates whether |handle_| must wait for a connection.
+ bool needs_connection_ = false;
+
+ const scoped_refptr<base::TaskRunner> io_task_runner_;
+
+ base::MessagePumpForIO::IOContext connect_context_;
+ base::MessagePumpForIO::IOContext read_context_;
+ bool is_connect_pending_ = false;
+ bool is_read_pending_ = false;
+
+ // Protects all fields potentially accessed on multiple threads via Write().
+ base::Lock write_lock_;
+ base::MessagePumpForIO::IOContext write_context_;
+ base::circular_deque<Channel::MessagePtr> outgoing_messages_;
+ bool delay_writes_ = true;
+ bool reject_writes_ = false;
+ bool is_write_pending_ = false;
+
+ bool leak_handle_ = false;
+
+ DISALLOW_COPY_AND_ASSIGN(ChannelWin);
+};
+
+} // namespace
+
+// static
+scoped_refptr<Channel> Channel::Create(
+ Delegate* delegate,
+ ConnectionParams connection_params,
+ scoped_refptr<base::TaskRunner> io_task_runner) {
+ return new ChannelWin(delegate, std::move(connection_params), io_task_runner);
+}
+
+} // namespace core
+} // namespace mojo