diff options
Diffstat (limited to 'mojo/public/cpp/bindings/lib/message.cc')
-rw-r--r-- | mojo/public/cpp/bindings/lib/message.cc | 332 |
1 files changed, 332 insertions, 0 deletions
diff --git a/mojo/public/cpp/bindings/lib/message.cc b/mojo/public/cpp/bindings/lib/message.cc new file mode 100644 index 0000000000..e5f3808117 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message.cc @@ -0,0 +1,332 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/message.h" + +#include <stddef.h> +#include <stdint.h> +#include <stdlib.h> + +#include <algorithm> +#include <utility> + +#include "base/bind.h" +#include "base/lazy_instance.h" +#include "base/logging.h" +#include "base/strings/stringprintf.h" +#include "base/threading/thread_local.h" +#include "mojo/public/cpp/bindings/associated_group_controller.h" +#include "mojo/public/cpp/bindings/lib/array_internal.h" + +namespace mojo { + +namespace { + +base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>:: + DestructorAtExit g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER; + +base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>:: + DestructorAtExit g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER; + +void DoNotifyBadMessage(Message message, const std::string& error) { + message.NotifyBadMessage(error); +} + +} // namespace + +Message::Message() { +} + +Message::Message(Message&& other) + : buffer_(std::move(other.buffer_)), + handles_(std::move(other.handles_)), + associated_endpoint_handles_( + std::move(other.associated_endpoint_handles_)) {} + +Message::~Message() { + CloseHandles(); +} + +Message& Message::operator=(Message&& other) { + Reset(); + std::swap(other.buffer_, buffer_); + std::swap(other.handles_, handles_); + std::swap(other.associated_endpoint_handles_, associated_endpoint_handles_); + return *this; +} + +void Message::Reset() { + CloseHandles(); + handles_.clear(); + associated_endpoint_handles_.clear(); + buffer_.reset(); +} + +void Message::Initialize(size_t capacity, bool zero_initialized) { + DCHECK(!buffer_); + buffer_.reset(new internal::MessageBuffer(capacity, zero_initialized)); +} + +void Message::InitializeFromMojoMessage(ScopedMessageHandle message, + uint32_t num_bytes, + std::vector<Handle>* handles) { + DCHECK(!buffer_); + buffer_.reset(new internal::MessageBuffer(std::move(message), num_bytes)); + handles_.swap(*handles); +} + +const uint8_t* Message::payload() const { + if (version() < 2) + return data() + header()->num_bytes; + + return static_cast<const uint8_t*>(header_v2()->payload.Get()); +} + +uint32_t Message::payload_num_bytes() const { + DCHECK_GE(data_num_bytes(), header()->num_bytes); + size_t num_bytes; + if (version() < 2) { + num_bytes = data_num_bytes() - header()->num_bytes; + } else { + auto payload = reinterpret_cast<uintptr_t>(header_v2()->payload.Get()); + if (!payload) { + num_bytes = 0; + } else { + auto payload_end = + reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get()); + if (!payload_end) + payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes()); + DCHECK_GE(payload_end, payload); + num_bytes = payload_end - payload; + } + } + DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max()); + return static_cast<uint32_t>(num_bytes); +} + +uint32_t Message::payload_num_interface_ids() const { + auto* array_pointer = + version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); + return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0; +} + +const uint32_t* Message::payload_interface_ids() const { + auto* array_pointer = + version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); + return array_pointer ? array_pointer->storage() : nullptr; +} + +ScopedMessageHandle Message::TakeMojoMessage() { + // If there are associated endpoints transferred, + // SerializeAssociatedEndpointHandles() must be called before this method. + DCHECK(associated_endpoint_handles_.empty()); + + if (handles_.empty()) // Fast path for the common case: No handles. + return buffer_->TakeMessage(); + + // Allocate a new message with space for the handles, then copy the buffer + // contents into it. + // + // TODO(rockot): We could avoid this copy by extending GetSerializedSize() + // behavior to collect handles. It's unoptimized for now because it's much + // more common to have messages with no handles. + ScopedMessageHandle new_message; + MojoResult rv = AllocMessage( + data_num_bytes(), + handles_.empty() ? nullptr + : reinterpret_cast<const MojoHandle*>(handles_.data()), + handles_.size(), + MOJO_ALLOC_MESSAGE_FLAG_NONE, + &new_message); + CHECK_EQ(rv, MOJO_RESULT_OK); + handles_.clear(); + + void* new_buffer = nullptr; + rv = GetMessageBuffer(new_message.get(), &new_buffer); + CHECK_EQ(rv, MOJO_RESULT_OK); + + memcpy(new_buffer, data(), data_num_bytes()); + buffer_.reset(); + + return new_message; +} + +void Message::NotifyBadMessage(const std::string& error) { + DCHECK(buffer_); + buffer_->NotifyBadMessage(error); +} + +void Message::CloseHandles() { + for (std::vector<Handle>::iterator it = handles_.begin(); + it != handles_.end(); ++it) { + if (it->is_valid()) + CloseRaw(*it); + } +} + +void Message::SerializeAssociatedEndpointHandles( + AssociatedGroupController* group_controller) { + if (associated_endpoint_handles_.empty()) + return; + + DCHECK_GE(version(), 2u); + DCHECK(header_v2()->payload_interface_ids.is_null()); + + size_t size = associated_endpoint_handles_.size(); + auto* data = internal::Array_Data<uint32_t>::New(size, buffer()); + header_v2()->payload_interface_ids.Set(data); + + for (size_t i = 0; i < size; ++i) { + ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i]; + + DCHECK(handle.pending_association()); + data->storage()[i] = + group_controller->AssociateInterface(std::move(handle)); + } + associated_endpoint_handles_.clear(); +} + +bool Message::DeserializeAssociatedEndpointHandles( + AssociatedGroupController* group_controller) { + associated_endpoint_handles_.clear(); + + uint32_t num_ids = payload_num_interface_ids(); + if (num_ids == 0) + return true; + + associated_endpoint_handles_.reserve(num_ids); + uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage(); + bool result = true; + for (uint32_t i = 0; i < num_ids; ++i) { + auto handle = group_controller->CreateLocalEndpointHandle(ids[i]); + if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) { + // |ids[i]| itself is valid but handle creation failed. In that case, mark + // deserialization as failed but continue to deserialize the rest of + // handles. + result = false; + } + + associated_endpoint_handles_.push_back(std::move(handle)); + ids[i] = kInvalidInterfaceId; + } + return result; +} + +PassThroughFilter::PassThroughFilter() {} + +PassThroughFilter::~PassThroughFilter() {} + +bool PassThroughFilter::Accept(Message* message) { return true; } + +SyncMessageResponseContext::SyncMessageResponseContext() + : outer_context_(current()) { + g_tls_sync_response_context.Get().Set(this); +} + +SyncMessageResponseContext::~SyncMessageResponseContext() { + DCHECK_EQ(current(), this); + g_tls_sync_response_context.Get().Set(outer_context_); +} + +// static +SyncMessageResponseContext* SyncMessageResponseContext::current() { + return g_tls_sync_response_context.Get().Get(); +} + +void SyncMessageResponseContext::ReportBadMessage(const std::string& error) { + GetBadMessageCallback().Run(error); +} + +const ReportBadMessageCallback& +SyncMessageResponseContext::GetBadMessageCallback() { + if (bad_message_callback_.is_null()) { + bad_message_callback_ = + base::Bind(&DoNotifyBadMessage, base::Passed(&response_)); + } + return bad_message_callback_; +} + +MojoResult ReadMessage(MessagePipeHandle handle, Message* message) { + MojoResult rv; + + std::vector<Handle> handles; + ScopedMessageHandle mojo_message; + uint32_t num_bytes = 0, num_handles = 0; + rv = ReadMessageNew(handle, + &mojo_message, + &num_bytes, + nullptr, + &num_handles, + MOJO_READ_MESSAGE_FLAG_NONE); + if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) { + DCHECK_GT(num_handles, 0u); + handles.resize(num_handles); + rv = ReadMessageNew(handle, + &mojo_message, + &num_bytes, + reinterpret_cast<MojoHandle*>(handles.data()), + &num_handles, + MOJO_READ_MESSAGE_FLAG_NONE); + } + + if (rv != MOJO_RESULT_OK) + return rv; + + message->InitializeFromMojoMessage( + std::move(mojo_message), num_bytes, &handles); + return MOJO_RESULT_OK; +} + +void ReportBadMessage(const std::string& error) { + internal::MessageDispatchContext* context = + internal::MessageDispatchContext::current(); + DCHECK(context); + context->GetBadMessageCallback().Run(error); +} + +ReportBadMessageCallback GetBadMessageCallback() { + internal::MessageDispatchContext* context = + internal::MessageDispatchContext::current(); + DCHECK(context); + return context->GetBadMessageCallback(); +} + +namespace internal { + +MessageHeaderV2::MessageHeaderV2() = default; + +MessageDispatchContext::MessageDispatchContext(Message* message) + : outer_context_(current()), message_(message) { + g_tls_message_dispatch_context.Get().Set(this); +} + +MessageDispatchContext::~MessageDispatchContext() { + DCHECK_EQ(current(), this); + g_tls_message_dispatch_context.Get().Set(outer_context_); +} + +// static +MessageDispatchContext* MessageDispatchContext::current() { + return g_tls_message_dispatch_context.Get().Get(); +} + +const ReportBadMessageCallback& +MessageDispatchContext::GetBadMessageCallback() { + if (bad_message_callback_.is_null()) { + bad_message_callback_ = + base::Bind(&DoNotifyBadMessage, base::Passed(message_)); + } + return bad_message_callback_; +} + +// static +void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) { + SyncMessageResponseContext* context = SyncMessageResponseContext::current(); + if (context) + context->response_ = std::move(*message); +} + +} // namespace internal + +} // namespace mojo |