diff options
Diffstat (limited to 'mojo/public/cpp/bindings/lib')
70 files changed, 2947 insertions, 1983 deletions
diff --git a/mojo/public/cpp/bindings/lib/array_internal.h b/mojo/public/cpp/bindings/lib/array_internal.h index eecfcfbc28..574be9b6f5 100644 --- a/mojo/public/cpp/bindings/lib/array_internal.h +++ b/mojo/public/cpp/bindings/lib/array_internal.h @@ -11,9 +11,10 @@ #include <limits> #include <new> +#include "base/component_export.h" #include "base/logging.h" +#include "base/macros.h" #include "mojo/public/c/system/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" #include "mojo/public/cpp/bindings/lib/buffer.h" #include "mojo/public/cpp/bindings/lib/serialization_util.h" @@ -29,13 +30,15 @@ namespace internal { template <typename K, typename V> class Map_Data; -MOJO_CPP_BINDINGS_EXPORT std::string -MakeMessageWithArrayIndex(const char* message, size_t size, size_t index); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +std::string MakeMessageWithArrayIndex(const char* message, + size_t size, + size_t index); -MOJO_CPP_BINDINGS_EXPORT std::string MakeMessageWithExpectedArraySize( - const char* message, - size_t size, - size_t expected_size); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +std::string MakeMessageWithExpectedArraySize(const char* message, + size_t size, + size_t expected_size); template <typename T> struct ArrayDataTraits { @@ -68,7 +71,7 @@ template <> struct ArrayDataTraits<bool> { // Helper class to emulate a reference to a bool, used for direct element // access. - class MOJO_CPP_BINDINGS_EXPORT BitRef { + class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) BitRef { public: ~BitRef(); BitRef& operator=(bool value); @@ -268,17 +271,35 @@ class Array_Data { std::is_same<T, Handle_Data>::value>; using Element = T; - // Returns null if |num_elements| or the corresponding storage size cannot be - // stored in uint32_t. - static Array_Data<T>* New(size_t num_elements, Buffer* buf) { - if (num_elements > Traits::kMaxNumElements) - return nullptr; + class BufferWriter { + public: + BufferWriter() = default; + + void Allocate(size_t num_elements, Buffer* buffer) { + if (num_elements > Traits::kMaxNumElements) + return; + + uint32_t num_bytes = + Traits::GetStorageSize(static_cast<uint32_t>(num_elements)); + buffer_ = buffer; + index_ = buffer_->Allocate(num_bytes); + new (data()) + Array_Data<T>(num_bytes, static_cast<uint32_t>(num_elements)); + } - uint32_t num_bytes = - Traits::GetStorageSize(static_cast<uint32_t>(num_elements)); - return new (buf->Allocate(num_bytes)) - Array_Data<T>(num_bytes, static_cast<uint32_t>(num_elements)); - } + bool is_null() const { return !buffer_; } + Array_Data<T>* data() { + DCHECK(!is_null()); + return buffer_->Get<Array_Data<T>>(index_); + } + Array_Data<T>* operator->() { return data(); } + + private: + Buffer* buffer_ = nullptr; + size_t index_ = 0; + + DISALLOW_COPY_AND_ASSIGN(BufferWriter); + }; static bool Validate(const void* data, ValidationContext* validation_context, diff --git a/mojo/public/cpp/bindings/lib/array_serialization.h b/mojo/public/cpp/bindings/lib/array_serialization.h index d2f8ecfd72..8323b5f9a1 100644 --- a/mojo/public/cpp/bindings/lib/array_serialization.h +++ b/mojo/public/cpp/bindings/lib/array_serialization.h @@ -112,20 +112,19 @@ struct ArraySerializer< using DataElement = typename Data::Element; using Element = typename MojomType::Element; using Traits = ArrayTraits<UserType>; + using BufferWriter = typename Data::BufferWriter; static_assert(std::is_same<Element, DataElement>::value, "Incorrect array serializer"); - static_assert(std::is_same<Element, typename Traits::Element>::value, - "Incorrect array serializer"); - - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - return sizeof(Data) + Align(input->GetSize() * sizeof(DataElement)); - } + static_assert( + std::is_same< + Element, + typename std::remove_const<typename Traits::Element>::type>::value, + "Incorrect array serializer"); static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(!validate_params->element_is_nullable) @@ -138,6 +137,7 @@ struct ArraySerializer< return; auto data = input->GetDataIfExists(); + Data* output = writer->data(); if (data) { memcpy(output->storage(), data, size * sizeof(DataElement)); } else { @@ -180,18 +180,14 @@ struct ArraySerializer< using DataElement = typename Data::Element; using Element = typename MojomType::Element; using Traits = ArrayTraits<UserType>; + using BufferWriter = typename Data::BufferWriter; static_assert(sizeof(Element) == sizeof(DataElement), "Incorrect array serializer"); - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - return sizeof(Data) + Align(input->GetSize() * sizeof(DataElement)); - } - static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(!validate_params->element_is_nullable) @@ -199,6 +195,7 @@ struct ArraySerializer< DCHECK(!validate_params->element_validate_params) << "Primitive type should not have array validate params"; + Data* output = writer->data(); size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) Serialize<Element>(input->GetNext(), output->storage() + i); @@ -231,18 +228,14 @@ struct ArraySerializer<MojomType, using UserType = typename std::remove_const<MaybeConstUserType>::type; using Traits = ArrayTraits<UserType>; using Data = typename MojomTypeTraits<MojomType>::Data; + using BufferWriter = typename Data::BufferWriter; static_assert(std::is_same<bool, typename Traits::Element>::value, "Incorrect array serializer"); - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - return sizeof(Data) + Align((input->GetSize() + 7) / 8); - } - static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(!validate_params->element_is_nullable) @@ -250,6 +243,7 @@ struct ArraySerializer<MojomType, DCHECK(!validate_params->element_validate_params) << "Primitive type should not have array validate params"; + Data* output = writer->data(); size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) output->at(i) = input->GetNext(); @@ -278,37 +272,23 @@ struct ArraySerializer< BelongsTo<typename MojomType::Element, MojomTypeCategory::ASSOCIATED_INTERFACE | MojomTypeCategory::ASSOCIATED_INTERFACE_REQUEST | - MojomTypeCategory::HANDLE | - MojomTypeCategory::INTERFACE | + MojomTypeCategory::HANDLE | MojomTypeCategory::INTERFACE | MojomTypeCategory::INTERFACE_REQUEST>::value>::type> { using UserType = typename std::remove_const<MaybeConstUserType>::type; using Data = typename MojomTypeTraits<MojomType>::Data; using Element = typename MojomType::Element; using Traits = ArrayTraits<UserType>; - - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - size_t element_count = input->GetSize(); - if (BelongsTo<Element, - MojomTypeCategory::ASSOCIATED_INTERFACE | - MojomTypeCategory::ASSOCIATED_INTERFACE_REQUEST>::value) { - for (size_t i = 0; i < element_count; ++i) { - typename UserTypeIterator::GetNextResult next = input->GetNext(); - size_t size = PrepareToSerialize<Element>(next, context); - DCHECK_EQ(size, 0u); - } - } - return sizeof(Data) + Align(element_count * sizeof(typename Data::Element)); - } + using BufferWriter = typename Data::BufferWriter; static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(!validate_params->element_validate_params) << "Handle or interface type should not have array validate params"; + Data* output = writer->data(); size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) { typename UserTypeIterator::GetNextResult next = input->GetNext(); @@ -360,35 +340,27 @@ struct ArraySerializer<MojomType, using UserType = typename std::remove_const<MaybeConstUserType>::type; using Data = typename MojomTypeTraits<MojomType>::Data; using Element = typename MojomType::Element; - using DataElementPtr = typename MojomTypeTraits<Element>::Data*; + using DataElementWriter = + typename MojomTypeTraits<Element>::Data::BufferWriter; using Traits = ArrayTraits<UserType>; - - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - size_t element_count = input->GetSize(); - size_t size = sizeof(Data) + element_count * sizeof(typename Data::Element); - for (size_t i = 0; i < element_count; ++i) { - typename UserTypeIterator::GetNextResult next = input->GetNext(); - size += PrepareToSerialize<Element>(next, context); - } - return size; - } + using BufferWriter = typename Data::BufferWriter; static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) { - DataElementPtr data_ptr; + DataElementWriter data_writer; typename UserTypeIterator::GetNextResult next = input->GetNext(); - SerializeCaller<Element>::Run(next, buf, &data_ptr, + SerializeCaller<Element>::Run(next, buf, &data_writer, validate_params->element_validate_params, context); - output->at(i).Set(data_ptr); + writer->data()->at(i).Set(data_writer.is_null() ? nullptr + : data_writer.data()); MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( - !validate_params->element_is_nullable && !data_ptr, + !validate_params->element_is_nullable && data_writer.is_null(), VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, MakeMessageWithArrayIndex("null in array expecting valid pointers", size, i)); @@ -417,10 +389,10 @@ struct ArraySerializer<MojomType, template <typename InputElementType> static void Run(InputElementType&& input, Buffer* buf, - DataElementPtr* output, + DataElementWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { - Serialize<T>(std::forward<InputElementType>(input), buf, output, context); + Serialize<T>(std::forward<InputElementType>(input), buf, writer, context); } }; @@ -429,10 +401,10 @@ struct ArraySerializer<MojomType, template <typename InputElementType> static void Run(InputElementType&& input, Buffer* buf, - DataElementPtr* output, + DataElementWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { - Serialize<T>(std::forward<InputElementType>(input), buf, output, + Serialize<T>(std::forward<InputElementType>(input), buf, writer, validate_params, context); } }; @@ -451,33 +423,24 @@ struct ArraySerializer< using UserType = typename std::remove_const<MaybeConstUserType>::type; using Data = typename MojomTypeTraits<MojomType>::Data; using Element = typename MojomType::Element; + using ElementWriter = typename Data::Element::BufferWriter; using Traits = ArrayTraits<UserType>; - - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - size_t element_count = input->GetSize(); - size_t size = sizeof(Data); - for (size_t i = 0; i < element_count; ++i) { - // Call with |inlined| set to false, so that it will account for both the - // data in the union and the space in the array used to hold the union. - typename UserTypeIterator::GetNextResult next = input->GetNext(); - size += PrepareToSerialize<Element>(next, false, context); - } - return size; - } + using BufferWriter = typename Data::BufferWriter; static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) { - typename Data::Element* result = output->storage() + i; + ElementWriter result; + result.AllocateInline(buf, writer->data()->storage() + i); typename UserTypeIterator::GetNextResult next = input->GetNext(); Serialize<Element>(next, buf, &result, true, context); MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( - !validate_params->element_is_nullable && output->at(i).is_null(), + !validate_params->element_is_nullable && + writer->data()->at(i).is_null(), VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, MakeMessageWithArrayIndex("null in array expecting valid unions", size, i)); @@ -506,38 +469,27 @@ struct Serializer<ArrayDataView<Element>, MaybeConstUserType> { MaybeConstUserType, ArrayIterator<Traits, MaybeConstUserType>>; using Data = typename MojomTypeTraits<ArrayDataView<Element>>::Data; - - static size_t PrepareToSerialize(MaybeConstUserType& input, - SerializationContext* context) { - if (CallIsNullIfExists<Traits>(input)) - return 0; - ArrayIterator<Traits, MaybeConstUserType> iterator(input); - return Impl::GetSerializedSize(&iterator, context); - } + using BufferWriter = typename Data::BufferWriter; static void Serialize(MaybeConstUserType& input, Buffer* buf, - Data** output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { - if (!CallIsNullIfExists<Traits>(input)) { - MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( - validate_params->expected_num_elements != 0 && - Traits::GetSize(input) != validate_params->expected_num_elements, - internal::VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER, - internal::MakeMessageWithExpectedArraySize( - "fixed-size array has wrong number of elements", - Traits::GetSize(input), validate_params->expected_num_elements)); - Data* result = Data::New(Traits::GetSize(input), buf); - if (result) { - ArrayIterator<Traits, MaybeConstUserType> iterator(input); - Impl::SerializeElements(&iterator, buf, result, validate_params, - context); - } - *output = result; - } else { - *output = nullptr; - } + if (CallIsNullIfExists<Traits>(input)) + return; + + const size_t size = Traits::GetSize(input); + MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( + validate_params->expected_num_elements != 0 && + size != validate_params->expected_num_elements, + internal::VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER, + internal::MakeMessageWithExpectedArraySize( + "fixed-size array has wrong number of elements", size, + validate_params->expected_num_elements)); + writer->Allocate(size, buf); + ArrayIterator<Traits, MaybeConstUserType> iterator(input); + Impl::SerializeElements(&iterator, buf, writer, validate_params, context); } static bool Deserialize(Data* input, diff --git a/mojo/public/cpp/bindings/lib/associated_binding.cc b/mojo/public/cpp/bindings/lib/associated_binding.cc index 6788e68e07..c7eddc2372 100644 --- a/mojo/public/cpp/bindings/lib/associated_binding.cc +++ b/mojo/public/cpp/bindings/lib/associated_binding.cc @@ -4,6 +4,9 @@ #include "mojo/public/cpp/bindings/associated_binding.h" +#include "base/single_thread_task_runner.h" +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + namespace mojo { AssociatedBindingBase::AssociatedBindingBase() {} @@ -27,15 +30,16 @@ void AssociatedBindingBase::CloseWithReason(uint32_t custom_reason, } void AssociatedBindingBase::set_connection_error_handler( - const base::Closure& error_handler) { + base::OnceClosure error_handler) { DCHECK(is_bound()); - endpoint_client_->set_connection_error_handler(error_handler); + endpoint_client_->set_connection_error_handler(std::move(error_handler)); } void AssociatedBindingBase::set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { DCHECK(is_bound()); - endpoint_client_->set_connection_error_with_reason_handler(error_handler); + endpoint_client_->set_connection_error_with_reason_handler( + std::move(error_handler)); } void AssociatedBindingBase::FlushForTesting() { @@ -56,7 +60,9 @@ void AssociatedBindingBase::BindImpl( endpoint_client_.reset(new InterfaceEndpointClient( std::move(handle), receiver, std::move(payload_validator), - expect_sync_requests, std::move(runner), interface_version)); + expect_sync_requests, + internal::GetTaskRunnerToUseFromUserProvidedTaskRunner(std::move(runner)), + interface_version)); } } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc b/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc index 78281eda9a..453e47a995 100644 --- a/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc +++ b/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc @@ -6,12 +6,12 @@ namespace mojo { -void GetIsolatedInterface(ScopedInterfaceEndpointHandle handle) { +void AssociateWithDisconnectedPipe(ScopedInterfaceEndpointHandle handle) { MessagePipe pipe; scoped_refptr<internal::MultiplexRouter> router = - new internal::MultiplexRouter(std::move(pipe.handle0), - internal::MultiplexRouter::MULTI_INTERFACE, - false, base::ThreadTaskRunnerHandle::Get()); + new internal::MultiplexRouter( + std::move(pipe.handle0), internal::MultiplexRouter::MULTI_INTERFACE, + false, base::SequencedTaskRunnerHandle::Get()); router->AssociateInterface(std::move(handle)); } diff --git a/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.cc b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.cc new file mode 100644 index 0000000000..dd3a2510f1 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.cc @@ -0,0 +1,81 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h" + +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + +namespace mojo { +namespace internal { + +AssociatedInterfacePtrStateBase::AssociatedInterfacePtrStateBase() = default; + +AssociatedInterfacePtrStateBase::~AssociatedInterfacePtrStateBase() = default; + +void AssociatedInterfacePtrStateBase::QueryVersion( + const base::Callback<void(uint32_t)>& callback) { + // It is safe to capture |this| because the callback won't be run after this + // object goes away. + endpoint_client_->QueryVersion( + base::Bind(&AssociatedInterfacePtrStateBase::OnQueryVersion, + base::Unretained(this), callback)); +} + +void AssociatedInterfacePtrStateBase::RequireVersion(uint32_t version) { + if (version <= version_) + return; + + version_ = version; + endpoint_client_->RequireVersion(version); +} + +void AssociatedInterfacePtrStateBase::OnQueryVersion( + const base::Callback<void(uint32_t)>& callback, + uint32_t version) { + version_ = version; + callback.Run(version); +} + +void AssociatedInterfacePtrStateBase::FlushForTesting() { + endpoint_client_->FlushForTesting(); +} + +void AssociatedInterfacePtrStateBase::CloseWithReason( + uint32_t custom_reason, + const std::string& description) { + endpoint_client_->CloseWithReason(custom_reason, description); +} + +void AssociatedInterfacePtrStateBase::Swap( + AssociatedInterfacePtrStateBase* other) { + using std::swap; + swap(other->endpoint_client_, endpoint_client_); + swap(other->version_, version_); +} + +void AssociatedInterfacePtrStateBase::Bind( + ScopedInterfaceEndpointHandle handle, + uint32_t version, + std::unique_ptr<MessageReceiver> validator, + scoped_refptr<base::SequencedTaskRunner> runner) { + DCHECK(!endpoint_client_); + DCHECK_EQ(0u, version_); + DCHECK(handle.is_valid()); + + version_ = version; + // The version is only queried from the client so the value passed here + // will not be used. + endpoint_client_ = std::make_unique<InterfaceEndpointClient>( + std::move(handle), nullptr, std::move(validator), false, + GetTaskRunnerToUseFromUserProvidedTaskRunner(std::move(runner)), 0u); +} + +ScopedInterfaceEndpointHandle AssociatedInterfacePtrStateBase::PassHandle() { + auto handle = endpoint_client_->PassHandle(); + endpoint_client_.reset(); + return handle; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h index a4b51882d2..79ec2bec93 100644 --- a/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h +++ b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h @@ -17,9 +17,10 @@ #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/memory/ref_counted.h" -#include "base/single_thread_task_runner.h" +#include "base/sequenced_task_runner.h" #include "mojo/public/cpp/bindings/associated_group.h" #include "mojo/public/cpp/bindings/associated_interface_ptr_info.h" +#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" #include "mojo/public/cpp/bindings/interface_endpoint_client.h" #include "mojo/public/cpp/bindings/interface_id.h" @@ -29,77 +30,17 @@ namespace mojo { namespace internal { -template <typename Interface> -class AssociatedInterfacePtrState { +class MOJO_CPP_BINDINGS_EXPORT AssociatedInterfacePtrStateBase { public: - AssociatedInterfacePtrState() : version_(0u) {} - - ~AssociatedInterfacePtrState() { - endpoint_client_.reset(); - proxy_.reset(); - } - - Interface* instance() { - // This will be null if the object is not bound. - return proxy_.get(); - } + AssociatedInterfacePtrStateBase(); + ~AssociatedInterfacePtrStateBase(); uint32_t version() const { return version_; } - void QueryVersion(const base::Callback<void(uint32_t)>& callback) { - // It is safe to capture |this| because the callback won't be run after this - // object goes away. - endpoint_client_->QueryVersion( - base::Bind(&AssociatedInterfacePtrState::OnQueryVersion, - base::Unretained(this), callback)); - } - - void RequireVersion(uint32_t version) { - if (version <= version_) - return; - - version_ = version; - endpoint_client_->RequireVersion(version); - } - - void FlushForTesting() { endpoint_client_->FlushForTesting(); } - - void CloseWithReason(uint32_t custom_reason, const std::string& description) { - endpoint_client_->CloseWithReason(custom_reason, description); - } - - void Swap(AssociatedInterfacePtrState* other) { - using std::swap; - swap(other->endpoint_client_, endpoint_client_); - swap(other->proxy_, proxy_); - swap(other->version_, version_); - } - - void Bind(AssociatedInterfacePtrInfo<Interface> info, - scoped_refptr<base::SingleThreadTaskRunner> runner) { - DCHECK(!endpoint_client_); - DCHECK(!proxy_); - DCHECK_EQ(0u, version_); - DCHECK(info.is_valid()); - - version_ = info.version(); - // The version is only queried from the client so the value passed here - // will not be used. - endpoint_client_.reset(new InterfaceEndpointClient( - info.PassHandle(), nullptr, - base::WrapUnique(new typename Interface::ResponseValidator_()), false, - std::move(runner), 0u)); - proxy_.reset(new Proxy(endpoint_client_.get())); - } - - // After this method is called, the object is in an invalid state and - // shouldn't be reused. - AssociatedInterfacePtrInfo<Interface> PassInterface() { - ScopedInterfaceEndpointHandle handle = endpoint_client_->PassHandle(); - endpoint_client_.reset(); - proxy_.reset(); - return AssociatedInterfacePtrInfo<Interface>(std::move(handle), version_); - } + void QueryVersion(const base::Callback<void(uint32_t)>& callback); + void RequireVersion(uint32_t version); + void FlushForTesting(); + void CloseWithReason(uint32_t custom_reason, const std::string& description); bool is_bound() const { return !!endpoint_client_; } @@ -107,15 +48,16 @@ class AssociatedInterfacePtrState { return endpoint_client_ ? endpoint_client_->encountered_error() : false; } - void set_connection_error_handler(const base::Closure& error_handler) { + void set_connection_error_handler(base::OnceClosure error_handler) { DCHECK(endpoint_client_); - endpoint_client_->set_connection_error_handler(error_handler); + endpoint_client_->set_connection_error_handler(std::move(error_handler)); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { DCHECK(endpoint_client_); - endpoint_client_->set_connection_error_with_reason_handler(error_handler); + endpoint_client_->set_connection_error_with_reason_handler( + std::move(error_handler)); } // Returns true if bound and awaiting a response to a message. @@ -134,19 +76,62 @@ class AssociatedInterfacePtrState { endpoint_client_->AcceptWithResponder(&message, std::move(responder)); } + protected: + void Swap(AssociatedInterfacePtrStateBase* other); + void Bind(ScopedInterfaceEndpointHandle handle, + uint32_t version, + std::unique_ptr<MessageReceiver> validator, + scoped_refptr<base::SequencedTaskRunner> runner); + ScopedInterfaceEndpointHandle PassHandle(); + + InterfaceEndpointClient* endpoint_client() { return endpoint_client_.get(); } + private: + void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, + uint32_t version); + + std::unique_ptr<InterfaceEndpointClient> endpoint_client_; + uint32_t version_ = 0; +}; + +template <typename Interface> +class AssociatedInterfacePtrState : public AssociatedInterfacePtrStateBase { + public: using Proxy = typename Interface::Proxy_; - void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, - uint32_t version) { - version_ = version; - callback.Run(version); + AssociatedInterfacePtrState() {} + ~AssociatedInterfacePtrState() = default; + + Proxy* instance() { + // This will be null if the object is not bound. + return proxy_.get(); } - std::unique_ptr<InterfaceEndpointClient> endpoint_client_; - std::unique_ptr<Proxy> proxy_; + void Swap(AssociatedInterfacePtrState* other) { + AssociatedInterfacePtrStateBase::Swap(other); + std::swap(other->proxy_, proxy_); + } + + void Bind(AssociatedInterfacePtrInfo<Interface> info, + scoped_refptr<base::SequencedTaskRunner> runner) { + DCHECK(!proxy_); + AssociatedInterfacePtrStateBase::Bind( + info.PassHandle(), info.version(), + std::make_unique<typename Interface::ResponseValidator_>(), + std::move(runner)); + proxy_.reset(new Proxy(endpoint_client())); + } + + // After this method is called, the object is in an invalid state and + // shouldn't be reused. + AssociatedInterfacePtrInfo<Interface> PassInterface() { + AssociatedInterfacePtrInfo<Interface> info(PassHandle(), version()); + proxy_.reset(); + return info; + } - uint32_t version_; + private: + std::unique_ptr<Proxy> proxy_; DISALLOW_COPY_AND_ASSIGN(AssociatedInterfacePtrState); }; diff --git a/mojo/public/cpp/bindings/lib/binding_state.cc b/mojo/public/cpp/bindings/lib/binding_state.cc index b34cb47e28..bb4a20f39b 100644 --- a/mojo/public/cpp/bindings/lib/binding_state.cc +++ b/mojo/public/cpp/bindings/lib/binding_state.cc @@ -4,10 +4,12 @@ #include "mojo/public/cpp/bindings/lib/binding_state.h" +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + namespace mojo { namespace internal { -BindingStateBase::BindingStateBase() = default; +BindingStateBase::BindingStateBase() : weak_ptr_factory_(this) {} BindingStateBase::~BindingStateBase() = default; @@ -24,6 +26,7 @@ void BindingStateBase::PauseIncomingMethodCallProcessing() { DCHECK(router_); router_->PauseIncomingMethodCallProcessing(); } + void BindingStateBase::ResumeIncomingMethodCallProcessing() { DCHECK(router_); router_->ResumeIncomingMethodCallProcessing(); @@ -51,6 +54,17 @@ void BindingStateBase::CloseWithReason(uint32_t custom_reason, Close(); } +ReportBadMessageCallback BindingStateBase::GetBadMessageCallback() { + return base::BindOnce( + [](ReportBadMessageCallback inner_callback, + base::WeakPtr<BindingStateBase> binding, const std::string& error) { + std::move(inner_callback).Run(error); + if (binding) + binding->Close(); + }, + mojo::GetBadMessageCallback(), weak_ptr_factory_.GetWeakPtr()); +} + void BindingStateBase::FlushForTesting() { endpoint_client_->FlushForTesting(); } @@ -60,6 +74,10 @@ void BindingStateBase::EnableTestingMode() { router_->EnableTestingMode(); } +scoped_refptr<internal::MultiplexRouter> BindingStateBase::RouterForTesting() { + return router_; +} + void BindingStateBase::BindInternal( ScopedMessagePipeHandle handle, scoped_refptr<base::SingleThreadTaskRunner> runner, @@ -69,21 +87,25 @@ void BindingStateBase::BindInternal( bool has_sync_methods, MessageReceiverWithResponderStatus* stub, uint32_t interface_version) { - DCHECK(!router_); + DCHECK(!is_bound()) << "Attempting to bind interface that is already bound: " + << interface_name; + auto sequenced_runner = + GetTaskRunnerToUseFromUserProvidedTaskRunner(std::move(runner)); MultiplexRouter::Config config = passes_associated_kinds ? MultiplexRouter::MULTI_INTERFACE : (has_sync_methods ? MultiplexRouter::SINGLE_INTERFACE_WITH_SYNC_METHODS : MultiplexRouter::SINGLE_INTERFACE); - router_ = new MultiplexRouter(std::move(handle), config, false, runner); + router_ = + new MultiplexRouter(std::move(handle), config, false, sequenced_runner); router_->SetMasterInterfaceName(interface_name); endpoint_client_.reset(new InterfaceEndpointClient( router_->CreateLocalEndpointHandle(kMasterInterfaceId), stub, - std::move(request_validator), has_sync_methods, std::move(runner), - interface_version)); + std::move(request_validator), has_sync_methods, + std::move(sequenced_runner), interface_version)); } } // namesapce internal diff --git a/mojo/public/cpp/bindings/lib/binding_state.h b/mojo/public/cpp/bindings/lib/binding_state.h index 0b0dbee002..d1c561c748 100644 --- a/mojo/public/cpp/bindings/lib/binding_state.h +++ b/mojo/public/cpp/bindings/lib/binding_state.h @@ -15,6 +15,7 @@ #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/memory/ref_counted.h" +#include "base/sequenced_task_runner.h" #include "base/single_thread_task_runner.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" @@ -50,15 +51,18 @@ class MOJO_CPP_BINDINGS_EXPORT BindingStateBase { void Close(); void CloseWithReason(uint32_t custom_reason, const std::string& description); - void set_connection_error_handler(const base::Closure& error_handler) { + void RaiseError() { endpoint_client_->RaiseError(); } + + void set_connection_error_handler(base::OnceClosure error_handler) { DCHECK(is_bound()); - endpoint_client_->set_connection_error_handler(error_handler); + endpoint_client_->set_connection_error_handler(std::move(error_handler)); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { DCHECK(is_bound()); - endpoint_client_->set_connection_error_with_reason_handler(error_handler); + endpoint_client_->set_connection_error_with_reason_handler( + std::move(error_handler)); } bool is_bound() const { return !!router_; } @@ -68,10 +72,14 @@ class MOJO_CPP_BINDINGS_EXPORT BindingStateBase { return router_->handle(); } + ReportBadMessageCallback GetBadMessageCallback(); + void FlushForTesting(); void EnableTestingMode(); + scoped_refptr<internal::MultiplexRouter> RouterForTesting(); + protected: void BindInternal(ScopedMessagePipeHandle handle, scoped_refptr<base::SingleThreadTaskRunner> runner, @@ -84,6 +92,8 @@ class MOJO_CPP_BINDINGS_EXPORT BindingStateBase { scoped_refptr<internal::MultiplexRouter> router_; std::unique_ptr<InterfaceEndpointClient> endpoint_client_; + + base::WeakPtrFactory<BindingStateBase> weak_ptr_factory_; }; template <typename Interface, typename ImplRefTraits> @@ -101,20 +111,24 @@ class BindingState : public BindingStateBase { scoped_refptr<base::SingleThreadTaskRunner> runner) { BindingStateBase::BindInternal( std::move(handle), runner, Interface::Name_, - base::MakeUnique<typename Interface::RequestValidator_>(), + std::make_unique<typename Interface::RequestValidator_>(), Interface::PassesAssociatedKinds_, Interface::HasSyncMethods_, &stub_, Interface::Version_); } InterfaceRequest<Interface> Unbind() { endpoint_client_.reset(); - InterfaceRequest<Interface> request = - MakeRequest<Interface>(router_->PassMessagePipe()); + InterfaceRequest<Interface> request(router_->PassMessagePipe()); router_ = nullptr; return request; } Interface* impl() { return ImplRefTraits::GetRawPointer(&stub_.sink()); } + ImplPointerType SwapImplForTesting(ImplPointerType new_impl) { + Interface* old_impl = impl(); + stub_.set_sink(std::move(new_impl)); + return old_impl; + } private: typename Interface::template Stub_<ImplRefTraits> stub_; diff --git a/mojo/public/cpp/bindings/lib/bindings_internal.h b/mojo/public/cpp/bindings/lib/bindings_internal.h index 631daec392..8bdb9c7b77 100644 --- a/mojo/public/cpp/bindings/lib/bindings_internal.h +++ b/mojo/public/cpp/bindings/lib/bindings_internal.h @@ -8,8 +8,9 @@ #include <stdint.h> #include <functional> +#include <type_traits> -#include "base/template_util.h" +#include "mojo/public/cpp/bindings/enum_traits.h" #include "mojo/public/cpp/bindings/interface_id.h" #include "mojo/public/cpp/bindings/lib/template_util.h" #include "mojo/public/cpp/system/core.h" @@ -34,8 +35,6 @@ class InterfaceRequestDataView; template <typename K, typename V> class MapDataView; -class NativeStructDataView; - class StringDataView; namespace internal { @@ -54,8 +53,6 @@ class Array_Data; template <typename K, typename V> class Map_Data; -class NativeStruct_Data; - using String_Data = Array_Data<char>; inline size_t Align(size_t size) { @@ -299,14 +296,6 @@ struct MojomTypeTraits<MapDataView<K, V>, false> { }; template <> -struct MojomTypeTraits<NativeStructDataView, false> { - using Data = internal::NativeStruct_Data; - using DataAsArrayElement = Pointer<Data>; - - static const MojomTypeCategory category = MojomTypeCategory::STRUCT; -}; - -template <> struct MojomTypeTraits<StringDataView, false> { using Data = String_Data; using DataAsArrayElement = Pointer<Data>; @@ -325,11 +314,19 @@ struct EnumHashImpl { static_assert(std::is_enum<T>::value, "Incorrect hash function."); size_t operator()(T input) const { - using UnderlyingType = typename base::underlying_type<T>::type; + using UnderlyingType = typename std::underlying_type<T>::type; return std::hash<UnderlyingType>()(static_cast<UnderlyingType>(input)); } }; +template <typename MojomType, typename T> +T ConvertEnumValue(MojomType input) { + T output; + bool result = EnumTraits<MojomType, T>::FromMojom(input, &output); + DCHECK(result); + return output; +} + } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/buffer.cc b/mojo/public/cpp/bindings/lib/buffer.cc new file mode 100644 index 0000000000..2444cf4e54 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/buffer.cc @@ -0,0 +1,136 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/buffer.h" + +#include "base/logging.h" +#include "base/numerics/safe_math.h" +#include "mojo/public/c/system/message_pipe.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" + +namespace mojo { +namespace internal { + +Buffer::Buffer() = default; + +Buffer::Buffer(void* data, size_t size, size_t cursor) + : data_(data), size_(size), cursor_(cursor) { + DCHECK(IsAligned(data_)); +} + +Buffer::Buffer(MessageHandle message, + size_t message_payload_size, + void* data, + size_t size) + : message_(message), + message_payload_size_(message_payload_size), + data_(data), + size_(size), + cursor_(0) { + DCHECK(IsAligned(data_)); +} + +Buffer::Buffer(Buffer&& other) { + *this = std::move(other); +} + +Buffer::~Buffer() = default; + +Buffer& Buffer::operator=(Buffer&& other) { + message_ = other.message_; + message_payload_size_ = other.message_payload_size_; + data_ = other.data_; + size_ = other.size_; + cursor_ = other.cursor_; + other.Reset(); + return *this; +} + +size_t Buffer::Allocate(size_t num_bytes) { + const size_t aligned_num_bytes = Align(num_bytes); + const size_t new_cursor = cursor_ + aligned_num_bytes; + if (new_cursor < cursor_ || (new_cursor > size_ && !message_.is_valid())) { + // Either we've overflowed or exceeded a fixed capacity. + NOTREACHED(); + return 0; + } + + if (new_cursor > size_) { + // If we have an underlying message object we can extend its payload to + // obtain more storage capacity. + DCHECK_LE(message_payload_size_, new_cursor); + size_t additional_bytes = new_cursor - message_payload_size_; + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(additional_bytes)); + uint32_t new_size; + MojoResult rv = MojoAppendMessageData( + message_.value(), static_cast<uint32_t>(additional_bytes), nullptr, 0, + nullptr, &data_, &new_size); + DCHECK_EQ(MOJO_RESULT_OK, rv); + message_payload_size_ = new_cursor; + size_ = new_size; + } + + DCHECK_LE(new_cursor, size_); + size_t block_start = cursor_; + cursor_ = new_cursor; + + // Ensure that all the allocated space is zeroed to avoid uninitialized bits + // leaking into messages. + // + // TODO(rockot): We should consider only clearing the alignment padding. This + // means being careful about generated bindings zeroing padding explicitly, + // which itself gets particularly messy with e.g. packed bool bitfields. + memset(static_cast<uint8_t*>(data_) + block_start, 0, aligned_num_bytes); + + return block_start; +} + +void Buffer::AttachHandles(std::vector<ScopedHandle>* handles) { + DCHECK(message_.is_valid()); + + uint32_t new_size = 0; + MojoResult rv = MojoAppendMessageData( + message_.value(), 0, reinterpret_cast<MojoHandle*>(handles->data()), + static_cast<uint32_t>(handles->size()), nullptr, &data_, &new_size); + if (rv != MOJO_RESULT_OK) + return; + + size_ = new_size; + for (auto& handle : *handles) + ignore_result(handle.release()); +} + +void Buffer::Seal() { + if (!message_.is_valid()) + return; + + // Ensure that the backing message has the final accumulated payload size. + DCHECK_LE(message_payload_size_, cursor_); + size_t additional_bytes = cursor_ - message_payload_size_; + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(additional_bytes)); + + MojoAppendMessageDataOptions options; + options.struct_size = sizeof(options); + options.flags = MOJO_APPEND_MESSAGE_DATA_FLAG_COMMIT_SIZE; + void* data; + uint32_t size; + MojoResult rv = MojoAppendMessageData(message_.value(), + static_cast<uint32_t>(additional_bytes), + nullptr, 0, &options, &data, &size); + DCHECK_EQ(MOJO_RESULT_OK, rv); + message_ = MessageHandle(); + message_payload_size_ = cursor_; + data_ = data; + size_ = size; +} + +void Buffer::Reset() { + message_ = MessageHandle(); + data_ = nullptr; + size_ = 0; + cursor_ = 0; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/buffer.h b/mojo/public/cpp/bindings/lib/buffer.h index 213a44590f..9f2a768490 100644 --- a/mojo/public/cpp/bindings/lib/buffer.h +++ b/mojo/public/cpp/bindings/lib/buffer.h @@ -6,60 +6,121 @@ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_BUFFER_H_ #include <stddef.h> +#include <stdint.h> -#include "base/logging.h" +#include <vector> + +#include "base/component_export.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/system/handle.h" +#include "mojo/public/cpp/system/message.h" namespace mojo { namespace internal { // Buffer provides an interface to allocate memory blocks which are 8-byte -// aligned and zero-initialized. It doesn't own the underlying memory. Users -// must ensure that the memory stays valid while using the allocated blocks from -// Buffer. -class Buffer { +// aligned. It doesn't own the underlying memory. Users must ensure that the +// memory stays valid while using the allocated blocks from Buffer. +// +// A Buffer may be moved around. A moved-from Buffer is reset and may no longer +// be used to Allocate memory unless re-Initialized. +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) Buffer { public: - Buffer() {} + // Constructs an invalid Buffer. May not call Allocate(). + Buffer(); + + // Constructs a Buffer which can Allocate() blocks from a buffer of fixed size + // |size| at |data|. Allocations start at |cursor|, so if |cursor| == |size| + // then no allocations are allowed. + // + // |data| is not owned. + Buffer(void* data, size_t size, size_t cursor); + + // Like above, but gives the Buffer an underlying message object which can + // have its payload extended to acquire more storage capacity on Allocate(). + // + // |data| and |size| must correspond to |message|'s data buffer at the time of + // construction. + // + // |payload_size| is the length of the payload as known by |message|, and it + // must be less than or equal to |size|. + // + // |message| is NOT owned and must outlive this Buffer. + Buffer(MessageHandle message, + size_t message_payload_size, + void* data, + size_t size); + + Buffer(Buffer&& other); + ~Buffer(); + + Buffer& operator=(Buffer&& other); - // The memory must have been zero-initialized. |data| must be 8-byte - // aligned. - void Initialize(void* data, size_t size) { - DCHECK(IsAligned(data)); + void* data() const { return data_; } + size_t size() const { return size_; } + size_t cursor() const { return cursor_; } - data_ = data; - size_ = size; - cursor_ = reinterpret_cast<uintptr_t>(data); - data_end_ = cursor_ + size; + bool is_valid() const { + return data_ != nullptr || (size_ == 0 && !message_.is_valid()); } - size_t size() const { return size_; } + // Allocates |num_bytes| from the buffer and returns an index to the start of + // the allocated block. The resulting index is 8-byte aligned and can be + // resolved to an address using Get<T>() below. + size_t Allocate(size_t num_bytes); + + // Returns a typed address within the Buffer corresponding to |index|. Note + // that this address is NOT stable across calls to |Allocate()| and thus must + // not be cached accordingly. + template <typename T> + T* Get(size_t index) { + DCHECK_LT(index, cursor_); + return reinterpret_cast<T*>(static_cast<uint8_t*>(data_) + index); + } - void* data() const { return data_; } + // A template helper combining Allocate() and Get<T>() above to allocate and + // return a block of size |sizeof(T)|. + template <typename T> + T* AllocateAndGet() { + return Get<T>(Allocate(sizeof(T))); + } - // Allocates |num_bytes| from the buffer and returns a pointer to the start of - // the allocated block. - // The resulting address is 8-byte aligned, and the content of the memory is - // zero-filled. - void* Allocate(size_t num_bytes) { - num_bytes = Align(num_bytes); - uintptr_t result = cursor_; - cursor_ += num_bytes; - if (cursor_ > data_end_ || cursor_ < result) { - NOTREACHED(); - cursor_ -= num_bytes; - return nullptr; - } - - return reinterpret_cast<void*>(result); + // A helper which combines Allocate() and Get<void>() for a specified number + // of bytes. + void* AllocateAndGet(size_t num_bytes) { + return Get<void>(Allocate(num_bytes)); } + // Serializes |handles| into the buffer object. Only valid to call when this + // Buffer is backed by a message object. + void AttachHandles(std::vector<ScopedHandle>* handles); + + // Seals this Buffer so it can no longer be used for allocation, and ensures + // the backing message object has a complete accounting of the size of the + // meaningful payload bytes. + void Seal(); + + // Resets the buffer to an invalid state. Can no longer be used to Allocate(). + void Reset(); + private: + MessageHandle message_; + + // The payload size from the message's internal perspective. This differs from + // |size_| as Mojo may intentionally over-allocate space to account for future + // growth. It differs from |cursor_| because we don't push payload size + // updates to the message object as frequently as we update |cursor_|, for + // performance. + size_t message_payload_size_ = 0; + + // The storage location and capacity currently backing |message_|. Owned by + // the message object internally, not by this Buffer. void* data_ = nullptr; size_t size_ = 0; - uintptr_t cursor_ = 0; - uintptr_t data_end_ = 0; + // The current write offset into |data_| if this Buffer is being used for + // message creation. + size_t cursor_ = 0; DISALLOW_COPY_AND_ASSIGN(Buffer); }; diff --git a/mojo/public/cpp/bindings/lib/connector.cc b/mojo/public/cpp/bindings/lib/connector.cc index d93e45ed93..352c51815f 100644 --- a/mojo/public/cpp/bindings/lib/connector.cc +++ b/mojo/public/cpp/bindings/lib/connector.cc @@ -5,7 +5,6 @@ #include "mojo/public/cpp/bindings/connector.h" #include <stdint.h> -#include <utility> #include "base/bind.h" #include "base/lazy_instance.h" @@ -13,23 +12,37 @@ #include "base/logging.h" #include "base/macros.h" #include "base/memory/ptr_util.h" -#include "base/message_loop/message_loop.h" +#include "base/message_loop/message_loop_current.h" +#include "base/run_loop.h" #include "base/synchronization/lock.h" #include "base/threading/thread_local.h" +#include "base/trace_event/trace_event.h" #include "mojo/public/cpp/bindings/lib/may_auto_lock.h" +#include "mojo/public/cpp/bindings/mojo_buildflags.h" #include "mojo/public/cpp/bindings/sync_handle_watcher.h" #include "mojo/public/cpp/system/wait.h" +#if defined(ENABLE_IPC_FUZZER) +#include "mojo/public/cpp/bindings/message_dumper.h" +#endif + namespace mojo { namespace { // The NestingObserver for each thread. Note that this is always a -// Connector::MessageLoopNestingObserver; we use the base type here because that +// Connector::RunLoopNestingObserver; we use the base type here because that // subclass is private to Connector. -base::LazyInstance< - base::ThreadLocalPointer<base::MessageLoop::NestingObserver>>::Leaky - g_tls_nesting_observer = LAZY_INSTANCE_INITIALIZER; +base::LazyInstance<base::ThreadLocalPointer<base::RunLoop::NestingObserver>>:: + Leaky g_tls_nesting_observer = LAZY_INSTANCE_INITIALIZER; + +// The default outgoing serialization mode for new Connectors. +Connector::OutgoingSerializationMode g_default_outgoing_serialization_mode = + Connector::OutgoingSerializationMode::kLazy; + +// The default incoming serialization mode for new Connectors. +Connector::IncomingSerializationMode g_default_incoming_serialization_mode = + Connector::IncomingSerializationMode::kDispatchAsIs; } // namespace @@ -44,7 +57,7 @@ class Connector::ActiveDispatchTracker { private: const base::WeakPtr<Connector> connector_; - MessageLoopNestingObserver* const nesting_observer_; + RunLoopNestingObserver* const nesting_observer_; ActiveDispatchTracker* outer_tracker_ = nullptr; ActiveDispatchTracker* inner_tracker_ = nullptr; @@ -52,41 +65,40 @@ class Connector::ActiveDispatchTracker { }; // Watches the MessageLoop on the current thread. Notifies the current chain of -// ActiveDispatchTrackers when a nested message loop is started. -class Connector::MessageLoopNestingObserver - : public base::MessageLoop::NestingObserver, - public base::MessageLoop::DestructionObserver { +// ActiveDispatchTrackers when a nested run loop is started. +class Connector::RunLoopNestingObserver + : public base::RunLoop::NestingObserver, + public base::MessageLoopCurrent::DestructionObserver { public: - MessageLoopNestingObserver() { - base::MessageLoop::current()->AddNestingObserver(this); - base::MessageLoop::current()->AddDestructionObserver(this); + RunLoopNestingObserver() { + base::RunLoop::AddNestingObserverOnCurrentThread(this); + base::MessageLoopCurrent::Get()->AddDestructionObserver(this); } - ~MessageLoopNestingObserver() override {} + ~RunLoopNestingObserver() override {} - // base::MessageLoop::NestingObserver: - void OnBeginNestedMessageLoop() override { + // base::RunLoop::NestingObserver: + void OnBeginNestedRunLoop() override { if (top_tracker_) top_tracker_->NotifyBeginNesting(); } - // base::MessageLoop::DestructionObserver: + // base::MessageLoopCurrent::DestructionObserver: void WillDestroyCurrentMessageLoop() override { - base::MessageLoop::current()->RemoveNestingObserver(this); - base::MessageLoop::current()->RemoveDestructionObserver(this); + base::RunLoop::RemoveNestingObserverOnCurrentThread(this); + base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this); DCHECK_EQ(this, g_tls_nesting_observer.Get().Get()); g_tls_nesting_observer.Get().Set(nullptr); delete this; } - static MessageLoopNestingObserver* GetForThread() { - if (!base::MessageLoop::current() || - !base::MessageLoop::current()->nesting_allowed()) + static RunLoopNestingObserver* GetForThread() { + if (!base::MessageLoopCurrent::Get()) return nullptr; - auto* observer = static_cast<MessageLoopNestingObserver*>( + auto* observer = static_cast<RunLoopNestingObserver*>( g_tls_nesting_observer.Get().Get()); if (!observer) { - observer = new MessageLoopNestingObserver; + observer = new RunLoopNestingObserver; g_tls_nesting_observer.Get().Set(observer); } return observer; @@ -97,7 +109,7 @@ class Connector::MessageLoopNestingObserver ActiveDispatchTracker* top_tracker_ = nullptr; - DISALLOW_COPY_AND_ASSIGN(MessageLoopNestingObserver); + DISALLOW_COPY_AND_ASSIGN(RunLoopNestingObserver); }; Connector::ActiveDispatchTracker::ActiveDispatchTracker( @@ -129,14 +141,22 @@ void Connector::ActiveDispatchTracker::NotifyBeginNesting() { Connector::Connector(ScopedMessagePipeHandle message_pipe, ConnectorConfig config, - scoped_refptr<base::SingleThreadTaskRunner> runner) + scoped_refptr<base::SequencedTaskRunner> runner) : message_pipe_(std::move(message_pipe)), task_runner_(std::move(runner)), - nesting_observer_(MessageLoopNestingObserver::GetForThread()), + error_(false), + outgoing_serialization_mode_(g_default_outgoing_serialization_mode), + incoming_serialization_mode_(g_default_incoming_serialization_mode), + nesting_observer_(RunLoopNestingObserver::GetForThread()), weak_factory_(this) { if (config == MULTI_THREADED_SEND) lock_.emplace(); +#if defined(ENABLE_IPC_FUZZER) + if (!MessageDumper::GetMessageDumpDirectory().empty()) + message_dumper_ = std::make_unique<MessageDumper>(); +#endif + weak_self_ = weak_factory_.GetWeakPtr(); // Even though we don't have an incoming receiver, we still want to monitor // the message pipe to know if is closed or encounters an error. @@ -145,23 +165,34 @@ Connector::Connector(ScopedMessagePipeHandle message_pipe, Connector::~Connector() { { - // Allow for quick destruction on any thread if the pipe is already closed. + // Allow for quick destruction on any sequence if the pipe is already + // closed. base::AutoLock lock(connected_lock_); if (!connected_) return; } - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); CancelWait(); } +void Connector::SetOutgoingSerializationMode(OutgoingSerializationMode mode) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + outgoing_serialization_mode_ = mode; +} + +void Connector::SetIncomingSerializationMode(IncomingSerializationMode mode) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + incoming_serialization_mode_ = mode; +} + void Connector::CloseMessagePipe() { // Throw away the returned message pipe. PassMessagePipe(); } ScopedMessagePipeHandle Connector::PassMessagePipe() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); CancelWait(); internal::MayAutoLock locker(&lock_); @@ -175,13 +206,13 @@ ScopedMessagePipeHandle Connector::PassMessagePipe() { } void Connector::RaiseError() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); HandleError(true, true); } bool Connector::WaitForIncomingMessage(MojoDeadline deadline) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (error_) return false; @@ -211,7 +242,7 @@ bool Connector::WaitForIncomingMessage(MojoDeadline deadline) { } void Connector::PauseIncomingMethodCallProcessing() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (paused_) return; @@ -221,7 +252,7 @@ void Connector::PauseIncomingMethodCallProcessing() { } void Connector::ResumeIncomingMethodCallProcessing() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!paused_) return; @@ -230,12 +261,18 @@ void Connector::ResumeIncomingMethodCallProcessing() { WaitToReadMore(); } +bool Connector::PrefersSerializedMessages() { + if (outgoing_serialization_mode_ == OutgoingSerializationMode::kEager) + return true; + DCHECK_EQ(OutgoingSerializationMode::kLazy, outgoing_serialization_mode_); + return peer_remoteness_tracker_ && + peer_remoteness_tracker_->last_known_state().peer_remote(); +} + bool Connector::Accept(Message* message) { - DCHECK(lock_ || thread_checker_.CalledOnValidThread()); + if (!lock_) + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - // It shouldn't hurt even if |error_| may be changed by a different thread at - // the same time. The outcome is that we may write into |message_pipe_| after - // encountering an error, which should be fine. if (error_) return false; @@ -244,6 +281,13 @@ bool Connector::Accept(Message* message) { if (!message_pipe_.is_valid() || drop_writes_) return true; +#if defined(ENABLE_IPC_FUZZER) + if (message_dumper_ && message->is_serialized()) { + bool dump_result = message_dumper_->Accept(message); + DCHECK(dump_result); + } +#endif + MojoResult rv = WriteMessageNew(message_pipe_.get(), message->TakeMojoMessage(), MOJO_WRITE_MESSAGE_FLAG_NONE); @@ -261,10 +305,10 @@ bool Connector::Accept(Message* message) { case MOJO_RESULT_BUSY: // We'd get a "busy" result if one of the message's handles is: // - |message_pipe_|'s own handle; - // - simultaneously being used on another thread; or + // - simultaneously being used on another sequence; or // - in a "busy" state that prohibits it from being transferred (e.g., // a data pipe handle in the middle of a two-phase read/write, - // regardless of which thread that two-phase read/write is happening + // regardless of which sequence that two-phase read/write is happening // on). // TODO(vtl): I wonder if this should be a |DCHECK()|. (But, until // crbug.com/389666, etc. are resolved, this will make tests fail quickly @@ -280,7 +324,7 @@ bool Connector::Accept(Message* message) { } void Connector::AllowWokenUpBySyncWatchOnSameThread() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); allow_woken_up_by_others_ = true; @@ -289,7 +333,7 @@ void Connector::AllowWokenUpBySyncWatchOnSameThread() { } bool Connector::SyncWatch(const bool* should_stop) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (error_) return false; @@ -301,12 +345,21 @@ bool Connector::SyncWatch(const bool* should_stop) { } void Connector::SetWatcherHeapProfilerTag(const char* tag) { - heap_profiler_tag_ = tag; - if (handle_watcher_) { - handle_watcher_->set_heap_profiler_tag(tag); + if (tag) { + heap_profiler_tag_ = tag; + if (handle_watcher_) + handle_watcher_->set_heap_profiler_tag(tag); } } +// static +void Connector::OverrideDefaultSerializationBehaviorForTesting( + OutgoingSerializationMode outgoing_mode, + IncomingSerializationMode incoming_mode) { + g_default_outgoing_serialization_mode = outgoing_mode; + g_default_incoming_serialization_mode = incoming_mode; +} + void Connector::OnWatcherHandleReady(MojoResult result) { OnHandleReadyInternal(result); } @@ -324,7 +377,7 @@ void Connector::OnSyncHandleWatcherHandleReady(MojoResult result) { } void Connector::OnHandleReadyInternal(MojoResult result) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (result != MOJO_RESULT_OK) { HandleError(result != MOJO_RESULT_FAILED_PRECONDITION, false); @@ -341,12 +394,16 @@ void Connector::WaitToReadMore() { handle_watcher_.reset(new SimpleWatcher( FROM_HERE, SimpleWatcher::ArmingPolicy::MANUAL, task_runner_)); - if (heap_profiler_tag_) - handle_watcher_->set_heap_profiler_tag(heap_profiler_tag_); + handle_watcher_->set_heap_profiler_tag(heap_profiler_tag_); MojoResult rv = handle_watcher_->Watch( message_pipe_.get(), MOJO_HANDLE_SIGNAL_READABLE, base::Bind(&Connector::OnWatcherHandleReady, base::Unretained(this))); + if (message_pipe_.is_valid()) { + peer_remoteness_tracker_.emplace(message_pipe_.get(), + MOJO_HANDLE_SIGNAL_PEER_REMOTE); + } + if (rv != MOJO_RESULT_OK) { // If the watch failed because the handle is invalid or its conditions can // no longer be met, we signal the error asynchronously to avoid reentry. @@ -383,6 +440,19 @@ bool Connector::ReadSingleMessage(MojoResult* read_result) { dispatch_tracker.emplace(weak_self); } + if (incoming_serialization_mode_ == + IncomingSerializationMode::kSerializeBeforeDispatchForTesting) { + message.SerializeIfNecessary(); + } else { + DCHECK_EQ(IncomingSerializationMode::kDispatchAsIs, + incoming_serialization_mode_); + } + +#if !BUILDFLAG(MOJO_TRACE_ENABLED) + // This emits just full class name, and is inferior to mojo tracing. + TRACE_EVENT0("mojom", heap_profiler_tag_); +#endif + receiver_result = incoming_receiver_ && incoming_receiver_->Accept(&message); @@ -443,6 +513,7 @@ void Connector::ReadAllAvailableMessages() { } void Connector::CancelWait() { + peer_remoteness_tracker_.reset(); handle_watcher_.reset(); sync_watcher_.reset(); } @@ -476,8 +547,8 @@ void Connector::HandleError(bool force_pipe_reset, bool force_async_handler) { WaitToReadMore(); } else { error_ = true; - if (!connection_error_handler_.is_null()) - connection_error_handler_.Run(); + if (connection_error_handler_) + std::move(connection_error_handler_).Run(); } } diff --git a/mojo/public/cpp/bindings/lib/control_message_handler.cc b/mojo/public/cpp/bindings/lib/control_message_handler.cc index 1b7bb78e5f..b87c11c874 100644 --- a/mojo/public/cpp/bindings/lib/control_message_handler.cc +++ b/mojo/public/cpp/bindings/lib/control_message_handler.cc @@ -10,9 +10,9 @@ #include "base/logging.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" #include "mojo/public/cpp/bindings/lib/serialization.h" #include "mojo/public/cpp/bindings/lib/validation_util.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/interfaces/bindings/interface_control_messages.mojom.h" namespace mojo { @@ -115,19 +115,15 @@ bool ControlMessageHandler::Run( auto response_params_ptr = interface_control::RunResponseMessageParams::New(); response_params_ptr->output = std::move(output); - size_t size = - PrepareToSerialize<interface_control::RunResponseMessageParamsDataView>( - response_params_ptr, &context_); - MessageBuilder builder(interface_control::kRunMessageId, - Message::kFlagIsResponse, size, 0); - builder.message()->set_request_id(message->request_id()); - - interface_control::internal::RunResponseMessageParams_Data* response_params = - nullptr; + Message response_message(interface_control::kRunMessageId, + Message::kFlagIsResponse, 0, 0, nullptr); + response_message.set_request_id(message->request_id()); + interface_control::internal::RunResponseMessageParams_Data::BufferWriter + response_params; Serialize<interface_control::RunResponseMessageParamsDataView>( - response_params_ptr, builder.buffer(), &response_params, &context_); - ignore_result(responder->Accept(builder.message())); - + response_params_ptr, response_message.payload_buffer(), &response_params, + &context_); + ignore_result(responder->Accept(&response_message)); return true; } diff --git a/mojo/public/cpp/bindings/lib/control_message_handler.h b/mojo/public/cpp/bindings/lib/control_message_handler.h index 5d1f716ea8..daa884bb52 100644 --- a/mojo/public/cpp/bindings/lib/control_message_handler.h +++ b/mojo/public/cpp/bindings/lib/control_message_handler.h @@ -18,7 +18,7 @@ namespace internal { // Handlers for request messages defined in interface_control_messages.mojom. class MOJO_CPP_BINDINGS_EXPORT ControlMessageHandler - : NON_EXPORTED_BASE(public MessageReceiverWithResponderStatus) { + : public MessageReceiverWithResponderStatus { public: static bool IsControlMessage(const Message* message); diff --git a/mojo/public/cpp/bindings/lib/control_message_proxy.cc b/mojo/public/cpp/bindings/lib/control_message_proxy.cc index d082b49fb3..9fd7bf4173 100644 --- a/mojo/public/cpp/bindings/lib/control_message_proxy.cc +++ b/mojo/public/cpp/bindings/lib/control_message_proxy.cc @@ -12,7 +12,6 @@ #include "base/callback_helpers.h" #include "base/macros.h" #include "base/run_loop.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" #include "mojo/public/cpp/bindings/lib/serialization.h" #include "mojo/public/cpp/bindings/lib/validation_util.h" #include "mojo/public/cpp/bindings/message.h" @@ -73,49 +72,37 @@ bool RunResponseForwardToCallback::Accept(Message* message) { void SendRunMessage(MessageReceiverWithResponder* receiver, interface_control::RunInputPtr input_ptr, const RunCallback& callback) { - SerializationContext context; - auto params_ptr = interface_control::RunMessageParams::New(); params_ptr->input = std::move(input_ptr); - size_t size = PrepareToSerialize<interface_control::RunMessageParamsDataView>( - params_ptr, &context); - MessageBuilder builder(interface_control::kRunMessageId, - Message::kFlagExpectsResponse, size, 0); - - interface_control::internal::RunMessageParams_Data* params = nullptr; + Message message(interface_control::kRunMessageId, + Message::kFlagExpectsResponse, 0, 0, nullptr); + SerializationContext context; + interface_control::internal::RunMessageParams_Data::BufferWriter params; Serialize<interface_control::RunMessageParamsDataView>( - params_ptr, builder.buffer(), ¶ms, &context); + params_ptr, message.payload_buffer(), ¶ms, &context); std::unique_ptr<MessageReceiver> responder = - base::MakeUnique<RunResponseForwardToCallback>(callback); - ignore_result( - receiver->AcceptWithResponder(builder.message(), std::move(responder))); + std::make_unique<RunResponseForwardToCallback>(callback); + ignore_result(receiver->AcceptWithResponder(&message, std::move(responder))); } Message ConstructRunOrClosePipeMessage( interface_control::RunOrClosePipeInputPtr input_ptr) { - SerializationContext context; - auto params_ptr = interface_control::RunOrClosePipeMessageParams::New(); params_ptr->input = std::move(input_ptr); - - size_t size = PrepareToSerialize< - interface_control::RunOrClosePipeMessageParamsDataView>(params_ptr, - &context); - MessageBuilder builder(interface_control::kRunOrClosePipeMessageId, 0, size, - 0); - - interface_control::internal::RunOrClosePipeMessageParams_Data* params = - nullptr; + Message message(interface_control::kRunOrClosePipeMessageId, 0, 0, 0, + nullptr); + SerializationContext context; + interface_control::internal::RunOrClosePipeMessageParams_Data::BufferWriter + params; Serialize<interface_control::RunOrClosePipeMessageParamsDataView>( - params_ptr, builder.buffer(), ¶ms, &context); - return std::move(*builder.message()); + params_ptr, message.payload_buffer(), ¶ms, &context); + return message; } void SendRunOrClosePipeMessage( MessageReceiverWithResponder* receiver, interface_control::RunOrClosePipeInputPtr input_ptr) { Message message(ConstructRunOrClosePipeMessage(std::move(input_ptr))); - ignore_result(receiver->Accept(&message)); } @@ -163,7 +150,7 @@ void ControlMessageProxy::FlushForTesting() { auto input_ptr = interface_control::RunInput::New(); input_ptr->set_flush_for_testing(interface_control::FlushForTesting::New()); - base::RunLoop run_loop; + base::RunLoop run_loop(base::RunLoop::Type::kNestableTasksAllowed); run_loop_quit_closure_ = run_loop.QuitClosure(); SendRunMessage( receiver_, std::move(input_ptr), diff --git a/mojo/public/cpp/bindings/lib/equals_traits.h b/mojo/public/cpp/bindings/lib/equals_traits.h deleted file mode 100644 index 53c7dce693..0000000000 --- a/mojo/public/cpp/bindings/lib/equals_traits.h +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_EQUALS_TRAITS_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_EQUALS_TRAITS_H_ - -#include <type_traits> -#include <unordered_map> -#include <vector> - -#include "base/optional.h" -#include "mojo/public/cpp/bindings/lib/template_util.h" - -namespace mojo { -namespace internal { - -template <typename T> -struct HasEqualsMethod { - template <typename U> - static char Test(decltype(&U::Equals)); - template <typename U> - static int Test(...); - static const bool value = sizeof(Test<T>(0)) == sizeof(char); - - private: - EnsureTypeIsComplete<T> check_t_; -}; - -template <typename T, bool has_equals_method = HasEqualsMethod<T>::value> -struct EqualsTraits; - -template <typename T> -bool Equals(const T& a, const T& b); - -template <typename T> -struct EqualsTraits<T, true> { - static bool Equals(const T& a, const T& b) { return a.Equals(b); } -}; - -template <typename T> -struct EqualsTraits<T, false> { - static bool Equals(const T& a, const T& b) { return a == b; } -}; - -template <typename T> -struct EqualsTraits<base::Optional<T>, false> { - static bool Equals(const base::Optional<T>& a, const base::Optional<T>& b) { - if (!a && !b) - return true; - if (!a || !b) - return false; - - return internal::Equals(*a, *b); - } -}; - -template <typename T> -struct EqualsTraits<std::vector<T>, false> { - static bool Equals(const std::vector<T>& a, const std::vector<T>& b) { - if (a.size() != b.size()) - return false; - for (size_t i = 0; i < a.size(); ++i) { - if (!internal::Equals(a[i], b[i])) - return false; - } - return true; - } -}; - -template <typename K, typename V> -struct EqualsTraits<std::unordered_map<K, V>, false> { - static bool Equals(const std::unordered_map<K, V>& a, - const std::unordered_map<K, V>& b) { - if (a.size() != b.size()) - return false; - for (const auto& element : a) { - auto iter = b.find(element.first); - if (iter == b.end() || !internal::Equals(element.second, iter->second)) - return false; - } - return true; - } -}; - -template <typename T> -bool Equals(const T& a, const T& b) { - return EqualsTraits<T>::Equals(a, b); -} - -} // namespace internal -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_EQUALS_TRAITS_H_ diff --git a/mojo/public/cpp/bindings/lib/fixed_buffer.cc b/mojo/public/cpp/bindings/lib/fixed_buffer.cc index 725a193cd7..3d595cc063 100644 --- a/mojo/public/cpp/bindings/lib/fixed_buffer.cc +++ b/mojo/public/cpp/bindings/lib/fixed_buffer.cc @@ -6,25 +6,17 @@ #include <stdlib.h> +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" + namespace mojo { namespace internal { -FixedBufferForTesting::FixedBufferForTesting(size_t size) { - size = internal::Align(size); - // Use calloc here to ensure all message memory is zero'd out. - void* ptr = calloc(size, 1); - Initialize(ptr, size); -} +FixedBufferForTesting::FixedBufferForTesting(size_t size) + : Buffer(calloc(Align(size), 1), Align(size), 0) {} FixedBufferForTesting::~FixedBufferForTesting() { free(data()); } -void* FixedBufferForTesting::Leak() { - void* ptr = data(); - Initialize(nullptr, 0); - return ptr; -} - } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/fixed_buffer.h b/mojo/public/cpp/bindings/lib/fixed_buffer.h index 070b0c8cef..147ce7b115 100644 --- a/mojo/public/cpp/bindings/lib/fixed_buffer.h +++ b/mojo/public/cpp/bindings/lib/fixed_buffer.h @@ -5,11 +5,10 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_FIXED_BUFFER_H_ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_FIXED_BUFFER_H_ -#include <stddef.h> +#include <cstddef> -#include "base/compiler_specific.h" +#include "base/component_export.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/buffer.h" namespace mojo { @@ -17,18 +16,12 @@ namespace internal { // FixedBufferForTesting owns its buffer. The Leak method may be used to steal // the underlying memory. -class MOJO_CPP_BINDINGS_EXPORT FixedBufferForTesting - : NON_EXPORTED_BASE(public Buffer) { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) FixedBufferForTesting + : public Buffer { public: explicit FixedBufferForTesting(size_t size); ~FixedBufferForTesting(); - // Returns the internal memory owned by the Buffer to the caller. The Buffer - // relinquishes its pointer, effectively resetting the state of the Buffer - // and leaving the caller responsible for freeing the returned memory address - // when no longer needed. - void* Leak(); - private: DISALLOW_COPY_AND_ASSIGN(FixedBufferForTesting); }; diff --git a/mojo/public/cpp/bindings/lib/handle_serialization.h b/mojo/public/cpp/bindings/lib/handle_serialization.h new file mode 100644 index 0000000000..6e1294e0a2 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/handle_serialization.h @@ -0,0 +1,35 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_SERIALIZATION_H_ + +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_context.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" +#include "mojo/public/cpp/system/handle.h" + +namespace mojo { +namespace internal { + +template <typename T> +struct Serializer<ScopedHandleBase<T>, ScopedHandleBase<T>> { + static void Serialize(ScopedHandleBase<T>& input, + Handle_Data* output, + SerializationContext* context) { + context->AddHandle(ScopedHandle::From(std::move(input)), output); + } + + static bool Deserialize(Handle_Data* input, + ScopedHandleBase<T>* output, + SerializationContext* context) { + *output = context->TakeHandleAs<T>(*input); + return true; + } +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc index 4682e72fad..6f119e4c1d 100644 --- a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc +++ b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc @@ -6,18 +6,17 @@ #include <stdint.h> -#include <utility> - #include "base/bind.h" #include "base/location.h" #include "base/logging.h" #include "base/macros.h" #include "base/memory/ptr_util.h" -#include "base/single_thread_task_runner.h" +#include "base/sequenced_task_runner.h" #include "base/stl_util.h" #include "mojo/public/cpp/bindings/associated_group.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/interface_endpoint_controller.h" +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" #include "mojo/public/cpp/bindings/lib/validation_util.h" #include "mojo/public/cpp/bindings/sync_call_restrictions.h" @@ -27,10 +26,10 @@ namespace mojo { namespace { -void DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient>& client, - const std::string& message) { - bool is_valid = client && !client->encountered_error(); - DCHECK(!is_valid) << message; +void DetermineIfEndpointIsConnected( + const base::WeakPtr<InterfaceEndpointClient>& client, + base::OnceCallback<void(bool)> callback) { + std::move(callback).Run(client && !client->encountered_error()); } // When receiving an incoming message which expects a repsonse, @@ -41,7 +40,7 @@ class ResponderThunk : public MessageReceiverWithStatus { public: explicit ResponderThunk( const base::WeakPtr<InterfaceEndpointClient>& endpoint_client, - scoped_refptr<base::SingleThreadTaskRunner> runner) + scoped_refptr<base::SequencedTaskRunner> runner) : endpoint_client_(endpoint_client), accept_was_invoked_(false), task_runner_(std::move(runner)) {} @@ -52,7 +51,7 @@ class ResponderThunk : public MessageReceiverWithStatus { // We raise an error to signal the calling application that an error // condition occurred. Without this the calling application would have no // way of knowing it should stop waiting for a response. - if (task_runner_->RunsTasksOnCurrentThread()) { + if (task_runner_->RunsTasksInCurrentSequence()) { // Please note that even if this code is run from a different task // runner on the same thread as |task_runner_|, it is okay to directly // call InterfaceEndpointClient::RaiseError(), because it will raise @@ -69,8 +68,12 @@ class ResponderThunk : public MessageReceiverWithStatus { } // MessageReceiver implementation: + bool PrefersSerializedMessages() override { + return endpoint_client_ && endpoint_client_->PrefersSerializedMessages(); + } + bool Accept(Message* message) override { - DCHECK(task_runner_->RunsTasksOnCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); accept_was_invoked_ = true; DCHECK(message->has_flag(Message::kFlagIsResponse)); @@ -83,24 +86,25 @@ class ResponderThunk : public MessageReceiverWithStatus { } // MessageReceiverWithStatus implementation: - bool IsValid() override { - DCHECK(task_runner_->RunsTasksOnCurrentThread()); + bool IsConnected() override { + DCHECK(task_runner_->RunsTasksInCurrentSequence()); return endpoint_client_ && !endpoint_client_->encountered_error(); } - void DCheckInvalid(const std::string& message) override { - if (task_runner_->RunsTasksOnCurrentThread()) { - DCheckIfInvalid(endpoint_client_, message); + void IsConnectedAsync(base::OnceCallback<void(bool)> callback) override { + if (task_runner_->RunsTasksInCurrentSequence()) { + DetermineIfEndpointIsConnected(endpoint_client_, std::move(callback)); } else { task_runner_->PostTask( - FROM_HERE, base::Bind(&DCheckIfInvalid, endpoint_client_, message)); + FROM_HERE, base::BindOnce(&DetermineIfEndpointIsConnected, + endpoint_client_, std::move(callback))); } - } + } private: base::WeakPtr<InterfaceEndpointClient> endpoint_client_; bool accept_was_invoked_; - scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; DISALLOW_COPY_AND_ASSIGN(ResponderThunk); }; @@ -136,7 +140,7 @@ InterfaceEndpointClient::InterfaceEndpointClient( MessageReceiverWithResponderStatus* receiver, std::unique_ptr<MessageReceiver> payload_validator, bool expect_sync_requests, - scoped_refptr<base::SingleThreadTaskRunner> runner, + scoped_refptr<base::SequencedTaskRunner> runner, uint32_t interface_version) : expect_sync_requests_(expect_sync_requests), handle_(std::move(handle)), @@ -163,20 +167,19 @@ InterfaceEndpointClient::InterfaceEndpointClient( } InterfaceEndpointClient::~InterfaceEndpointClient() { - DCHECK(thread_checker_.CalledOnValidThread()); - + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (controller_) handle_.group_controller()->DetachEndpointClient(handle_); } AssociatedGroup* InterfaceEndpointClient::associated_group() { if (!associated_group_) - associated_group_ = base::MakeUnique<AssociatedGroup>(handle_); + associated_group_ = std::make_unique<AssociatedGroup>(handle_); return associated_group_.get(); } ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(!has_pending_responders()); if (!handle_.is_valid()) @@ -199,7 +202,7 @@ void InterfaceEndpointClient::AddFilter( } void InterfaceEndpointClient::RaiseError() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!handle_.pending_association()) handle_.group_controller()->RaiseError(); @@ -207,14 +210,19 @@ void InterfaceEndpointClient::RaiseError() { void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason, const std::string& description) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); auto handle = PassHandle(); handle.ResetWithReason(custom_reason, description); } +bool InterfaceEndpointClient::PrefersSerializedMessages() { + auto* controller = handle_.group_controller(); + return controller && controller->PrefersSerializedMessages(); +} + bool InterfaceEndpointClient::Accept(Message* message) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(!message->has_flag(Message::kFlagExpectsResponse)); DCHECK(!handle_.pending_association()); @@ -237,7 +245,7 @@ bool InterfaceEndpointClient::Accept(Message* message) { bool InterfaceEndpointClient::AcceptWithResponder( Message* message, std::unique_ptr<MessageReceiver> responder) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(message->has_flag(Message::kFlagExpectsResponse)); DCHECK(!handle_.pending_association()); @@ -270,7 +278,7 @@ bool InterfaceEndpointClient::AcceptWithResponder( bool response_received = false; sync_responses_.insert(std::make_pair( - request_id, base::MakeUnique<SyncResponseInfo>(&response_received))); + request_id, std::make_unique<SyncResponseInfo>(&response_received))); base::WeakPtr<InterfaceEndpointClient> weak_self = weak_ptr_factory_.GetWeakPtr(); @@ -280,8 +288,13 @@ bool InterfaceEndpointClient::AcceptWithResponder( DCHECK(base::ContainsKey(sync_responses_, request_id)); auto iter = sync_responses_.find(request_id); DCHECK_EQ(&response_received, iter->second->response_received); - if (response_received) + if (response_received) { ignore_result(responder->Accept(&iter->second->response)); + } else { + DVLOG(1) << "Mojo sync call returns without receiving a response. " + << "Typcially it is because the interface has been " + << "disconnected."; + } sync_responses_.erase(iter); } @@ -289,13 +302,13 @@ bool InterfaceEndpointClient::AcceptWithResponder( } bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return filters_.Accept(message); } void InterfaceEndpointClient::NotifyError( const base::Optional<DisconnectReason>& reason) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (encountered_error_) return; @@ -309,16 +322,14 @@ void InterfaceEndpointClient::NotifyError( control_message_proxy_.OnConnectionError(); - if (!error_handler_.is_null()) { - base::Closure error_handler = std::move(error_handler_); - error_handler.Run(); - } else if (!error_with_reason_handler_.is_null()) { - ConnectionErrorWithReasonCallback error_with_reason_handler = - std::move(error_with_reason_handler_); + if (error_handler_) { + std::move(error_handler_).Run(); + } else if (error_with_reason_handler_) { if (reason) { - error_with_reason_handler.Run(reason->custom_reason, reason->description); + std::move(error_with_reason_handler_) + .Run(reason->custom_reason, reason->description); } else { - error_with_reason_handler.Run(0, std::string()); + std::move(error_with_reason_handler_).Run(0, std::string()); } } } @@ -374,7 +385,7 @@ bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) { if (message->has_flag(Message::kFlagExpectsResponse)) { std::unique_ptr<MessageReceiverWithStatus> responder = - base::MakeUnique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(), + std::make_unique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(), task_runner_); if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) { return control_message_handler_.AcceptWithResponder(message, diff --git a/mojo/public/cpp/bindings/lib/interface_ptr_state.cc b/mojo/public/cpp/bindings/lib/interface_ptr_state.cc new file mode 100644 index 0000000000..8cd23ea067 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/interface_ptr_state.cc @@ -0,0 +1,94 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/interface_ptr_state.h" + +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + +namespace mojo { +namespace internal { + +InterfacePtrStateBase::InterfacePtrStateBase() = default; + +InterfacePtrStateBase::~InterfacePtrStateBase() { + endpoint_client_.reset(); + if (router_) + router_->CloseMessagePipe(); +} + +void InterfacePtrStateBase::QueryVersion( + const base::Callback<void(uint32_t)>& callback) { + // It is safe to capture |this| because the callback won't be run after this + // object goes away. + endpoint_client_->QueryVersion( + base::Bind(&InterfacePtrStateBase::OnQueryVersion, base::Unretained(this), + callback)); +} + +void InterfacePtrStateBase::RequireVersion(uint32_t version) { + if (version <= version_) + return; + + version_ = version; + endpoint_client_->RequireVersion(version); +} + +void InterfacePtrStateBase::Swap(InterfacePtrStateBase* other) { + using std::swap; + swap(other->router_, router_); + swap(other->endpoint_client_, endpoint_client_); + handle_.swap(other->handle_); + runner_.swap(other->runner_); + swap(other->version_, version_); +} + +void InterfacePtrStateBase::Bind( + ScopedMessagePipeHandle handle, + uint32_t version, + scoped_refptr<base::SequencedTaskRunner> task_runner) { + DCHECK(!router_); + DCHECK(!endpoint_client_); + DCHECK(!handle_.is_valid()); + DCHECK_EQ(0u, version_); + DCHECK(handle.is_valid()); + + handle_ = std::move(handle); + version_ = version; + runner_ = + GetTaskRunnerToUseFromUserProvidedTaskRunner(std::move(task_runner)); +} + +void InterfacePtrStateBase::OnQueryVersion( + const base::Callback<void(uint32_t)>& callback, + uint32_t version) { + version_ = version; + callback.Run(version); +} + +bool InterfacePtrStateBase::InitializeEndpointClient( + bool passes_associated_kinds, + bool has_sync_methods, + std::unique_ptr<MessageReceiver> payload_validator) { + // The object hasn't been bound. + if (!handle_.is_valid()) + return false; + + MultiplexRouter::Config config = + passes_associated_kinds + ? MultiplexRouter::MULTI_INTERFACE + : (has_sync_methods + ? MultiplexRouter::SINGLE_INTERFACE_WITH_SYNC_METHODS + : MultiplexRouter::SINGLE_INTERFACE); + router_ = new MultiplexRouter(std::move(handle_), config, true, runner_); + endpoint_client_.reset(new InterfaceEndpointClient( + router_->CreateLocalEndpointHandle(kMasterInterfaceId), nullptr, + std::move(payload_validator), false, std::move(runner_), + // The version is only queried from the client so the value passed here + // will not be used. + 0u)); + return true; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/interface_ptr_state.h b/mojo/public/cpp/bindings/lib/interface_ptr_state.h index fa54979795..2e73564a80 100644 --- a/mojo/public/cpp/bindings/lib/interface_ptr_state.h +++ b/mojo/public/cpp/bindings/lib/interface_ptr_state.h @@ -18,8 +18,9 @@ #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/memory/ref_counted.h" -#include "base/single_thread_task_runner.h" +#include "base/sequenced_task_runner.h" #include "mojo/public/cpp/bindings/associated_group.h" +#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" #include "mojo/public/cpp/bindings/filter_chain.h" #include "mojo/public/cpp/bindings/interface_endpoint_client.h" @@ -32,191 +33,190 @@ namespace mojo { namespace internal { -template <typename Interface> -class InterfacePtrState { +class MOJO_CPP_BINDINGS_EXPORT InterfacePtrStateBase { public: - InterfacePtrState() : version_(0u) {} + InterfacePtrStateBase(); + ~InterfacePtrStateBase(); + + MessagePipeHandle handle() const { + return router_ ? router_->handle() : handle_.get(); + } + + uint32_t version() const { return version_; } + + bool is_bound() const { return handle_.is_valid() || endpoint_client_; } + + bool encountered_error() const { + return endpoint_client_ ? endpoint_client_->encountered_error() : false; + } + + bool HasAssociatedInterfaces() const { + return router_ ? router_->HasAssociatedEndpoints() : false; + } + + // Returns true if bound and awaiting a response to a message. + bool has_pending_callbacks() const { + return endpoint_client_ && endpoint_client_->has_pending_responders(); + } + + protected: + InterfaceEndpointClient* endpoint_client() const { + return endpoint_client_.get(); + } + MultiplexRouter* router() const { return router_.get(); } - ~InterfacePtrState() { + void QueryVersion(const base::Callback<void(uint32_t)>& callback); + void RequireVersion(uint32_t version); + void Swap(InterfacePtrStateBase* other); + void Bind(ScopedMessagePipeHandle handle, + uint32_t version, + scoped_refptr<base::SequencedTaskRunner> task_runner); + + ScopedMessagePipeHandle PassMessagePipe() { endpoint_client_.reset(); - proxy_.reset(); - if (router_) - router_->CloseMessagePipe(); + return router_ ? router_->PassMessagePipe() : std::move(handle_); } - Interface* instance() { + bool InitializeEndpointClient( + bool passes_associated_kinds, + bool has_sync_methods, + std::unique_ptr<MessageReceiver> payload_validator); + + private: + void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, + uint32_t version); + + scoped_refptr<MultiplexRouter> router_; + + std::unique_ptr<InterfaceEndpointClient> endpoint_client_; + + // |router_| (as well as other members above) is not initialized until + // read/write with the message pipe handle is needed. |handle_| is valid + // between the Bind() call and the initialization of |router_|. + ScopedMessagePipeHandle handle_; + scoped_refptr<base::SequencedTaskRunner> runner_; + + uint32_t version_ = 0; + + DISALLOW_COPY_AND_ASSIGN(InterfacePtrStateBase); +}; + +template <typename Interface> +class InterfacePtrState : public InterfacePtrStateBase { + public: + using Proxy = typename Interface::Proxy_; + + InterfacePtrState() = default; + ~InterfacePtrState() = default; + + Proxy* instance() { ConfigureProxyIfNecessary(); // This will be null if the object is not bound. return proxy_.get(); } - uint32_t version() const { return version_; } - void QueryVersion(const base::Callback<void(uint32_t)>& callback) { ConfigureProxyIfNecessary(); - - // It is safe to capture |this| because the callback won't be run after this - // object goes away. - endpoint_client_->QueryVersion(base::Bind( - &InterfacePtrState::OnQueryVersion, base::Unretained(this), callback)); + InterfacePtrStateBase::QueryVersion(callback); } void RequireVersion(uint32_t version) { ConfigureProxyIfNecessary(); - - if (version <= version_) - return; - - version_ = version; - endpoint_client_->RequireVersion(version); + InterfacePtrStateBase::RequireVersion(version); } void FlushForTesting() { ConfigureProxyIfNecessary(); - endpoint_client_->FlushForTesting(); + endpoint_client()->FlushForTesting(); } void CloseWithReason(uint32_t custom_reason, const std::string& description) { ConfigureProxyIfNecessary(); - endpoint_client_->CloseWithReason(custom_reason, description); + endpoint_client()->CloseWithReason(custom_reason, description); } void Swap(InterfacePtrState* other) { using std::swap; - swap(other->router_, router_); - swap(other->endpoint_client_, endpoint_client_); swap(other->proxy_, proxy_); - handle_.swap(other->handle_); - runner_.swap(other->runner_); - swap(other->version_, version_); + InterfacePtrStateBase::Swap(other); } void Bind(InterfacePtrInfo<Interface> info, - scoped_refptr<base::SingleThreadTaskRunner> runner) { - DCHECK(!router_); - DCHECK(!endpoint_client_); + scoped_refptr<base::SequencedTaskRunner> runner) { DCHECK(!proxy_); - DCHECK(!handle_.is_valid()); - DCHECK_EQ(0u, version_); - DCHECK(info.is_valid()); - - handle_ = info.PassHandle(); - version_ = info.version(); - runner_ = std::move(runner); - } - - bool HasAssociatedInterfaces() const { - return router_ ? router_->HasAssociatedEndpoints() : false; + InterfacePtrStateBase::Bind(info.PassHandle(), info.version(), + std::move(runner)); } // After this method is called, the object is in an invalid state and // shouldn't be reused. InterfacePtrInfo<Interface> PassInterface() { - endpoint_client_.reset(); proxy_.reset(); - return InterfacePtrInfo<Interface>( - router_ ? router_->PassMessagePipe() : std::move(handle_), version_); + return InterfacePtrInfo<Interface>(PassMessagePipe(), version()); } - bool is_bound() const { return handle_.is_valid() || endpoint_client_; } - - bool encountered_error() const { - return endpoint_client_ ? endpoint_client_->encountered_error() : false; - } - - void set_connection_error_handler(const base::Closure& error_handler) { + void set_connection_error_handler(base::OnceClosure error_handler) { ConfigureProxyIfNecessary(); - DCHECK(endpoint_client_); - endpoint_client_->set_connection_error_handler(error_handler); + DCHECK(endpoint_client()); + endpoint_client()->set_connection_error_handler(std::move(error_handler)); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { ConfigureProxyIfNecessary(); - DCHECK(endpoint_client_); - endpoint_client_->set_connection_error_with_reason_handler(error_handler); - } - - // Returns true if bound and awaiting a response to a message. - bool has_pending_callbacks() const { - return endpoint_client_ && endpoint_client_->has_pending_responders(); + DCHECK(endpoint_client()); + endpoint_client()->set_connection_error_with_reason_handler( + std::move(error_handler)); } AssociatedGroup* associated_group() { ConfigureProxyIfNecessary(); - return endpoint_client_->associated_group(); + return endpoint_client()->associated_group(); } void EnableTestingMode() { ConfigureProxyIfNecessary(); - router_->EnableTestingMode(); + router()->EnableTestingMode(); } void ForwardMessage(Message message) { ConfigureProxyIfNecessary(); - endpoint_client_->Accept(&message); + endpoint_client()->Accept(&message); } void ForwardMessageWithResponder(Message message, std::unique_ptr<MessageReceiver> responder) { ConfigureProxyIfNecessary(); - endpoint_client_->AcceptWithResponder(&message, std::move(responder)); + endpoint_client()->AcceptWithResponder(&message, std::move(responder)); } - private: - using Proxy = typename Interface::Proxy_; + void RaiseError() { + ConfigureProxyIfNecessary(); + endpoint_client()->RaiseError(); + } + private: void ConfigureProxyIfNecessary() { // The proxy has been configured. if (proxy_) { - DCHECK(router_); - DCHECK(endpoint_client_); + DCHECK(router()); + DCHECK(endpoint_client()); return; } - // The object hasn't been bound. - if (!handle_.is_valid()) - return; - MultiplexRouter::Config config = - Interface::PassesAssociatedKinds_ - ? MultiplexRouter::MULTI_INTERFACE - : (Interface::HasSyncMethods_ - ? MultiplexRouter::SINGLE_INTERFACE_WITH_SYNC_METHODS - : MultiplexRouter::SINGLE_INTERFACE); - router_ = new MultiplexRouter(std::move(handle_), config, true, runner_); - router_->SetMasterInterfaceName(Interface::Name_); - endpoint_client_.reset(new InterfaceEndpointClient( - router_->CreateLocalEndpointHandle(kMasterInterfaceId), nullptr, - base::WrapUnique(new typename Interface::ResponseValidator_()), false, - std::move(runner_), - // The version is only queried from the client so the value passed here - // will not be used. - 0u)); - proxy_.reset(new Proxy(endpoint_client_.get())); - } - - void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, - uint32_t version) { - version_ = version; - callback.Run(version); + if (InitializeEndpointClient( + Interface::PassesAssociatedKinds_, Interface::HasSyncMethods_, + std::make_unique<typename Interface::ResponseValidator_>())) { + router()->SetMasterInterfaceName(Interface::Name_); + proxy_ = std::make_unique<Proxy>(endpoint_client()); + } } - scoped_refptr<MultiplexRouter> router_; - - std::unique_ptr<InterfaceEndpointClient> endpoint_client_; std::unique_ptr<Proxy> proxy_; - // |router_| (as well as other members above) is not initialized until - // read/write with the message pipe handle is needed. |handle_| is valid - // between the Bind() call and the initialization of |router_|. - ScopedMessagePipeHandle handle_; - scoped_refptr<base::SingleThreadTaskRunner> runner_; - - uint32_t version_; - DISALLOW_COPY_AND_ASSIGN(InterfacePtrState); }; diff --git a/mojo/public/cpp/bindings/lib/handle_interface_serialization.h b/mojo/public/cpp/bindings/lib/interface_serialization.h index 14ed21f0ac..00954de261 100644 --- a/mojo/public/cpp/bindings/lib/handle_interface_serialization.h +++ b/mojo/public/cpp/bindings/lib/interface_serialization.h @@ -1,9 +1,9 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. +// Copyright 2018 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_INTERFACE_SERIALIZATION_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_INTERFACE_SERIALIZATION_H_ +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_INTERFACE_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_INTERFACE_SERIALIZATION_H_ #include <type_traits> @@ -17,6 +17,7 @@ #include "mojo/public/cpp/bindings/lib/serialization_context.h" #include "mojo/public/cpp/bindings/lib/serialization_forward.h" #include "mojo/public/cpp/system/handle.h" +#include "mojo/public/cpp/system/message_pipe.h" namespace mojo { namespace internal { @@ -26,40 +27,24 @@ struct Serializer<AssociatedInterfacePtrInfoDataView<Base>, AssociatedInterfacePtrInfo<T>> { static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static size_t PrepareToSerialize(const AssociatedInterfacePtrInfo<T>& input, - SerializationContext* context) { - if (input.handle().is_valid()) - context->associated_endpoint_count++; - return 0; - } - static void Serialize(AssociatedInterfacePtrInfo<T>& input, AssociatedInterface_Data* output, SerializationContext* context) { DCHECK(!input.handle().is_valid() || input.handle().pending_association()); - if (input.handle().is_valid()) { - // Set to the index of the element pushed to the back of the vector. - output->handle.value = - static_cast<uint32_t>(context->associated_endpoint_handles.size()); - context->associated_endpoint_handles.push_back(input.PassHandle()); - } else { - output->handle.value = kEncodedInvalidHandleValue; - } - output->version = input.version(); + context->AddAssociatedInterfaceInfo(input.PassHandle(), input.version(), + output); } static bool Deserialize(AssociatedInterface_Data* input, AssociatedInterfacePtrInfo<T>* output, SerializationContext* context) { - if (input->handle.is_valid()) { - DCHECK_LT(input->handle.value, - context->associated_endpoint_handles.size()); - output->set_handle( - std::move(context->associated_endpoint_handles[input->handle.value])); + auto handle = context->TakeAssociatedEndpointHandle(input->handle); + if (!handle.is_valid()) { + *output = AssociatedInterfacePtrInfo<T>(); } else { - output->set_handle(ScopedInterfaceEndpointHandle()); + output->set_handle(std::move(handle)); + output->set_version(input->version); } - output->set_version(input->version); return true; } }; @@ -69,37 +54,21 @@ struct Serializer<AssociatedInterfaceRequestDataView<Base>, AssociatedInterfaceRequest<T>> { static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static size_t PrepareToSerialize(const AssociatedInterfaceRequest<T>& input, - SerializationContext* context) { - if (input.handle().is_valid()) - context->associated_endpoint_count++; - return 0; - } - static void Serialize(AssociatedInterfaceRequest<T>& input, AssociatedEndpointHandle_Data* output, SerializationContext* context) { DCHECK(!input.handle().is_valid() || input.handle().pending_association()); - if (input.handle().is_valid()) { - // Set to the index of the element pushed to the back of the vector. - output->value = - static_cast<uint32_t>(context->associated_endpoint_handles.size()); - context->associated_endpoint_handles.push_back(input.PassHandle()); - } else { - output->value = kEncodedInvalidHandleValue; - } + context->AddAssociatedEndpoint(input.PassHandle(), output); } static bool Deserialize(AssociatedEndpointHandle_Data* input, AssociatedInterfaceRequest<T>* output, SerializationContext* context) { - if (input->is_valid()) { - DCHECK_LT(input->value, context->associated_endpoint_handles.size()); - output->Bind( - std::move(context->associated_endpoint_handles[input->value])); - } else { - output->Bind(ScopedInterfaceEndpointHandle()); - } + auto handle = context->TakeAssociatedEndpointHandle(*input); + if (!handle.is_valid()) + *output = AssociatedInterfaceRequest<T>(); + else + *output = AssociatedInterfaceRequest<T>(std::move(handle)); return true; } }; @@ -108,69 +77,58 @@ template <typename Base, typename T> struct Serializer<InterfacePtrDataView<Base>, InterfacePtr<T>> { static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static size_t PrepareToSerialize(const InterfacePtr<T>& input, - SerializationContext* context) { - return 0; - } - static void Serialize(InterfacePtr<T>& input, Interface_Data* output, SerializationContext* context) { InterfacePtrInfo<T> info = input.PassInterface(); - output->handle = context->handles.AddHandle(info.PassHandle().release()); - output->version = info.version(); + context->AddInterfaceInfo(info.PassHandle(), info.version(), output); } static bool Deserialize(Interface_Data* input, InterfacePtr<T>* output, SerializationContext* context) { output->Bind(InterfacePtrInfo<T>( - context->handles.TakeHandleAs<mojo::MessagePipeHandle>(input->handle), + context->TakeHandleAs<mojo::MessagePipeHandle>(input->handle), input->version)); return true; } }; template <typename Base, typename T> -struct Serializer<InterfaceRequestDataView<Base>, InterfaceRequest<T>> { +struct Serializer<InterfacePtrDataView<Base>, InterfacePtrInfo<T>> { static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static size_t PrepareToSerialize(const InterfaceRequest<T>& input, - SerializationContext* context) { - return 0; - } - - static void Serialize(InterfaceRequest<T>& input, - Handle_Data* output, + static void Serialize(InterfacePtrInfo<T>& input, + Interface_Data* output, SerializationContext* context) { - *output = context->handles.AddHandle(input.PassMessagePipe().release()); + context->AddInterfaceInfo(input.PassHandle(), input.version(), output); } - static bool Deserialize(Handle_Data* input, - InterfaceRequest<T>* output, + static bool Deserialize(Interface_Data* input, + InterfacePtrInfo<T>* output, SerializationContext* context) { - output->Bind(context->handles.TakeHandleAs<MessagePipeHandle>(*input)); + *output = InterfacePtrInfo<T>( + context->TakeHandleAs<mojo::MessagePipeHandle>(input->handle), + input->version); return true; } }; -template <typename T> -struct Serializer<ScopedHandleBase<T>, ScopedHandleBase<T>> { - static size_t PrepareToSerialize(const ScopedHandleBase<T>& input, - SerializationContext* context) { - return 0; - } +template <typename Base, typename T> +struct Serializer<InterfaceRequestDataView<Base>, InterfaceRequest<T>> { + static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static void Serialize(ScopedHandleBase<T>& input, + static void Serialize(InterfaceRequest<T>& input, Handle_Data* output, SerializationContext* context) { - *output = context->handles.AddHandle(input.release()); + context->AddHandle(ScopedHandle::From(input.PassMessagePipe()), output); } static bool Deserialize(Handle_Data* input, - ScopedHandleBase<T>* output, + InterfaceRequest<T>* output, SerializationContext* context) { - *output = context->handles.TakeHandleAs<T>(*input); + *output = + InterfaceRequest<T>(context->TakeHandleAs<MessagePipeHandle>(*input)); return true; } }; @@ -178,4 +136,4 @@ struct Serializer<ScopedHandleBase<T>, ScopedHandleBase<T>> { } // namespace internal } // namespace mojo -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_INTERFACE_SERIALIZATION_H_ +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_INTERFACE_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/map_data_internal.h b/mojo/public/cpp/bindings/lib/map_data_internal.h index f8e3d2918f..217904fd43 100644 --- a/mojo/public/cpp/bindings/lib/map_data_internal.h +++ b/mojo/public/cpp/bindings/lib/map_data_internal.h @@ -5,6 +5,7 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_DATA_INTERNAL_H_ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_DATA_INTERNAL_H_ +#include "base/macros.h" #include "mojo/public/cpp/bindings/lib/array_internal.h" #include "mojo/public/cpp/bindings/lib/validate_params.h" #include "mojo/public/cpp/bindings/lib/validation_errors.h" @@ -18,9 +19,29 @@ namespace internal { template <typename Key, typename Value> class Map_Data { public: - static Map_Data* New(Buffer* buf) { - return new (buf->Allocate(sizeof(Map_Data))) Map_Data(); - } + class BufferWriter { + public: + BufferWriter() = default; + + void Allocate(Buffer* buffer) { + buffer_ = buffer; + index_ = buffer_->Allocate(sizeof(Map_Data)); + new (data()) Map_Data(); + } + + bool is_null() const { return !buffer_; } + Map_Data* data() { + DCHECK(!is_null()); + return buffer_->Get<Map_Data>(index_); + } + Map_Data* operator->() { return data(); } + + private: + Buffer* buffer_ = nullptr; + size_t index_ = 0; + + DISALLOW_COPY_AND_ASSIGN(BufferWriter); + }; // |validate_params| must have non-null |key_validate_params| and // |element_validate_params| members. @@ -41,16 +62,13 @@ class Map_Data { return false; } - if (!ValidatePointerNonNullable( - object->keys, "null key array in map struct", validation_context) || + if (!ValidatePointerNonNullable(object->keys, 0, validation_context) || !ValidateContainer(object->keys, validation_context, validate_params->key_validate_params)) { return false; } - if (!ValidatePointerNonNullable(object->values, - "null value array in map struct", - validation_context) || + if (!ValidatePointerNonNullable(object->values, 1, validation_context) || !ValidateContainer(object->values, validation_context, validate_params->element_validate_params)) { return false; diff --git a/mojo/public/cpp/bindings/lib/map_serialization.h b/mojo/public/cpp/bindings/lib/map_serialization.h index 718a76307d..b114f4995c 100644 --- a/mojo/public/cpp/bindings/lib/map_serialization.h +++ b/mojo/public/cpp/bindings/lib/map_serialization.h @@ -95,57 +95,34 @@ struct Serializer<MapDataView<Key, Value>, MaybeConstUserType> { std::vector<UserValue>, MapValueReader<MaybeConstUserType>>; - static size_t PrepareToSerialize(MaybeConstUserType& input, - SerializationContext* context) { - if (CallIsNullIfExists<Traits>(input)) - return 0; - - size_t struct_overhead = sizeof(Data); - MapKeyReader<MaybeConstUserType> key_reader(input); - size_t keys_size = - KeyArraySerializer::GetSerializedSize(&key_reader, context); - MapValueReader<MaybeConstUserType> value_reader(input); - size_t values_size = - ValueArraySerializer::GetSerializedSize(&value_reader, context); - - return struct_overhead + keys_size + values_size; - } - static void Serialize(MaybeConstUserType& input, Buffer* buf, - Data** output, + typename Data::BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(validate_params->key_validate_params); DCHECK(validate_params->element_validate_params); - if (CallIsNullIfExists<Traits>(input)) { - *output = nullptr; + if (CallIsNullIfExists<Traits>(input)) return; - } - auto result = Data::New(buf); - if (result) { - auto keys_ptr = MojomTypeTraits<ArrayDataView<Key>>::Data::New( - Traits::GetSize(input), buf); - if (keys_ptr) { - MapKeyReader<MaybeConstUserType> key_reader(input); - KeyArraySerializer::SerializeElements( - &key_reader, buf, keys_ptr, validate_params->key_validate_params, - context); - result->keys.Set(keys_ptr); - } - - auto values_ptr = MojomTypeTraits<ArrayDataView<Value>>::Data::New( - Traits::GetSize(input), buf); - if (values_ptr) { - MapValueReader<MaybeConstUserType> value_reader(input); - ValueArraySerializer::SerializeElements( - &value_reader, buf, values_ptr, - validate_params->element_validate_params, context); - result->values.Set(values_ptr); - } - } - *output = result; + writer->Allocate(buf); + typename MojomTypeTraits<ArrayDataView<Key>>::Data::BufferWriter + keys_writer; + keys_writer.Allocate(Traits::GetSize(input), buf); + MapKeyReader<MaybeConstUserType> key_reader(input); + KeyArraySerializer::SerializeElements(&key_reader, buf, &keys_writer, + validate_params->key_validate_params, + context); + (*writer)->keys.Set(keys_writer.data()); + + typename MojomTypeTraits<ArrayDataView<Value>>::Data::BufferWriter + values_writer; + values_writer.Allocate(Traits::GetSize(input), buf); + MapValueReader<MaybeConstUserType> value_reader(input); + ValueArraySerializer::SerializeElements( + &value_reader, buf, &values_writer, + validate_params->element_validate_params, context); + (*writer)->values.Set(values_writer.data()); } static bool Deserialize(Data* input, diff --git a/mojo/public/cpp/bindings/lib/may_auto_lock.h b/mojo/public/cpp/bindings/lib/may_auto_lock.h index 06091fee90..78cb89fa77 100644 --- a/mojo/public/cpp/bindings/lib/may_auto_lock.h +++ b/mojo/public/cpp/bindings/lib/may_auto_lock.h @@ -5,6 +5,7 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MAY_AUTO_LOCK_H_ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_MAY_AUTO_LOCK_H_ +#include "base/component_export.h" #include "base/macros.h" #include "base/optional.h" #include "base/synchronization/lock.h" @@ -14,7 +15,7 @@ namespace internal { // Similar to base::AutoLock, except that it does nothing if |lock| passed into // the constructor is null. -class MayAutoLock { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) MayAutoLock { public: explicit MayAutoLock(base::Optional<base::Lock>* lock) : lock_(lock->has_value() ? &lock->value() : nullptr) { @@ -36,7 +37,7 @@ class MayAutoLock { // Similar to base::AutoUnlock, except that it does nothing if |lock| passed // into the constructor is null. -class MayAutoUnlock { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) MayAutoUnlock { public: explicit MayAutoUnlock(base::Optional<base::Lock>* lock) : lock_(lock->has_value() ? &lock->value() : nullptr) { diff --git a/mojo/public/cpp/bindings/lib/message.cc b/mojo/public/cpp/bindings/lib/message.cc index e5f3808117..8972d9efd1 100644 --- a/mojo/public/cpp/bindings/lib/message.cc +++ b/mojo/public/cpp/bindings/lib/message.cc @@ -14,72 +14,279 @@ #include "base/bind.h" #include "base/lazy_instance.h" #include "base/logging.h" +#include "base/numerics/safe_math.h" #include "base/strings/stringprintf.h" #include "base/threading/thread_local.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/lib/unserialized_message_context.h" namespace mojo { namespace { base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>:: - DestructorAtExit g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER; + Leaky g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER; -base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>:: - DestructorAtExit g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER; +base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::Leaky + g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER; void DoNotifyBadMessage(Message message, const std::string& error) { message.NotifyBadMessage(error); } -} // namespace +template <typename HeaderType> +void AllocateHeaderFromBuffer(internal::Buffer* buffer, HeaderType** header) { + *header = buffer->AllocateAndGet<HeaderType>(); + (*header)->num_bytes = sizeof(HeaderType); +} + +void WriteMessageHeader(uint32_t name, + uint32_t flags, + size_t payload_interface_id_count, + internal::Buffer* payload_buffer) { + if (payload_interface_id_count > 0) { + // Version 2 + internal::MessageHeaderV2* header; + AllocateHeaderFromBuffer(payload_buffer, &header); + header->version = 2; + header->name = name; + header->flags = flags; + // The payload immediately follows the header. + header->payload.Set(header + 1); + } else if (flags & + (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { + // Version 1 + internal::MessageHeaderV1* header; + AllocateHeaderFromBuffer(payload_buffer, &header); + header->version = 1; + header->name = name; + header->flags = flags; + } else { + internal::MessageHeader* header; + AllocateHeaderFromBuffer(payload_buffer, &header); + header->version = 0; + header->name = name; + header->flags = flags; + } +} + +void CreateSerializedMessageObject(uint32_t name, + uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count, + std::vector<ScopedHandle>* handles, + ScopedMessageHandle* out_handle, + internal::Buffer* out_buffer) { + ScopedMessageHandle handle; + MojoResult rv = mojo::CreateMessage(&handle); + DCHECK_EQ(MOJO_RESULT_OK, rv); + DCHECK(handle.is_valid()); + + void* buffer; + uint32_t buffer_size; + size_t total_size = internal::ComputeSerializedMessageSize( + flags, payload_size, payload_interface_id_count); + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(total_size)); + DCHECK(!handles || + base::IsValueInRangeForNumericType<uint32_t>(handles->size())); + rv = MojoAppendMessageData( + handle->value(), static_cast<uint32_t>(total_size), + handles ? reinterpret_cast<MojoHandle*>(handles->data()) : nullptr, + handles ? static_cast<uint32_t>(handles->size()) : 0, nullptr, &buffer, + &buffer_size); + DCHECK_EQ(MOJO_RESULT_OK, rv); + if (handles) { + // Handle ownership has been taken by MojoAppendMessageData. + for (size_t i = 0; i < handles->size(); ++i) + ignore_result(handles->at(i).release()); + } + + internal::Buffer payload_buffer(handle.get(), total_size, buffer, + buffer_size); + + // Make sure we zero the memory first! + memset(payload_buffer.data(), 0, total_size); + WriteMessageHeader(name, flags, payload_interface_id_count, &payload_buffer); + + *out_handle = std::move(handle); + *out_buffer = std::move(payload_buffer); +} + +void SerializeUnserializedContext(MojoMessageHandle message, + uintptr_t context_value) { + auto* context = + reinterpret_cast<internal::UnserializedMessageContext*>(context_value); + void* buffer; + uint32_t buffer_size; + MojoResult attach_result = MojoAppendMessageData( + message, 0, nullptr, 0, nullptr, &buffer, &buffer_size); + if (attach_result != MOJO_RESULT_OK) + return; + + internal::Buffer payload_buffer(MessageHandle(message), 0, buffer, + buffer_size); + WriteMessageHeader(context->message_name(), context->message_flags(), + 0 /* payload_interface_id_count */, &payload_buffer); + + // We need to copy additional header data which may have been set after + // message construction, as this codepath may be reached at some arbitrary + // time between message send and message dispatch. + static_cast<internal::MessageHeader*>(buffer)->interface_id = + context->header()->interface_id; + if (context->header()->flags & + (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { + DCHECK_GE(context->header()->version, 1u); + static_cast<internal::MessageHeaderV1*>(buffer)->request_id = + context->header()->request_id; + } + + internal::SerializationContext serialization_context; + context->Serialize(&serialization_context, &payload_buffer); + + // TODO(crbug.com/753433): Support lazy serialization of associated endpoint + // handles. See corresponding TODO in the bindings generator for proof that + // this DCHECK is indeed valid. + DCHECK(serialization_context.associated_endpoint_handles()->empty()); + if (!serialization_context.handles()->empty()) + payload_buffer.AttachHandles(serialization_context.mutable_handles()); + payload_buffer.Seal(); +} + +void DestroyUnserializedContext(uintptr_t context) { + delete reinterpret_cast<internal::UnserializedMessageContext*>(context); +} -Message::Message() { +ScopedMessageHandle CreateUnserializedMessageObject( + std::unique_ptr<internal::UnserializedMessageContext> context) { + ScopedMessageHandle handle; + MojoResult rv = mojo::CreateMessage(&handle); + DCHECK_EQ(MOJO_RESULT_OK, rv); + DCHECK(handle.is_valid()); + + rv = MojoSetMessageContext( + handle->value(), reinterpret_cast<uintptr_t>(context.release()), + &SerializeUnserializedContext, &DestroyUnserializedContext, nullptr); + DCHECK_EQ(MOJO_RESULT_OK, rv); + return handle; } +} // namespace + +Message::Message() = default; + Message::Message(Message&& other) - : buffer_(std::move(other.buffer_)), + : handle_(std::move(other.handle_)), + payload_buffer_(std::move(other.payload_buffer_)), handles_(std::move(other.handles_)), associated_endpoint_handles_( - std::move(other.associated_endpoint_handles_)) {} + std::move(other.associated_endpoint_handles_)), + transferable_(other.transferable_), + serialized_(other.serialized_) { + other.transferable_ = false; + other.serialized_ = false; +#if defined(ENABLE_IPC_FUZZER) + interface_name_ = other.interface_name_; + method_name_ = other.method_name_; +#endif +} + +Message::Message(std::unique_ptr<internal::UnserializedMessageContext> context) + : Message(CreateUnserializedMessageObject(std::move(context))) {} + +Message::Message(uint32_t name, + uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count, + std::vector<ScopedHandle>* handles) { + CreateSerializedMessageObject(name, flags, payload_size, + payload_interface_id_count, handles, &handle_, + &payload_buffer_); + transferable_ = true; + serialized_ = true; +} + +Message::Message(ScopedMessageHandle handle) { + DCHECK(handle.is_valid()); + + uintptr_t context_value = 0; + MojoResult get_context_result = + MojoGetMessageContext(handle->value(), nullptr, &context_value); + if (get_context_result == MOJO_RESULT_NOT_FOUND) { + // It's a serialized message. Extract handles if possible. + uint32_t num_bytes; + void* buffer; + uint32_t num_handles = 0; + MojoResult rv = MojoGetMessageData(handle->value(), nullptr, &buffer, + &num_bytes, nullptr, &num_handles); + if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) { + handles_.resize(num_handles); + rv = MojoGetMessageData(handle->value(), nullptr, &buffer, &num_bytes, + reinterpret_cast<MojoHandle*>(handles_.data()), + &num_handles); + } else { + // No handles, so it's safe to retransmit this message if the caller + // really wants to. + transferable_ = true; + } -Message::~Message() { - CloseHandles(); + if (rv != MOJO_RESULT_OK) { + // Failed to deserialize handles. Leave the Message uninitialized. + return; + } + + payload_buffer_ = internal::Buffer(buffer, num_bytes, num_bytes); + serialized_ = true; + } else { + DCHECK_EQ(MOJO_RESULT_OK, get_context_result); + auto* context = + reinterpret_cast<internal::UnserializedMessageContext*>(context_value); + // Dummy data address so common header accessors still behave properly. The + // choice is V1 reflects unserialized message capabilities: we may or may + // not need to support request IDs (which require at least V1), but we never + // (for now, anyway) need to support associated interface handles (V2). + payload_buffer_ = + internal::Buffer(context->header(), sizeof(internal::MessageHeaderV1), + sizeof(internal::MessageHeaderV1)); + transferable_ = true; + serialized_ = false; + } + + handle_ = std::move(handle); } +Message::~Message() = default; + Message& Message::operator=(Message&& other) { - Reset(); - std::swap(other.buffer_, buffer_); - std::swap(other.handles_, handles_); - std::swap(other.associated_endpoint_handles_, associated_endpoint_handles_); + handle_ = std::move(other.handle_); + payload_buffer_ = std::move(other.payload_buffer_); + handles_ = std::move(other.handles_); + associated_endpoint_handles_ = std::move(other.associated_endpoint_handles_); + transferable_ = other.transferable_; + other.transferable_ = false; + serialized_ = other.serialized_; + other.serialized_ = false; +#if defined(ENABLE_IPC_FUZZER) + interface_name_ = other.interface_name_; + method_name_ = other.method_name_; +#endif return *this; } void Message::Reset() { - CloseHandles(); + handle_.reset(); + payload_buffer_.Reset(); handles_.clear(); associated_endpoint_handles_.clear(); - buffer_.reset(); -} - -void Message::Initialize(size_t capacity, bool zero_initialized) { - DCHECK(!buffer_); - buffer_.reset(new internal::MessageBuffer(capacity, zero_initialized)); -} - -void Message::InitializeFromMojoMessage(ScopedMessageHandle message, - uint32_t num_bytes, - std::vector<Handle>* handles) { - DCHECK(!buffer_); - buffer_.reset(new internal::MessageBuffer(std::move(message), num_bytes)); - handles_.swap(*handles); + transferable_ = false; + serialized_ = false; } const uint8_t* Message::payload() const { if (version() < 2) return data() + header()->num_bytes; + DCHECK(!header_v2()->payload.is_null()); return static_cast<const uint8_t*>(header_v2()->payload.Get()); } @@ -89,19 +296,16 @@ uint32_t Message::payload_num_bytes() const { if (version() < 2) { num_bytes = data_num_bytes() - header()->num_bytes; } else { - auto payload = reinterpret_cast<uintptr_t>(header_v2()->payload.Get()); - if (!payload) { - num_bytes = 0; - } else { - auto payload_end = - reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get()); - if (!payload_end) - payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes()); - DCHECK_GE(payload_end, payload); - num_bytes = payload_end - payload; - } + auto payload_begin = + reinterpret_cast<uintptr_t>(header_v2()->payload.Get()); + auto payload_end = + reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get()); + if (!payload_end) + payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes()); + DCHECK_GE(payload_end, payload_begin); + num_bytes = payload_end - payload_begin; } - DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max()); + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(num_bytes)); return static_cast<uint32_t>(num_bytes); } @@ -117,52 +321,52 @@ const uint32_t* Message::payload_interface_ids() const { return array_pointer ? array_pointer->storage() : nullptr; } -ScopedMessageHandle Message::TakeMojoMessage() { - // If there are associated endpoints transferred, - // SerializeAssociatedEndpointHandles() must be called before this method. - DCHECK(associated_endpoint_handles_.empty()); +void Message::AttachHandlesFromSerializationContext( + internal::SerializationContext* context) { + if (context->handles()->empty() && + context->associated_endpoint_handles()->empty()) { + // No handles attached, so no extra serialization work. + return; + } - if (handles_.empty()) // Fast path for the common case: No handles. - return buffer_->TakeMessage(); + if (context->associated_endpoint_handles()->empty()) { + // Attaching only non-associated handles is easier since we don't have to + // modify the message header. Faster path for that. + payload_buffer_.AttachHandles(context->mutable_handles()); + return; + } - // Allocate a new message with space for the handles, then copy the buffer - // contents into it. + // Allocate a new message with enough space to hold all attached handles. Copy + // this message's contents into the new one and use it to replace ourself. // - // TODO(rockot): We could avoid this copy by extending GetSerializedSize() - // behavior to collect handles. It's unoptimized for now because it's much - // more common to have messages with no handles. - ScopedMessageHandle new_message; - MojoResult rv = AllocMessage( - data_num_bytes(), - handles_.empty() ? nullptr - : reinterpret_cast<const MojoHandle*>(handles_.data()), - handles_.size(), - MOJO_ALLOC_MESSAGE_FLAG_NONE, - &new_message); - CHECK_EQ(rv, MOJO_RESULT_OK); - handles_.clear(); - - void* new_buffer = nullptr; - rv = GetMessageBuffer(new_message.get(), &new_buffer); - CHECK_EQ(rv, MOJO_RESULT_OK); - - memcpy(new_buffer, data(), data_num_bytes()); - buffer_.reset(); - - return new_message; + // TODO(rockot): We could avoid the extra full message allocation by instead + // growing the buffer and carefully moving its contents around. This errs on + // the side of less complexity with probably only marginal performance cost. + uint32_t payload_size = payload_num_bytes(); + mojo::Message new_message(name(), header()->flags, payload_size, + context->associated_endpoint_handles()->size(), + context->mutable_handles()); + std::swap(*context->mutable_associated_endpoint_handles(), + new_message.associated_endpoint_handles_); + memcpy(new_message.payload_buffer()->AllocateAndGet(payload_size), payload(), + payload_size); + *this = std::move(new_message); } -void Message::NotifyBadMessage(const std::string& error) { - DCHECK(buffer_); - buffer_->NotifyBadMessage(error); +ScopedMessageHandle Message::TakeMojoMessage() { + // If there are associated endpoints transferred, + // SerializeAssociatedEndpointHandles() must be called before this method. + DCHECK(associated_endpoint_handles_.empty()); + DCHECK(transferable_); + payload_buffer_.Seal(); + auto handle = std::move(handle_); + Reset(); + return handle; } -void Message::CloseHandles() { - for (std::vector<Handle>::iterator it = handles_.begin(); - it != handles_.end(); ++it) { - if (it->is_valid()) - CloseRaw(*it); - } +void Message::NotifyBadMessage(const std::string& error) { + DCHECK(handle_.is_valid()); + mojo::NotifyBadMessage(handle_.get(), error); } void Message::SerializeAssociatedEndpointHandles( @@ -172,16 +376,20 @@ void Message::SerializeAssociatedEndpointHandles( DCHECK_GE(version(), 2u); DCHECK(header_v2()->payload_interface_ids.is_null()); + DCHECK(payload_buffer_.is_valid()); + DCHECK(handle_.is_valid()); size_t size = associated_endpoint_handles_.size(); - auto* data = internal::Array_Data<uint32_t>::New(size, buffer()); - header_v2()->payload_interface_ids.Set(data); + + internal::Array_Data<uint32_t>::BufferWriter handle_writer; + handle_writer.Allocate(size, &payload_buffer_); + header_v2()->payload_interface_ids.Set(handle_writer.data()); for (size_t i = 0; i < size; ++i) { ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i]; DCHECK(handle.pending_association()); - data->storage()[i] = + handle_writer->storage()[i] = group_controller->AssociateInterface(std::move(handle)); } associated_endpoint_handles_.clear(); @@ -189,6 +397,9 @@ void Message::SerializeAssociatedEndpointHandles( bool Message::DeserializeAssociatedEndpointHandles( AssociatedGroupController* group_controller) { + if (!serialized_) + return true; + associated_endpoint_handles_.clear(); uint32_t num_ids = payload_num_interface_ids(); @@ -213,11 +424,48 @@ bool Message::DeserializeAssociatedEndpointHandles( return result; } +void Message::SerializeIfNecessary() { + MojoResult rv = MojoSerializeMessage(handle_->value(), nullptr); + if (rv == MOJO_RESULT_FAILED_PRECONDITION) + return; + + // Reconstruct this Message instance from the serialized message's handle. + *this = Message(std::move(handle_)); +} + +std::unique_ptr<internal::UnserializedMessageContext> +Message::TakeUnserializedContext( + const internal::UnserializedMessageContext::Tag* tag) { + DCHECK(handle_.is_valid()); + uintptr_t context_value = 0; + MojoResult rv = + MojoGetMessageContext(handle_->value(), nullptr, &context_value); + if (rv == MOJO_RESULT_NOT_FOUND) + return nullptr; + DCHECK_EQ(MOJO_RESULT_OK, rv); + + auto* context = + reinterpret_cast<internal::UnserializedMessageContext*>(context_value); + if (context->tag() != tag) + return nullptr; + + // Detach the context from the message. + rv = MojoSetMessageContext(handle_->value(), 0, nullptr, nullptr, nullptr); + DCHECK_EQ(MOJO_RESULT_OK, rv); + return base::WrapUnique(context); +} + +bool MessageReceiver::PrefersSerializedMessages() { + return false; +} + PassThroughFilter::PassThroughFilter() {} PassThroughFilter::~PassThroughFilter() {} -bool PassThroughFilter::Accept(Message* message) { return true; } +bool PassThroughFilter::Accept(Message* message) { + return true; +} SyncMessageResponseContext::SyncMessageResponseContext() : outer_context_(current()) { @@ -238,43 +486,19 @@ void SyncMessageResponseContext::ReportBadMessage(const std::string& error) { GetBadMessageCallback().Run(error); } -const ReportBadMessageCallback& -SyncMessageResponseContext::GetBadMessageCallback() { - if (bad_message_callback_.is_null()) { - bad_message_callback_ = - base::Bind(&DoNotifyBadMessage, base::Passed(&response_)); - } - return bad_message_callback_; +ReportBadMessageCallback SyncMessageResponseContext::GetBadMessageCallback() { + DCHECK(!response_.IsNull()); + return base::BindOnce(&DoNotifyBadMessage, std::move(response_)); } MojoResult ReadMessage(MessagePipeHandle handle, Message* message) { - MojoResult rv; - - std::vector<Handle> handles; - ScopedMessageHandle mojo_message; - uint32_t num_bytes = 0, num_handles = 0; - rv = ReadMessageNew(handle, - &mojo_message, - &num_bytes, - nullptr, - &num_handles, - MOJO_READ_MESSAGE_FLAG_NONE); - if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) { - DCHECK_GT(num_handles, 0u); - handles.resize(num_handles); - rv = ReadMessageNew(handle, - &mojo_message, - &num_bytes, - reinterpret_cast<MojoHandle*>(handles.data()), - &num_handles, - MOJO_READ_MESSAGE_FLAG_NONE); - } - + ScopedMessageHandle message_handle; + MojoResult rv = + ReadMessageNew(handle, &message_handle, MOJO_READ_MESSAGE_FLAG_NONE); if (rv != MOJO_RESULT_OK) return rv; - message->InitializeFromMojoMessage( - std::move(mojo_message), num_bytes, &handles); + *message = Message(std::move(message_handle)); return MOJO_RESULT_OK; } @@ -311,13 +535,9 @@ MessageDispatchContext* MessageDispatchContext::current() { return g_tls_message_dispatch_context.Get().Get(); } -const ReportBadMessageCallback& -MessageDispatchContext::GetBadMessageCallback() { - if (bad_message_callback_.is_null()) { - bad_message_callback_ = - base::Bind(&DoNotifyBadMessage, base::Passed(message_)); - } - return bad_message_callback_; +ReportBadMessageCallback MessageDispatchContext::GetBadMessageCallback() { + DCHECK(!message_->IsNull()); + return base::BindOnce(&DoNotifyBadMessage, std::move(*message_)); } // static diff --git a/mojo/public/cpp/bindings/lib/message_buffer.cc b/mojo/public/cpp/bindings/lib/message_buffer.cc deleted file mode 100644 index cc12ef6e31..0000000000 --- a/mojo/public/cpp/bindings/lib/message_buffer.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/lib/message_buffer.h" - -#include <limits> - -#include "mojo/public/cpp/bindings/lib/serialization_util.h" - -namespace mojo { -namespace internal { - -MessageBuffer::MessageBuffer(size_t capacity, bool zero_initialized) { - DCHECK_LE(capacity, std::numeric_limits<uint32_t>::max()); - - MojoResult rv = AllocMessage(capacity, nullptr, 0, - MOJO_ALLOC_MESSAGE_FLAG_NONE, &message_); - CHECK_EQ(rv, MOJO_RESULT_OK); - - void* buffer = nullptr; - if (capacity != 0) { - rv = GetMessageBuffer(message_.get(), &buffer); - CHECK_EQ(rv, MOJO_RESULT_OK); - - if (zero_initialized) - memset(buffer, 0, capacity); - } - Initialize(buffer, capacity); -} - -MessageBuffer::MessageBuffer(ScopedMessageHandle message, uint32_t num_bytes) { - message_ = std::move(message); - - void* buffer = nullptr; - if (num_bytes != 0) { - MojoResult rv = GetMessageBuffer(message_.get(), &buffer); - CHECK_EQ(rv, MOJO_RESULT_OK); - } - Initialize(buffer, num_bytes); -} - -MessageBuffer::~MessageBuffer() {} - -void MessageBuffer::NotifyBadMessage(const std::string& error) { - DCHECK(message_.is_valid()); - MojoResult result = mojo::NotifyBadMessage(message_.get(), error); - DCHECK_EQ(result, MOJO_RESULT_OK); -} - -} // namespace internal -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_buffer.h b/mojo/public/cpp/bindings/lib/message_buffer.h deleted file mode 100644 index 96d5140f77..0000000000 --- a/mojo/public/cpp/bindings/lib/message_buffer.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_LIB_MESSAGE_BUFFER_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_LIB_MESSAGE_BUFFER_H_ - -#include <stdint.h> - -#include <utility> - -#include "base/macros.h" -#include "mojo/public/cpp/bindings/lib/buffer.h" -#include "mojo/public/cpp/system/message.h" - -namespace mojo { -namespace internal { - -// A fixed-size Buffer using a Mojo message object for storage. -class MessageBuffer : public Buffer { - public: - // Initializes this buffer to carry a fixed byte capacity and no handles. - MessageBuffer(size_t capacity, bool zero_initialized); - - // Initializes this buffer from an existing Mojo MessageHandle. - MessageBuffer(ScopedMessageHandle message, uint32_t num_bytes); - - ~MessageBuffer(); - - ScopedMessageHandle TakeMessage() { return std::move(message_); } - - void NotifyBadMessage(const std::string& error); - - private: - ScopedMessageHandle message_; - - DISALLOW_COPY_AND_ASSIGN(MessageBuffer); -}; - -} // namespace internal -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_LIB_MESSAGE_BUFFER_H_ diff --git a/mojo/public/cpp/bindings/lib/message_builder.cc b/mojo/public/cpp/bindings/lib/message_builder.cc deleted file mode 100644 index 6806a73213..0000000000 --- a/mojo/public/cpp/bindings/lib/message_builder.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/lib/message_builder.h" - -#include "mojo/public/cpp/bindings/lib/array_internal.h" -#include "mojo/public/cpp/bindings/lib/bindings_internal.h" -#include "mojo/public/cpp/bindings/lib/buffer.h" -#include "mojo/public/cpp/bindings/lib/message_internal.h" - -namespace mojo { -namespace internal { - -template <typename Header> -void Allocate(Buffer* buf, Header** header) { - *header = static_cast<Header*>(buf->Allocate(sizeof(Header))); - (*header)->num_bytes = sizeof(Header); -} - -MessageBuilder::MessageBuilder(uint32_t name, - uint32_t flags, - size_t payload_size, - size_t payload_interface_id_count) { - if (payload_interface_id_count > 0) { - // Version 2 - InitializeMessage( - sizeof(MessageHeaderV2) + Align(payload_size) + - ArrayDataTraits<uint32_t>::GetStorageSize( - static_cast<uint32_t>(payload_interface_id_count))); - - MessageHeaderV2* header; - Allocate(message_.buffer(), &header); - header->version = 2; - header->name = name; - header->flags = flags; - // The payload immediately follows the header. - header->payload.Set(header + 1); - } else if (flags & - (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { - // Version 1 - InitializeMessage(sizeof(MessageHeaderV1) + payload_size); - - MessageHeaderV1* header; - Allocate(message_.buffer(), &header); - header->version = 1; - header->name = name; - header->flags = flags; - } else { - InitializeMessage(sizeof(MessageHeader) + payload_size); - - MessageHeader* header; - Allocate(message_.buffer(), &header); - header->version = 0; - header->name = name; - header->flags = flags; - } -} - -MessageBuilder::~MessageBuilder() { -} - -void MessageBuilder::InitializeMessage(size_t size) { - message_.Initialize(static_cast<uint32_t>(Align(size)), - true /* zero_initialized */); -} - -} // namespace internal -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_builder.h b/mojo/public/cpp/bindings/lib/message_builder.h deleted file mode 100644 index 8a4d5c4690..0000000000 --- a/mojo/public/cpp/bindings/lib/message_builder.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_BUILDER_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_BUILDER_H_ - -#include <stddef.h> -#include <stdint.h> - -#include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" -#include "mojo/public/cpp/bindings/message.h" - -namespace mojo { - -class Message; - -namespace internal { - -class Buffer; - -class MOJO_CPP_BINDINGS_EXPORT MessageBuilder { - public: - MessageBuilder(uint32_t name, - uint32_t flags, - size_t payload_size, - size_t payload_interface_id_count); - ~MessageBuilder(); - - Buffer* buffer() { return message_.buffer(); } - Message* message() { return &message_; } - - private: - void InitializeMessage(size_t size); - - Message message_; - - DISALLOW_COPY_AND_ASSIGN(MessageBuilder); -}; - -} // namespace internal -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_BUILDER_H_ diff --git a/mojo/public/cpp/bindings/lib/message_dumper.cc b/mojo/public/cpp/bindings/lib/message_dumper.cc new file mode 100644 index 0000000000..35696bbcbf --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_dumper.cc @@ -0,0 +1,96 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/message_dumper.h" + +#include "base/files/file.h" +#include "base/files/file_path.h" +#include "base/files/file_util.h" +#include "base/logging.h" +#include "base/no_destructor.h" +#include "base/process/process.h" +#include "base/rand_util.h" +#include "base/strings/string_number_conversions.h" +#include "base/task_scheduler/post_task.h" +#include "mojo/public/cpp/bindings/message.h" + +namespace { + +base::FilePath& DumpDirectory() { + static base::NoDestructor<base::FilePath> dump_directory; + return *dump_directory; +} + +// void WriteMessage(uint32_t identifier, +// const mojo::MessageDumper::MessageEntry& entry) { +// static uint64_t num = 0; + +// if (!entry.interface_name) +// return; + +// base::FilePath message_directory = +// DumpDirectory() +// .AppendASCII(entry.interface_name) +// .AppendASCII(base::NumberToString(identifier)); + +// if (!base::DirectoryExists(message_directory) && +// !base::CreateDirectory(message_directory)) { +// LOG(ERROR) << "Failed to create" << message_directory.value(); +// return; +// } + +// std::string filename = +// base::NumberToString(num++) + "." + entry.method_name + ".mojomsg"; +// base::FilePath path = message_directory.AppendASCII(filename); +// base::File file(path, +// base::File::FLAG_WRITE | base::File::FLAG_CREATE_ALWAYS); + +// file.WriteAtCurrentPos(reinterpret_cast<const char*>(entry.data_bytes.data()), +// static_cast<int>(entry.data_bytes.size())); +// } + +} // namespace + +namespace mojo { + +MessageDumper::MessageEntry::MessageEntry(const uint8_t* data, + uint32_t data_size, + const char* interface_name, + const char* method_name) + : interface_name(interface_name), + method_name(method_name), + data_bytes(data, data + data_size) {} + +MessageDumper::MessageEntry::MessageEntry(const MessageEntry& entry) = default; + +MessageDumper::MessageEntry::~MessageEntry() {} + +MessageDumper::MessageDumper() : identifier_(base::RandUint64()) {} + +MessageDumper::~MessageDumper() {} + +bool MessageDumper::Accept(mojo::Message* message) { + // MessageEntry entry(message->data(), message->data_num_bytes(), + // "unknown interface", "unknown name"); + + // static base::NoDestructor<scoped_refptr<base::TaskRunner>> task_runner( + // base::CreateSequencedTaskRunnerWithTraits( + // {base::MayBlock(), base::TaskPriority::USER_BLOCKING, + // base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})); + + // (*task_runner) + // ->PostTask(FROM_HERE, + // base::BindOnce(&WriteMessage, identifier_, std::move(entry))); + return true; +} + +void MessageDumper::SetMessageDumpDirectory(const base::FilePath& directory) { + DumpDirectory() = directory; +} + +const base::FilePath& MessageDumper::GetMessageDumpDirectory() { + return DumpDirectory(); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_header_validator.cc b/mojo/public/cpp/bindings/lib/message_header_validator.cc index 9f8c6278c0..46bc5ed6e3 100644 --- a/mojo/public/cpp/bindings/lib/message_header_validator.cc +++ b/mojo/public/cpp/bindings/lib/message_header_validator.cc @@ -73,9 +73,10 @@ bool IsValidMessageHeader(const internal::MessageHeader* header, // payload size). // - Validation of the payload contents will be done separately based on the // payload type. - if (!header_v2->payload.is_null() && - (!internal::ValidatePointer(header_v2->payload, validation_context) || - !validation_context->ClaimMemory(header_v2->payload.Get(), 1))) { + if (!internal::ValidatePointerNonNullable(header_v2->payload, 5, + validation_context) || + !internal::ValidatePointer(header_v2->payload, validation_context) || + !validation_context->ClaimMemory(header_v2->payload.Get(), 1)) { return false; } @@ -115,6 +116,10 @@ void MessageHeaderValidator::SetDescription(const std::string& description) { } bool MessageHeaderValidator::Accept(Message* message) { + // Don't bother validating unserialized message headers. + if (!message->is_serialized()) + return true; + // Pass 0 as number of handles and associated endpoint handles because we // don't expect any in the header, even if |message| contains handles. internal::ValidationContext validation_context( diff --git a/mojo/public/cpp/bindings/lib/message_internal.cc b/mojo/public/cpp/bindings/lib/message_internal.cc new file mode 100644 index 0000000000..445eb4d891 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_internal.cc @@ -0,0 +1,45 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/message_internal.h" + +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/message.h" + +namespace mojo { +namespace internal { + +namespace { + +size_t ComputeHeaderSize(uint32_t flags, size_t payload_interface_id_count) { + if (payload_interface_id_count > 0) { + // Version 2 + return sizeof(MessageHeaderV2); + } else if (flags & + (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { + // Version 1 + return sizeof(MessageHeaderV1); + } else { + // Version 0 + return sizeof(MessageHeader); + } +} + +} // namespace + +size_t ComputeSerializedMessageSize(uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count) { + const size_t header_size = + ComputeHeaderSize(flags, payload_interface_id_count); + if (payload_interface_id_count > 0) { + return Align(header_size + Align(payload_size) + + ArrayDataTraits<uint32_t>::GetStorageSize( + static_cast<uint32_t>(payload_interface_id_count))); + } + return internal::Align(header_size + payload_size); +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_internal.h b/mojo/public/cpp/bindings/lib/message_internal.h index 6693198f81..40539e27aa 100644 --- a/mojo/public/cpp/bindings/lib/message_internal.h +++ b/mojo/public/cpp/bindings/lib/message_internal.h @@ -10,8 +10,8 @@ #include <string> #include "base/callback.h" +#include "base/component_export.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" namespace mojo { @@ -54,28 +54,32 @@ static_assert(sizeof(MessageHeaderV2) == 48, "Bad sizeof(MessageHeaderV2)"); #pragma pack(pop) -class MOJO_CPP_BINDINGS_EXPORT MessageDispatchContext { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) MessageDispatchContext { public: explicit MessageDispatchContext(Message* message); ~MessageDispatchContext(); static MessageDispatchContext* current(); - const base::Callback<void(const std::string&)>& GetBadMessageCallback(); + base::OnceCallback<void(const std::string&)> GetBadMessageCallback(); private: MessageDispatchContext* outer_context_; Message* message_; - base::Callback<void(const std::string&)> bad_message_callback_; DISALLOW_COPY_AND_ASSIGN(MessageDispatchContext); }; -class MOJO_CPP_BINDINGS_EXPORT SyncMessageResponseSetup { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) SyncMessageResponseSetup { public: static void SetCurrentSyncResponseMessage(Message* message); }; +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +size_t ComputeSerializedMessageSize(uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count); + } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/multiplex_router.cc b/mojo/public/cpp/bindings/lib/multiplex_router.cc index ff7c678289..61833097ef 100644 --- a/mojo/public/cpp/bindings/lib/multiplex_router.cc +++ b/mojo/public/cpp/bindings/lib/multiplex_router.cc @@ -12,14 +12,13 @@ #include "base/location.h" #include "base/macros.h" #include "base/memory/ptr_util.h" -#include "base/single_thread_task_runner.h" +#include "base/sequenced_task_runner.h" #include "base/stl_util.h" #include "base/synchronization/waitable_event.h" -#include "base/threading/thread_task_runner_handle.h" #include "mojo/public/cpp/bindings/interface_endpoint_client.h" #include "mojo/public/cpp/bindings/interface_endpoint_controller.h" #include "mojo/public/cpp/bindings/lib/may_auto_lock.h" -#include "mojo/public/cpp/bindings/sync_event_watcher.h" +#include "mojo/public/cpp/bindings/sequence_local_sync_event_watcher.h" namespace mojo { namespace internal { @@ -41,7 +40,7 @@ class MultiplexRouter::InterfaceEndpoint client_(nullptr) {} // --------------------------------------------------------------------------- - // The following public methods are safe to call from any threads without + // The following public methods are safe to call from any sequence without // locking. InterfaceId id() const { return id_; } @@ -76,29 +75,27 @@ class MultiplexRouter::InterfaceEndpoint disconnect_reason_ = disconnect_reason; } - base::SingleThreadTaskRunner* task_runner() const { - return task_runner_.get(); - } + base::SequencedTaskRunner* task_runner() const { return task_runner_.get(); } InterfaceEndpointClient* client() const { return client_; } void AttachClient(InterfaceEndpointClient* client, - scoped_refptr<base::SingleThreadTaskRunner> runner) { + scoped_refptr<base::SequencedTaskRunner> runner) { router_->AssertLockAcquired(); DCHECK(!client_); DCHECK(!closed_); - DCHECK(runner->BelongsToCurrentThread()); + DCHECK(runner->RunsTasksInCurrentSequence()); task_runner_ = std::move(runner); client_ = client; } - // This method must be called on the same thread as the corresponding + // This method must be called on the same sequence as the corresponding // AttachClient() call. void DetachClient() { router_->AssertLockAcquired(); DCHECK(client_); - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); DCHECK(!closed_); task_runner_ = nullptr; @@ -111,8 +108,8 @@ class MultiplexRouter::InterfaceEndpoint if (sync_message_event_signaled_) return; sync_message_event_signaled_ = true; - if (sync_message_event_) - sync_message_event_->Signal(); + if (sync_watcher_) + sync_watcher_->SignalEvent(); } void ResetSyncMessageSignal() { @@ -120,30 +117,30 @@ class MultiplexRouter::InterfaceEndpoint if (!sync_message_event_signaled_) return; sync_message_event_signaled_ = false; - if (sync_message_event_) - sync_message_event_->Reset(); + if (sync_watcher_) + sync_watcher_->ResetEvent(); } // --------------------------------------------------------------------------- // The following public methods (i.e., InterfaceEndpointController - // implementation) are called by the client on the same thread as the + // implementation) are called by the client on the same sequence as the // AttachClient() call. They are called outside of the router's lock. bool SendMessage(Message* message) override { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); message->set_interface_id(id_); return router_->connector_.Accept(message); } void AllowWokenUpBySyncWatchOnSameThread() override { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); EnsureSyncWatcherExists(); - sync_watcher_->AllowWokenUpBySyncWatchOnSameThread(); + sync_watcher_->AllowWokenUpBySyncWatchOnSameSequence(); } bool SyncWatch(const bool* should_stop) override { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); EnsureSyncWatcherExists(); return sync_watcher_->SyncWatch(should_stop); @@ -156,13 +153,10 @@ class MultiplexRouter::InterfaceEndpoint router_->AssertLockAcquired(); DCHECK(!client_); - DCHECK(closed_); - DCHECK(peer_closed_); - DCHECK(!sync_watcher_); } void OnSyncEventSignaled() { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); scoped_refptr<MultiplexRouter> router_protector(router_); MayAutoLock locker(&router_->lock_); @@ -184,28 +178,20 @@ class MultiplexRouter::InterfaceEndpoint } void EnsureSyncWatcherExists() { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); if (sync_watcher_) return; - { - MayAutoLock locker(&router_->lock_); - if (!sync_message_event_) { - sync_message_event_.emplace( - base::WaitableEvent::ResetPolicy::MANUAL, - base::WaitableEvent::InitialState::NOT_SIGNALED); - if (sync_message_event_signaled_) - sync_message_event_->Signal(); - } - } - sync_watcher_.reset( - new SyncEventWatcher(&sync_message_event_.value(), - base::Bind(&InterfaceEndpoint::OnSyncEventSignaled, - base::Unretained(this)))); + MayAutoLock locker(&router_->lock_); + sync_watcher_ = + std::make_unique<SequenceLocalSyncEventWatcher>(base::BindRepeating( + &InterfaceEndpoint::OnSyncEventSignaled, base::Unretained(this))); + if (sync_message_event_signaled_) + sync_watcher_->SignalEvent(); } // --------------------------------------------------------------------------- - // The following members are safe to access from any threads. + // The following members are safe to access from any sequence. MultiplexRouter* const router_; const InterfaceId id_; @@ -225,30 +211,22 @@ class MultiplexRouter::InterfaceEndpoint base::Optional<DisconnectReason> disconnect_reason_; // The task runner on which |client_|'s methods can be called. - scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; // Not owned. It is null if no client is attached to this endpoint. InterfaceEndpointClient* client_; - // An event used to signal that sync messages are available. The event is - // initialized under the router's lock and remains unchanged afterwards. It - // may be accessed outside of the router's lock later. - base::Optional<base::WaitableEvent> sync_message_event_; + // Indicates whether the sync watcher should be signaled for this endpoint. bool sync_message_event_signaled_ = false; - // --------------------------------------------------------------------------- - // The following members are only valid while a client is attached. They are - // used exclusively on the client's thread. They may be accessed outside of - // the router's lock. - - std::unique_ptr<SyncEventWatcher> sync_watcher_; + // Guarded by the router's lock. Used to synchronously wait on replies. + std::unique_ptr<SequenceLocalSyncEventWatcher> sync_watcher_; DISALLOW_COPY_AND_ASSIGN(InterfaceEndpoint); }; // MessageWrapper objects are always destroyed under the router's lock. On -// destruction, if the message it wrappers contains -// ScopedInterfaceEndpointHandles (which cannot be destructed under the -// router's lock), the wrapper unlocks to clean them up. +// destruction, if the message it wrappers contains interface IDs, the wrapper +// closes the corresponding endpoints. class MultiplexRouter::MessageWrapper { public: MessageWrapper() = default; @@ -260,14 +238,14 @@ class MultiplexRouter::MessageWrapper { : router_(other.router_), value_(std::move(other.value_)) {} ~MessageWrapper() { - if (value_.associated_endpoint_handles()->empty()) + if (!router_ || value_.IsNull()) return; router_->AssertLockAcquired(); - { - MayAutoUnlock unlocker(&router_->lock_); - value_.mutable_associated_endpoint_handles()->clear(); - } + // Don't try to close the endpoints if at this point the router is already + // half-destructed. + if (!router_->being_destructed_) + router_->CloseEndpointsForMessage(value_); } MessageWrapper& operator=(MessageWrapper&& other) { @@ -276,7 +254,21 @@ class MultiplexRouter::MessageWrapper { return *this; } - Message& value() { return value_; } + const Message& value() const { return value_; } + + // Must be called outside of the router's lock. + // Returns a null message if it fails to deseralize the associated endpoint + // handles. + Message DeserializeEndpointHandlesAndTake() { + if (!value_.DeserializeAssociatedEndpointHandles(router_)) { + // The previous call may have deserialized part of the associated + // interface endpoint handles. They must be destroyed outside of the + // router's lock, so we cannot wait until destruction of MessageWrapper. + value_.Reset(); + return Message(); + } + return std::move(value_); + } private: MultiplexRouter* router_ = nullptr; @@ -322,23 +314,17 @@ MultiplexRouter::MultiplexRouter( ScopedMessagePipeHandle message_pipe, Config config, bool set_interface_id_namesapce_bit, - scoped_refptr<base::SingleThreadTaskRunner> runner) + scoped_refptr<base::SequencedTaskRunner> runner) : set_interface_id_namespace_bit_(set_interface_id_namesapce_bit), task_runner_(runner), - header_validator_(nullptr), filters_(this), connector_(std::move(message_pipe), config == MULTI_INTERFACE ? Connector::MULTI_THREADED_SEND : Connector::SINGLE_THREADED_SEND, std::move(runner)), control_message_handler_(this), - control_message_proxy_(&connector_), - next_interface_id_value_(1), - posted_to_process_tasks_(false), - encountered_error_(false), - paused_(false), - testing_mode_(false) { - DCHECK(task_runner_->BelongsToCurrentThread()); + control_message_proxy_(&connector_) { + DCHECK(task_runner_->RunsTasksInCurrentSequence()); if (config == MULTI_INTERFACE) lock_.emplace(); @@ -348,16 +334,15 @@ MultiplexRouter::MultiplexRouter( // Always participate in sync handle watching in multi-interface mode, // because even if it doesn't expect sync requests during sync handle // watching, it may still need to dispatch messages to associated endpoints - // on a different thread. + // on a different sequence. connector_.AllowWokenUpBySyncWatchOnSameThread(); } connector_.set_incoming_receiver(&filters_); - connector_.set_connection_error_handler( - base::Bind(&MultiplexRouter::OnPipeConnectionError, - base::Unretained(this))); + connector_.set_connection_error_handler(base::Bind( + &MultiplexRouter::OnPipeConnectionError, base::Unretained(this))); std::unique_ptr<MessageHeaderValidator> header_validator = - base::MakeUnique<MessageHeaderValidator>(); + std::make_unique<MessageHeaderValidator>(); header_validator_ = header_validator.get(); filters_.Append(std::move(header_validator)); } @@ -365,33 +350,22 @@ MultiplexRouter::MultiplexRouter( MultiplexRouter::~MultiplexRouter() { MayAutoLock locker(&lock_); + being_destructed_ = true; + sync_message_tasks_.clear(); tasks_.clear(); + endpoints_.clear(); +} - for (auto iter = endpoints_.begin(); iter != endpoints_.end();) { - InterfaceEndpoint* endpoint = iter->second.get(); - // Increment the iterator before calling UpdateEndpointStateMayRemove() - // because it may remove the corresponding value from the map. - ++iter; - - if (!endpoint->closed()) { - // This happens when a NotifyPeerEndpointClosed message been received, but - // the interface ID hasn't been used to create local endpoint handle. - DCHECK(!endpoint->client()); - DCHECK(endpoint->peer_closed()); - UpdateEndpointStateMayRemove(endpoint, ENDPOINT_CLOSED); - } else { - UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); - } - } - - DCHECK(endpoints_.empty()); +void MultiplexRouter::AddIncomingMessageFilter( + std::unique_ptr<MessageReceiver> filter) { + filters_.Append(std::move(filter)); } void MultiplexRouter::SetMasterInterfaceName(const char* name) { - DCHECK(thread_checker_.CalledOnValidThread()); - header_validator_->SetDescription( - std::string(name) + " [master] MessageHeaderValidator"); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + header_validator_->SetDescription(std::string(name) + + " [master] MessageHeaderValidator"); control_message_handler_.SetDescription( std::string(name) + " [master] PipeControlMessageHandler"); connector_.SetWatcherHeapProfilerTag(name); @@ -445,17 +419,10 @@ ScopedInterfaceEndpointHandle MultiplexRouter::CreateLocalEndpointHandle( bool inserted = false; InterfaceEndpoint* endpoint = FindOrInsertEndpoint(id, &inserted); if (inserted) { - DCHECK(!endpoint->handle_created()); - if (encountered_error_) UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); } else { - // If the endpoint already exist, it is because we have received a - // notification that the peer endpoint has closed. - CHECK(!endpoint->closed()); - CHECK(endpoint->peer_closed()); - - if (endpoint->handle_created()) + if (endpoint->handle_created() || endpoint->closed()) return ScopedInterfaceEndpointHandle(); } @@ -487,7 +454,7 @@ void MultiplexRouter::CloseEndpointHandle( InterfaceEndpointController* MultiplexRouter::AttachEndpointClient( const ScopedInterfaceEndpointHandle& handle, InterfaceEndpointClient* client, - scoped_refptr<base::SingleThreadTaskRunner> runner) { + scoped_refptr<base::SequencedTaskRunner> runner) { const InterfaceId id = handle.id(); DCHECK(IsValidInterfaceId(id)); @@ -520,7 +487,7 @@ void MultiplexRouter::DetachEndpointClient( } void MultiplexRouter::RaiseError() { - if (task_runner_->BelongsToCurrentThread()) { + if (task_runner_->RunsTasksInCurrentSequence()) { connector_.RaiseError(); } else { task_runner_->PostTask(FROM_HERE, @@ -528,8 +495,13 @@ void MultiplexRouter::RaiseError() { } } +bool MultiplexRouter::PrefersSerializedMessages() { + MayAutoLock locker(&lock_); + return connector_.PrefersSerializedMessages(); +} + void MultiplexRouter::CloseMessagePipe() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); connector_.CloseMessagePipe(); // CloseMessagePipe() above won't trigger connection error handler. // Explicitly call OnPipeConnectionError() so that associated endpoints will @@ -538,7 +510,7 @@ void MultiplexRouter::CloseMessagePipe() { } void MultiplexRouter::PauseIncomingMethodCallProcessing() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); connector_.PauseIncomingMethodCallProcessing(); MayAutoLock locker(&lock_); @@ -549,7 +521,7 @@ void MultiplexRouter::PauseIncomingMethodCallProcessing() { } void MultiplexRouter::ResumeIncomingMethodCallProcessing() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); connector_.ResumeIncomingMethodCallProcessing(); MayAutoLock locker(&lock_); @@ -568,7 +540,7 @@ void MultiplexRouter::ResumeIncomingMethodCallProcessing() { } bool MultiplexRouter::HasAssociatedEndpoints() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); MayAutoLock locker(&lock_); if (endpoints_.size() > 1) @@ -580,7 +552,7 @@ bool MultiplexRouter::HasAssociatedEndpoints() const { } void MultiplexRouter::EnableTestingMode() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); MayAutoLock locker(&lock_); testing_mode_ = true; @@ -588,9 +560,19 @@ void MultiplexRouter::EnableTestingMode() { } bool MultiplexRouter::Accept(Message* message) { - DCHECK(thread_checker_.CalledOnValidThread()); - - if (!message->DeserializeAssociatedEndpointHandles(this)) + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + // Insert endpoints for the payload interface IDs as soon as the message + // arrives, instead of waiting till the message is dispatched. Consider the + // following sequence: + // 1) Async message msg1 arrives, containing interface ID x. Msg1 is not + // dispatched because a sync call is blocking the thread. + // 2) Sync message msg2 arrives targeting interface ID x. + // + // If we don't insert endpoint for interface ID x, when trying to dispatch + // msg2 we don't know whether it is an unexpected message or it is just + // because the message containing x hasn't been dispatched. + if (!InsertEndpointsForMessage(*message)) return false; scoped_refptr<MultiplexRouter> protector(this); @@ -603,15 +585,15 @@ bool MultiplexRouter::Accept(Message* message) { ? ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES : ALLOW_DIRECT_CLIENT_CALLS; - bool processed = - tasks_.empty() && ProcessIncomingMessage(message, client_call_behavior, - connector_.task_runner()); + MessageWrapper message_wrapper(this, std::move(*message)); + bool processed = tasks_.empty() && ProcessIncomingMessage( + &message_wrapper, client_call_behavior, + connector_.task_runner()); if (!processed) { // Either the task queue is not empty or we cannot process the message // directly. In both cases, there is no need to call ProcessTasks(). - tasks_.push_back( - Task::CreateMessageTask(MessageWrapper(this, std::move(*message)))); + tasks_.push_back(Task::CreateMessageTask(std::move(message_wrapper))); Task* task = tasks_.back().get(); if (task->message_wrapper.value().has_flag(Message::kFlagIsSync)) { @@ -636,8 +618,6 @@ bool MultiplexRouter::Accept(Message* message) { bool MultiplexRouter::OnPeerAssociatedEndpointClosed( InterfaceId id, const base::Optional<DisconnectReason>& reason) { - DCHECK(!IsMasterInterfaceId(id) || reason); - MayAutoLock locker(&lock_); InterfaceEndpoint* endpoint = FindOrInsertEndpoint(id, nullptr); @@ -662,23 +642,26 @@ bool MultiplexRouter::OnPeerAssociatedEndpointClosed( } void MultiplexRouter::OnPipeConnectionError() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); scoped_refptr<MultiplexRouter> protector(this); MayAutoLock locker(&lock_); encountered_error_ = true; - for (auto iter = endpoints_.begin(); iter != endpoints_.end();) { - InterfaceEndpoint* endpoint = iter->second.get(); - // Increment the iterator before calling UpdateEndpointStateMayRemove() - // because it may remove the corresponding value from the map. - ++iter; + // Calling UpdateEndpointStateMayRemove() may remove the corresponding value + // from |endpoints_| and invalidate any iterator of |endpoints_|. Therefore, + // copy the endpoint pointers to a vector and iterate over it instead. + std::vector<scoped_refptr<InterfaceEndpoint>> endpoint_vector; + endpoint_vector.reserve(endpoints_.size()); + for (const auto& pair : endpoints_) + endpoint_vector.push_back(pair.second); + for (const auto& endpoint : endpoint_vector) { if (endpoint->client()) - tasks_.push_back(Task::CreateNotifyErrorTask(endpoint)); + tasks_.push_back(Task::CreateNotifyErrorTask(endpoint.get())); - UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); + UpdateEndpointStateMayRemove(endpoint.get(), PEER_ENDPOINT_CLOSED); } ProcessTasks(connector_.during_sync_handle_watcher_callback() @@ -689,7 +672,7 @@ void MultiplexRouter::OnPipeConnectionError() { void MultiplexRouter::ProcessTasks( ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner) { + base::SequencedTaskRunner* current_task_runner) { AssertLockAcquired(); if (posted_to_process_tasks_) @@ -714,7 +697,7 @@ void MultiplexRouter::ProcessTasks( task->IsNotifyErrorTask() ? ProcessNotifyErrorTask(task.get(), client_call_behavior, current_task_runner) - : ProcessIncomingMessage(&task->message_wrapper.value(), + : ProcessIncomingMessage(&task->message_wrapper, client_call_behavior, current_task_runner); if (!processed) { @@ -752,8 +735,7 @@ bool MultiplexRouter::ProcessFirstSyncMessageForEndpoint(InterfaceId id) { // Note: after this call, |task| and |iter| may be invalidated. bool processed = ProcessIncomingMessage( - &message_wrapper.value(), ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES, - nullptr); + &message_wrapper, ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES, nullptr); DCHECK(processed); iter = sync_message_tasks_.find(id); @@ -771,8 +753,9 @@ bool MultiplexRouter::ProcessFirstSyncMessageForEndpoint(InterfaceId id) { bool MultiplexRouter::ProcessNotifyErrorTask( Task* task, ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner) { - DCHECK(!current_task_runner || current_task_runner->BelongsToCurrentThread()); + base::SequencedTaskRunner* current_task_runner) { + DCHECK(!current_task_runner || + current_task_runner->RunsTasksInCurrentSequence()); DCHECK(!paused_); AssertLockAcquired(); @@ -786,7 +769,7 @@ bool MultiplexRouter::ProcessNotifyErrorTask( return false; } - DCHECK(endpoint->task_runner()->BelongsToCurrentThread()); + DCHECK(endpoint->task_runner()->RunsTasksInCurrentSequence()); InterfaceEndpointClient* client = endpoint->client(); base::Optional<DisconnectReason> disconnect_reason( @@ -797,7 +780,7 @@ bool MultiplexRouter::ProcessNotifyErrorTask( // object within NotifyError(). Holding the lock will lead to deadlock. // // It is safe to call into |client| without the lock. Because |client| is - // always accessed on the same thread, including DetachEndpointClient(). + // always accessed on the same sequence, including DetachEndpointClient(). MayAutoUnlock unlocker(&lock_); client->NotifyError(disconnect_reason); } @@ -805,14 +788,16 @@ bool MultiplexRouter::ProcessNotifyErrorTask( } bool MultiplexRouter::ProcessIncomingMessage( - Message* message, + MessageWrapper* message_wrapper, ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner) { - DCHECK(!current_task_runner || current_task_runner->BelongsToCurrentThread()); + base::SequencedTaskRunner* current_task_runner) { + DCHECK(!current_task_runner || + current_task_runner->RunsTasksInCurrentSequence()); DCHECK(!paused_); - DCHECK(message); + DCHECK(message_wrapper); AssertLockAcquired(); + const Message* message = &message_wrapper->value(); if (message->IsNull()) { // This is a sync message and has been processed during sync handle // watching. @@ -824,7 +809,10 @@ bool MultiplexRouter::ProcessIncomingMessage( { MayAutoUnlock unlocker(&lock_); - result = control_message_handler_.Accept(message); + Message tmp_message = + message_wrapper->DeserializeEndpointHandlesAndTake(); + result = !tmp_message.IsNull() && + control_message_handler_.Accept(&tmp_message); } if (!result) @@ -849,7 +837,7 @@ bool MultiplexRouter::ProcessIncomingMessage( bool can_direct_call; if (message->has_flag(Message::kFlagIsSync)) { can_direct_call = client_call_behavior != NO_DIRECT_CLIENT_CALLS && - endpoint->task_runner()->BelongsToCurrentThread(); + endpoint->task_runner()->RunsTasksInCurrentSequence(); } else { can_direct_call = client_call_behavior == ALLOW_DIRECT_CLIENT_CALLS && endpoint->task_runner() == current_task_runner; @@ -860,7 +848,7 @@ bool MultiplexRouter::ProcessIncomingMessage( return false; } - DCHECK(endpoint->task_runner()->BelongsToCurrentThread()); + DCHECK(endpoint->task_runner()->RunsTasksInCurrentSequence()); InterfaceEndpointClient* client = endpoint->client(); bool result = false; @@ -870,9 +858,11 @@ bool MultiplexRouter::ProcessIncomingMessage( // deadlock. // // It is safe to call into |client| without the lock. Because |client| is - // always accessed on the same thread, including DetachEndpointClient(). + // always accessed on the same sequence, including DetachEndpointClient(). MayAutoUnlock unlocker(&lock_); - result = client->HandleIncomingMessage(message); + Message tmp_message = message_wrapper->DeserializeEndpointHandlesAndTake(); + result = + !tmp_message.IsNull() && client->HandleIncomingMessage(&tmp_message); } if (!result) RaiseErrorInNonTestingMode(); @@ -881,7 +871,7 @@ bool MultiplexRouter::ProcessIncomingMessage( } void MultiplexRouter::MaybePostToProcessTasks( - base::SingleThreadTaskRunner* task_runner) { + base::SequencedTaskRunner* task_runner) { AssertLockAcquired(); if (posted_to_process_tasks_) return; @@ -897,7 +887,7 @@ void MultiplexRouter::LockAndCallProcessTasks() { // always called using base::Bind(), which holds a ref. MayAutoLock locker(&lock_); posted_to_process_tasks_ = false; - scoped_refptr<base::SingleThreadTaskRunner> runner( + scoped_refptr<base::SequencedTaskRunner> runner( std::move(posted_to_task_runner_)); ProcessTasks(ALLOW_DIRECT_CLIENT_CALLS, runner.get()); } @@ -956,5 +946,67 @@ void MultiplexRouter::AssertLockAcquired() { #endif } +bool MultiplexRouter::InsertEndpointsForMessage(const Message& message) { + if (!message.is_serialized()) + return true; + + uint32_t num_ids = message.payload_num_interface_ids(); + if (num_ids == 0) + return true; + + const uint32_t* ids = message.payload_interface_ids(); + + MayAutoLock locker(&lock_); + for (uint32_t i = 0; i < num_ids; ++i) { + // Message header validation already ensures that the IDs are valid and not + // the master ID. + // The IDs are from the remote side and therefore their namespace bit is + // supposed to be different than the value that this router would use. + if (set_interface_id_namespace_bit_ == + HasInterfaceIdNamespaceBitSet(ids[i])) { + return false; + } + + // It is possible that the endpoint already exists even when the remote side + // is well-behaved: it might have notified us that the peer endpoint has + // closed. + bool inserted = false; + InterfaceEndpoint* endpoint = FindOrInsertEndpoint(ids[i], &inserted); + if (endpoint->closed() || endpoint->handle_created()) + return false; + } + + return true; +} + +void MultiplexRouter::CloseEndpointsForMessage(const Message& message) { + AssertLockAcquired(); + + if (!message.is_serialized()) + return; + + uint32_t num_ids = message.payload_num_interface_ids(); + if (num_ids == 0) + return; + + const uint32_t* ids = message.payload_interface_ids(); + for (uint32_t i = 0; i < num_ids; ++i) { + InterfaceEndpoint* endpoint = FindEndpoint(ids[i]); + // If the remote side maliciously sends the same interface ID in another + // message which has been dispatched, we could get here with no endpoint + // for the ID, a closed endpoint, or an endpoint with handle created. + if (!endpoint || endpoint->closed() || endpoint->handle_created()) { + RaiseErrorInNonTestingMode(); + continue; + } + + UpdateEndpointStateMayRemove(endpoint, ENDPOINT_CLOSED); + MayAutoUnlock unlocker(&lock_); + control_message_proxy_.NotifyPeerEndpointClosed(ids[i], base::nullopt); + } + + ProcessTasks(NO_DIRECT_CLIENT_CALLS, nullptr); +} + } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/multiplex_router.h b/mojo/public/cpp/bindings/lib/multiplex_router.h index cac138bcb7..8c2e7c8b0f 100644 --- a/mojo/public/cpp/bindings/lib/multiplex_router.h +++ b/mojo/public/cpp/bindings/lib/multiplex_router.h @@ -7,20 +7,21 @@ #include <stdint.h> -#include <deque> #include <map> #include <memory> #include <string> #include "base/compiler_specific.h" +#include "base/containers/queue.h" +#include "base/containers/small_map.h" #include "base/logging.h" #include "base/macros.h" #include "base/memory/ref_counted.h" #include "base/memory/weak_ptr.h" #include "base/optional.h" -#include "base/single_thread_task_runner.h" +#include "base/sequence_checker.h" +#include "base/sequenced_task_runner.h" #include "base/synchronization/lock.h" -#include "base/threading/thread_checker.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/connector.h" @@ -33,7 +34,7 @@ #include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" namespace base { -class SingleThreadTaskRunner; +class SequencedTaskRunner; } namespace mojo { @@ -43,19 +44,19 @@ namespace internal { // MultiplexRouter supports routing messages for multiple interfaces over a // single message pipe. // -// It is created on the thread where the master interface of the message pipe +// It is created on the sequence where the master interface of the message pipe // lives. Although it is ref-counted, it is guarateed to be destructed on the -// same thread. -// Some public methods are only allowed to be called on the creating thread; -// while the others are safe to call from any threads. Please see the method +// same sequence. +// Some public methods are only allowed to be called on the creating sequence; +// while the others are safe to call from any sequence. Please see the method // comments for more details. // // NOTE: CloseMessagePipe() or PassMessagePipe() MUST be called on |runner|'s -// thread before this object is destroyed. +// sequence before this object is destroyed. class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter - : NON_EXPORTED_BASE(public MessageReceiver), + : public MessageReceiver, public AssociatedGroupController, - NON_EXPORTED_BASE(public PipeControlMessageHandlerDelegate) { + public PipeControlMessageHandlerDelegate { public: enum Config { // There is only the master interface running on this router. Please note @@ -76,7 +77,11 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter MultiplexRouter(ScopedMessagePipeHandle message_pipe, Config config, bool set_interface_id_namespace_bit, - scoped_refptr<base::SingleThreadTaskRunner> runner); + scoped_refptr<base::SequencedTaskRunner> runner); + + // Adds a MessageReceiver which can filter a message after validation but + // before dispatch. + void AddIncomingMessageFilter(std::unique_ptr<MessageReceiver> filter); // Sets the master interface name for this router. Only used when reporting // message header or control message validation errors. @@ -84,7 +89,7 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter void SetMasterInterfaceName(const char* name); // --------------------------------------------------------------------------- - // The following public methods are safe to call from any threads. + // The following public methods are safe to call from any sequence. // AssociatedGroupController implementation: InterfaceId AssociateInterface( @@ -97,13 +102,14 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter InterfaceEndpointController* AttachEndpointClient( const ScopedInterfaceEndpointHandle& handle, InterfaceEndpointClient* endpoint_client, - scoped_refptr<base::SingleThreadTaskRunner> runner) override; + scoped_refptr<base::SequencedTaskRunner> runner) override; void DetachEndpointClient( const ScopedInterfaceEndpointHandle& handle) override; void RaiseError() override; + bool PrefersSerializedMessages() override; // --------------------------------------------------------------------------- - // The following public methods are called on the creating thread. + // The following public methods are called on the creating sequence. // Please note that this method shouldn't be called unless it results from an // explicit request of the user of bindings (e.g., the user sets an @@ -112,14 +118,15 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter // Extracts the underlying message pipe. ScopedMessagePipeHandle PassMessagePipe() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(!HasAssociatedEndpoints()); return connector_.PassMessagePipe(); } - // Blocks the current thread until the first incoming message, or |deadline|. + // Blocks the current sequence until the first incoming message, or + // |deadline|. bool WaitForIncomingMessage(MojoDeadline deadline) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return connector_.WaitForIncomingMessage(deadline); } @@ -137,13 +144,13 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter // Is the router bound to a message pipe handle? bool is_valid() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return connector_.is_valid(); } // TODO(yzshen): consider removing this getter. MessagePipeHandle handle() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return connector_.handle(); } @@ -169,7 +176,7 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter void OnPipeConnectionError(); // Specifies whether we are allowed to directly call into - // InterfaceEndpointClient (given that we are already on the same thread as + // InterfaceEndpointClient (given that we are already on the same sequence as // the client). enum ClientCallBehavior { // Don't call any InterfaceEndpointClient methods directly. @@ -191,7 +198,7 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter // of this object, if direct calls are allowed, the caller needs to hold on to // a ref outside of |lock_| before calling this method. void ProcessTasks(ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner); + base::SequencedTaskRunner* current_task_runner); // Processes the first queued sync message for the endpoint corresponding to // |id|; returns whether there are more sync messages for that endpoint in the @@ -202,16 +209,14 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter bool ProcessFirstSyncMessageForEndpoint(InterfaceId id); // Returns true to indicate that |task|/|message| has been processed. - bool ProcessNotifyErrorTask( - Task* task, - ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner); - bool ProcessIncomingMessage( - Message* message, - ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner); - - void MaybePostToProcessTasks(base::SingleThreadTaskRunner* task_runner); + bool ProcessNotifyErrorTask(Task* task, + ClientCallBehavior client_call_behavior, + base::SequencedTaskRunner* current_task_runner); + bool ProcessIncomingMessage(MessageWrapper* message_wrapper, + ClientCallBehavior client_call_behavior, + base::SequencedTaskRunner* current_task_runner); + + void MaybePostToProcessTasks(base::SequencedTaskRunner* task_runner); void LockAndCallProcessTasks(); // Updates the state of |endpoint|. If both the endpoint and its peer have @@ -226,21 +231,25 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter InterfaceEndpoint* FindOrInsertEndpoint(InterfaceId id, bool* inserted); InterfaceEndpoint* FindEndpoint(InterfaceId id); + // Returns false if some interface IDs are invalid or have been used. + bool InsertEndpointsForMessage(const Message& message); + void CloseEndpointsForMessage(const Message& message); + void AssertLockAcquired(); // Whether to set the namespace bit when generating interface IDs. Please see // comments of kInterfaceIdNamespaceMask. const bool set_interface_id_namespace_bit_; - scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; // Owned by |filters_| below. - MessageHeaderValidator* header_validator_; + MessageHeaderValidator* header_validator_ = nullptr; FilterChain filters_; Connector connector_; - base::ThreadChecker thread_checker_; + SEQUENCE_CHECKER(sequence_checker_); // Protects the following members. // Not set in Config::SINGLE_INTERFACE* mode. @@ -250,21 +259,24 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter // NOTE: It is unsafe to call into this object while holding |lock_|. PipeControlMessageProxy control_message_proxy_; - std::map<InterfaceId, scoped_refptr<InterfaceEndpoint>> endpoints_; - uint32_t next_interface_id_value_; + base::small_map<std::map<InterfaceId, scoped_refptr<InterfaceEndpoint>>, 1> + endpoints_; + uint32_t next_interface_id_value_ = 1; - std::deque<std::unique_ptr<Task>> tasks_; + base::circular_deque<std::unique_ptr<Task>> tasks_; // It refers to tasks in |tasks_| and doesn't own any of them. - std::map<InterfaceId, std::deque<Task*>> sync_message_tasks_; + std::map<InterfaceId, base::circular_deque<Task*>> sync_message_tasks_; + + bool posted_to_process_tasks_ = false; + scoped_refptr<base::SequencedTaskRunner> posted_to_task_runner_; - bool posted_to_process_tasks_; - scoped_refptr<base::SingleThreadTaskRunner> posted_to_task_runner_; + bool encountered_error_ = false; - bool encountered_error_; + bool paused_ = false; - bool paused_; + bool testing_mode_ = false; - bool testing_mode_; + bool being_destructed_ = false; DISALLOW_COPY_AND_ASSIGN(MultiplexRouter); }; diff --git a/mojo/public/cpp/bindings/lib/native_struct.cc b/mojo/public/cpp/bindings/lib/native_struct.cc deleted file mode 100644 index 7b1a1a6c59..0000000000 --- a/mojo/public/cpp/bindings/lib/native_struct.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/native_struct.h" - -#include "mojo/public/cpp/bindings/lib/hash_util.h" - -namespace mojo { - -// static -NativeStructPtr NativeStruct::New() { - return NativeStructPtr(base::in_place); -} - -NativeStruct::NativeStruct() {} - -NativeStruct::~NativeStruct() {} - -NativeStructPtr NativeStruct::Clone() const { - NativeStructPtr rv(New()); - rv->data = data; - return rv; -} - -bool NativeStruct::Equals(const NativeStruct& other) const { - return data == other.data; -} - -size_t NativeStruct::Hash(size_t seed) const { - return internal::Hash(seed, data); -} - -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/native_struct_data.cc b/mojo/public/cpp/bindings/lib/native_struct_data.cc deleted file mode 100644 index 0e5d245692..0000000000 --- a/mojo/public/cpp/bindings/lib/native_struct_data.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/lib/native_struct_data.h" - -#include "mojo/public/cpp/bindings/lib/buffer.h" -#include "mojo/public/cpp/bindings/lib/validation_context.h" - -namespace mojo { -namespace internal { - -// static -bool NativeStruct_Data::Validate(const void* data, - ValidationContext* validation_context) { - const ContainerValidateParams data_validate_params(0, false, nullptr); - return Array_Data<uint8_t>::Validate(data, validation_context, - &data_validate_params); -} - -} // namespace internal -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/native_struct_data.h b/mojo/public/cpp/bindings/lib/native_struct_data.h deleted file mode 100644 index 1c7cd81c77..0000000000 --- a/mojo/public/cpp/bindings/lib/native_struct_data.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_DATA_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_DATA_H_ - -#include <vector> - -#include "mojo/public/cpp/bindings/bindings_export.h" -#include "mojo/public/cpp/bindings/lib/array_internal.h" -#include "mojo/public/cpp/system/handle.h" - -namespace mojo { -namespace internal { - -class ValidationContext; - -class MOJO_CPP_BINDINGS_EXPORT NativeStruct_Data { - public: - static bool Validate(const void* data, ValidationContext* validation_context); - - // Unlike normal structs, the memory layout is exactly the same as an array - // of uint8_t. - Array_Data<uint8_t> data; - - private: - NativeStruct_Data() = delete; - ~NativeStruct_Data() = delete; -}; - -static_assert(sizeof(Array_Data<uint8_t>) == sizeof(NativeStruct_Data), - "Mismatched NativeStruct_Data and Array_Data<uint8_t> size"); - -} // namespace internal -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_DATA_H_ diff --git a/mojo/public/cpp/bindings/lib/native_struct_serialization.cc b/mojo/public/cpp/bindings/lib/native_struct_serialization.cc index fa0dbf3803..283080089f 100644 --- a/mojo/public/cpp/bindings/lib/native_struct_serialization.cc +++ b/mojo/public/cpp/bindings/lib/native_struct_serialization.cc @@ -4,56 +4,120 @@ #include "mojo/public/cpp/bindings/lib/native_struct_serialization.h" +#include "ipc/ipc_message_attachment.h" +#include "ipc/ipc_message_attachment_set.h" +#include "ipc/native_handle_type_converters.h" #include "mojo/public/cpp/bindings/lib/serialization.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" namespace mojo { namespace internal { // static -size_t UnmappedNativeStructSerializerImpl::PrepareToSerialize( - const NativeStructPtr& input, +void UnmappedNativeStructSerializerImpl::Serialize( + const native::NativeStructPtr& input, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, SerializationContext* context) { if (!input) - return 0; - return internal::PrepareToSerialize<ArrayDataView<uint8_t>>(input->data, - context); + return; + + writer->Allocate(buffer); + + Array_Data<uint8_t>::BufferWriter data_writer; + const mojo::internal::ContainerValidateParams data_validate_params(0, false, + nullptr); + mojo::internal::Serialize<ArrayDataView<uint8_t>>( + input->data, buffer, &data_writer, &data_validate_params, context); + writer->data()->data.Set(data_writer.data()); + + mojo::internal::Array_Data<mojo::internal::Pointer< + native::internal::SerializedHandle_Data>>::BufferWriter handles_writer; + const mojo::internal::ContainerValidateParams handles_validate_params( + 0, false, nullptr); + mojo::internal::Serialize< + mojo::ArrayDataView<::mojo::native::SerializedHandleDataView>>( + input->handles, buffer, &handles_writer, &handles_validate_params, + context); + writer->data()->handles.Set(handles_writer.is_null() ? nullptr + : handles_writer.data()); } // static -void UnmappedNativeStructSerializerImpl::Serialize( - const NativeStructPtr& input, - Buffer* buffer, - NativeStruct_Data** output, +bool UnmappedNativeStructSerializerImpl::Deserialize( + native::internal::NativeStruct_Data* input, + native::NativeStructPtr* output, SerializationContext* context) { if (!input) { - *output = nullptr; - return; + output->reset(); + return true; } - Array_Data<uint8_t>* data = nullptr; - const ContainerValidateParams params(0, false, nullptr); - internal::Serialize<ArrayDataView<uint8_t>>(input->data, buffer, &data, - ¶ms, context); - *output = reinterpret_cast<NativeStruct_Data*>(data); + native::NativeStructDataView data_view(input, context); + return StructTraits<::mojo::native::NativeStructDataView, + native::NativeStructPtr>::Read(data_view, output); } // static -bool UnmappedNativeStructSerializerImpl::Deserialize( - NativeStruct_Data* input, - NativeStructPtr* output, +void UnmappedNativeStructSerializerImpl::SerializeMessageContents( + IPC::Message* message, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, SerializationContext* context) { - Array_Data<uint8_t>* data = reinterpret_cast<Array_Data<uint8_t>*>(input); + writer->Allocate(buffer); + + // Allocate a uint8 array, initialize its header, and copy the Pickle in. + Array_Data<uint8_t>::BufferWriter data_writer; + data_writer.Allocate(message->payload_size(), buffer); + memcpy(data_writer->storage(), message->payload(), message->payload_size()); + writer->data()->data.Set(data_writer.data()); + + if (message->attachment_set()->empty()) { + writer->data()->handles.Set(nullptr); + return; + } + + mojo::internal::Array_Data<mojo::internal::Pointer< + native::internal::SerializedHandle_Data>>::BufferWriter handles_writer; + auto* attachments = message->attachment_set(); + handles_writer.Allocate(attachments->size(), buffer); + for (unsigned i = 0; i < attachments->size(); ++i) { + native::internal::SerializedHandle_Data::BufferWriter handle_writer; + handle_writer.Allocate(buffer); + + auto attachment = attachments->GetAttachmentAt(i); + ScopedHandle handle = attachment->TakeMojoHandle(); + internal::Serializer<ScopedHandle, ScopedHandle>::Serialize( + handle, &handle_writer->the_handle, context); + handle_writer->type = static_cast<int32_t>( + mojo::ConvertTo<native::SerializedHandle::Type>(attachment->GetType())); + handles_writer.data()->at(i).Set(handle_writer.data()); + } + writer->data()->handles.Set(handles_writer.data()); +} + +// static +bool UnmappedNativeStructSerializerImpl::DeserializeMessageAttachments( + native::internal::NativeStruct_Data* data, + SerializationContext* context, + IPC::Message* message) { + if (data->handles.is_null()) + return true; - NativeStructPtr result(NativeStruct::New()); - if (!internal::Deserialize<ArrayDataView<uint8_t>>(data, &result->data, - context)) { - output = nullptr; - return false; + auto* handles_data = data->handles.Get(); + for (size_t i = 0; i < handles_data->size(); ++i) { + auto* handle_data = handles_data->at(i).Get(); + if (!handle_data) + return false; + ScopedHandle handle; + internal::Serializer<ScopedHandle, ScopedHandle>::Deserialize( + &handle_data->the_handle, &handle, context); + auto attachment = IPC::MessageAttachment::CreateFromMojoHandle( + std::move(handle), + mojo::ConvertTo<IPC::MessageAttachment::Type>( + static_cast<native::SerializedHandle::Type>(handle_data->type))); + message->attachment_set()->AddAttachment(std::move(attachment)); } - if (!result->data) - *output = nullptr; - else - result.Swap(output); return true; } diff --git a/mojo/public/cpp/bindings/lib/native_struct_serialization.h b/mojo/public/cpp/bindings/lib/native_struct_serialization.h index 457435b955..6aa4c3a4a8 100644 --- a/mojo/public/cpp/bindings/lib/native_struct_serialization.h +++ b/mojo/public/cpp/bindings/lib/native_struct_serialization.h @@ -12,59 +12,63 @@ #include "base/logging.h" #include "base/pickle.h" +#include "ipc/ipc_message.h" #include "ipc/ipc_param_traits.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/array_internal.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" -#include "mojo/public/cpp/bindings/lib/native_struct_data.h" #include "mojo/public/cpp/bindings/lib/serialization_forward.h" #include "mojo/public/cpp/bindings/lib/serialization_util.h" -#include "mojo/public/cpp/bindings/native_struct.h" -#include "mojo/public/cpp/bindings/native_struct_data_view.h" +#include "mojo/public/interfaces/bindings/native_struct.mojom.h" namespace mojo { namespace internal { +// Base class for the templated native struct serialization interface below, +// used to consolidated some shared logic and provide a basic +// Serialize/Deserialize for [Native] mojom structs which do not have a +// registered typemap in the current configuration (i.e. structs that are +// represented by a raw native::NativeStruct mojom struct in C++ bindings.) +struct MOJO_CPP_BINDINGS_EXPORT UnmappedNativeStructSerializerImpl { + static void Serialize( + const native::NativeStructPtr& input, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, + SerializationContext* context); + + static bool Deserialize(native::internal::NativeStruct_Data* input, + native::NativeStructPtr* output, + SerializationContext* context); + + static void SerializeMessageContents( + IPC::Message* message, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, + SerializationContext* context); + + static bool DeserializeMessageAttachments( + native::internal::NativeStruct_Data* data, + SerializationContext* context, + IPC::Message* message); +}; + template <typename MaybeConstUserType> struct NativeStructSerializerImpl { using UserType = typename std::remove_const<MaybeConstUserType>::type; using Traits = IPC::ParamTraits<UserType>; - static size_t PrepareToSerialize(MaybeConstUserType& value, - SerializationContext* context) { - base::PickleSizer sizer; - Traits::GetSize(&sizer, value); - return Align(sizer.payload_size() + sizeof(ArrayHeader)); - } - - static void Serialize(MaybeConstUserType& value, - Buffer* buffer, - NativeStruct_Data** out, - SerializationContext* context) { - base::Pickle pickle; - Traits::Write(&pickle, value); - -#if DCHECK_IS_ON() - base::PickleSizer sizer; - Traits::GetSize(&sizer, value); - DCHECK_EQ(sizer.payload_size(), pickle.payload_size()); -#endif - - size_t total_size = pickle.payload_size() + sizeof(ArrayHeader); - DCHECK_LT(total_size, std::numeric_limits<uint32_t>::max()); - - // Allocate a uint8 array, initialize its header, and copy the Pickle in. - ArrayHeader* header = - reinterpret_cast<ArrayHeader*>(buffer->Allocate(total_size)); - header->num_bytes = static_cast<uint32_t>(total_size); - header->num_elements = static_cast<uint32_t>(pickle.payload_size()); - memcpy(reinterpret_cast<char*>(header) + sizeof(ArrayHeader), - pickle.payload(), pickle.payload_size()); - - *out = reinterpret_cast<NativeStruct_Data*>(header); + static void Serialize( + MaybeConstUserType& value, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, + SerializationContext* context) { + IPC::Message message; + Traits::Write(&message, value); + UnmappedNativeStructSerializerImpl::SerializeMessageContents( + &message, buffer, writer, context); } - static bool Deserialize(NativeStruct_Data* data, + static bool Deserialize(native::internal::NativeStruct_Data* data, UserType* out, SerializationContext* context) { if (!data) @@ -82,7 +86,7 @@ struct NativeStructSerializerImpl { // Because ArrayHeader's num_bytes includes the length of the header and // Pickle's payload_size does not, we need to adjust the stored value // momentarily so Pickle can view the data. - ArrayHeader* header = reinterpret_cast<ArrayHeader*>(data); + ArrayHeader* header = reinterpret_cast<ArrayHeader*>(data->data.Get()); DCHECK_GE(header->num_bytes, sizeof(ArrayHeader)); header->num_bytes -= sizeof(ArrayHeader); @@ -90,10 +94,15 @@ struct NativeStructSerializerImpl { // Construct a view over the full Array_Data, including our hacked up // header. Pickle will infer from this that the header is 8 bytes long, // and the payload will contain all of the pickled bytes. - base::Pickle pickle_view(reinterpret_cast<const char*>(header), - header->num_bytes + sizeof(ArrayHeader)); - base::PickleIterator iter(pickle_view); - if (!Traits::Read(&pickle_view, &iter, out)) + IPC::Message message_view(reinterpret_cast<const char*>(header), + header->num_bytes + sizeof(ArrayHeader)); + base::PickleIterator iter(message_view); + if (!UnmappedNativeStructSerializerImpl::DeserializeMessageAttachments( + data, context, &message_view)) { + return false; + } + + if (!Traits::Read(&message_view, &iter, out)) return false; } @@ -104,28 +113,16 @@ struct NativeStructSerializerImpl { } }; -struct MOJO_CPP_BINDINGS_EXPORT UnmappedNativeStructSerializerImpl { - static size_t PrepareToSerialize(const NativeStructPtr& input, - SerializationContext* context); - static void Serialize(const NativeStructPtr& input, - Buffer* buffer, - NativeStruct_Data** output, - SerializationContext* context); - static bool Deserialize(NativeStruct_Data* input, - NativeStructPtr* output, - SerializationContext* context); -}; - template <> -struct NativeStructSerializerImpl<NativeStructPtr> +struct NativeStructSerializerImpl<native::NativeStructPtr> : public UnmappedNativeStructSerializerImpl {}; template <> -struct NativeStructSerializerImpl<const NativeStructPtr> +struct NativeStructSerializerImpl<const native::NativeStructPtr> : public UnmappedNativeStructSerializerImpl {}; template <typename MaybeConstUserType> -struct Serializer<NativeStructDataView, MaybeConstUserType> +struct Serializer<native::NativeStructDataView, MaybeConstUserType> : public NativeStructSerializerImpl<MaybeConstUserType> {}; } // namespace internal diff --git a/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc b/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc index d451c05a5f..d39b991e20 100644 --- a/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc +++ b/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc @@ -6,7 +6,6 @@ #include "base/logging.h" #include "mojo/public/cpp/bindings/interface_id.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" #include "mojo/public/cpp/bindings/lib/serialization.h" #include "mojo/public/cpp/bindings/lib/serialization_context.h" #include "mojo/public/cpp/bindings/lib/validation_context.h" diff --git a/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc b/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc index 1029c2c491..f218892db5 100644 --- a/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc +++ b/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc @@ -9,8 +9,8 @@ #include "base/logging.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" #include "mojo/public/cpp/bindings/lib/serialization.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/interfaces/bindings/pipe_control_messages.mojom.h" namespace mojo { @@ -18,21 +18,16 @@ namespace { Message ConstructRunOrClosePipeMessage( pipe_control::RunOrClosePipeInputPtr input_ptr) { - internal::SerializationContext context; - auto params_ptr = pipe_control::RunOrClosePipeMessageParams::New(); params_ptr->input = std::move(input_ptr); - size_t size = internal::PrepareToSerialize< - pipe_control::RunOrClosePipeMessageParamsDataView>(params_ptr, &context); - internal::MessageBuilder builder(pipe_control::kRunOrClosePipeMessageId, 0, - size, 0); - - pipe_control::internal::RunOrClosePipeMessageParams_Data* params = nullptr; + Message message(pipe_control::kRunOrClosePipeMessageId, 0, 0, 0, nullptr); + internal::SerializationContext context; + pipe_control::internal::RunOrClosePipeMessageParams_Data::BufferWriter params; internal::Serialize<pipe_control::RunOrClosePipeMessageParamsDataView>( - params_ptr, builder.buffer(), ¶ms, &context); - builder.message()->set_interface_id(kInvalidInterfaceId); - return std::move(*builder.message()); + params_ptr, message.payload_buffer(), ¶ms, &context); + message.set_interface_id(kInvalidInterfaceId); + return message; } } // namespace diff --git a/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc b/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc index c1345079a5..2e5559ce0d 100644 --- a/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc +++ b/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc @@ -7,6 +7,7 @@ #include "base/bind.h" #include "base/logging.h" #include "base/synchronization/lock.h" +#include "base/threading/sequenced_task_runner_handle.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/lib/may_auto_lock.h" @@ -14,7 +15,7 @@ namespace mojo { // ScopedInterfaceEndpointHandle::State ---------------------------------------- -// State could be called from multiple threads. +// State could be called from multiple sequences. class ScopedInterfaceEndpointHandle::State : public base::RefCountedThreadSafe<State> { public: @@ -51,7 +52,7 @@ class ScopedInterfaceEndpointHandle::State // Intentionally keep |group_controller_| unchanged. // That is because the callback created by // CreateGroupControllerGetter() could still be used after this point, - // potentially from another thread. We would like it to continue + // potentially from another sequence. We would like it to continue // returning the same group controller. // // Imagine there is a ThreadSafeForwarder A: @@ -103,7 +104,7 @@ class ScopedInterfaceEndpointHandle::State return; } - runner_ = base::ThreadTaskRunnerHandle::Get(); + runner_ = base::SequencedTaskRunnerHandle::Get(); if (!pending_association_) { runner_->PostTask( FROM_HERE, @@ -171,7 +172,7 @@ class ScopedInterfaceEndpointHandle::State DCHECK(!IsValidInterfaceId(id_)); } - // Called by the peer, maybe from a different thread. + // Called by the peer, maybe from a different sequence. void OnAssociated(InterfaceId id, scoped_refptr<AssociatedGroupController> group_controller) { AssociationEventCallback handler; @@ -179,7 +180,7 @@ class ScopedInterfaceEndpointHandle::State internal::MayAutoLock locker(&lock_); // There may be race between Close() of endpoint A and - // NotifyPeerAssociation() of endpoint A_peer on different threads. + // NotifyPeerAssociation() of endpoint A_peer on different sequences. // Therefore, it is possible that endpoint A has been closed but it // still gets OnAssociated() call from its peer. if (!pending_association_) @@ -191,7 +192,7 @@ class ScopedInterfaceEndpointHandle::State group_controller_ = std::move(group_controller); if (!association_event_handler_.is_null()) { - if (runner_->BelongsToCurrentThread()) { + if (runner_->RunsTasksInCurrentSequence()) { handler = std::move(association_event_handler_); runner_ = nullptr; } else { @@ -207,7 +208,7 @@ class ScopedInterfaceEndpointHandle::State std::move(handler).Run(ASSOCIATED); } - // Called by the peer, maybe from a different thread. + // Called by the peer, maybe from a different sequence. void OnPeerClosedBeforeAssociation( const base::Optional<DisconnectReason>& reason) { AssociationEventCallback handler; @@ -215,7 +216,7 @@ class ScopedInterfaceEndpointHandle::State internal::MayAutoLock locker(&lock_); // There may be race between Close()/NotifyPeerAssociation() of endpoint - // A and Close() of endpoint A_peer on different threads. + // A and Close() of endpoint A_peer on different sequences. // Therefore, it is possible that endpoint A is not in pending association // state but still gets OnPeerClosedBeforeAssociation() call from its // peer. @@ -227,7 +228,7 @@ class ScopedInterfaceEndpointHandle::State peer_state_ = nullptr; if (!association_event_handler_.is_null()) { - if (runner_->BelongsToCurrentThread()) { + if (runner_->RunsTasksInCurrentSequence()) { handler = std::move(association_event_handler_); runner_ = nullptr; } else { @@ -245,7 +246,7 @@ class ScopedInterfaceEndpointHandle::State } void RunAssociationEventHandler( - scoped_refptr<base::SingleThreadTaskRunner> posted_to_runner, + scoped_refptr<base::SequencedTaskRunner> posted_to_runner, AssociationEvent event) { AssociationEventCallback handler; @@ -271,7 +272,7 @@ class ScopedInterfaceEndpointHandle::State scoped_refptr<State> peer_state_; AssociationEventCallback association_event_handler_; - scoped_refptr<base::SingleThreadTaskRunner> runner_; + scoped_refptr<base::SequencedTaskRunner> runner_; InterfaceId id_ = kInvalidInterfaceId; scoped_refptr<AssociatedGroupController> group_controller_; @@ -373,7 +374,7 @@ void ScopedInterfaceEndpointHandle::ResetInternal( base::Callback<AssociatedGroupController*()> ScopedInterfaceEndpointHandle::CreateGroupControllerGetter() const { - // We allow this callback to be run on any thread. If this handle is created + // We allow this callback to be run on any sequence. If this handle is created // in non-pending state, we don't have a lock but it should still be safe // because the group controller never changes. return base::Bind(&State::group_controller, state_); diff --git a/mojo/public/cpp/bindings/lib/sequence_local_sync_event_watcher.cc b/mojo/public/cpp/bindings/lib/sequence_local_sync_event_watcher.cc new file mode 100644 index 0000000000..f4618ffbe8 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/sequence_local_sync_event_watcher.cc @@ -0,0 +1,286 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/sequence_local_sync_event_watcher.h" + +#include <map> +#include <memory> +#include <set> + +#include "base/bind.h" +#include "base/containers/flat_set.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "base/no_destructor.h" +#include "base/synchronization/lock.h" +#include "base/synchronization/waitable_event.h" +#include "base/threading/sequence_local_storage_slot.h" +#include "mojo/public/cpp/bindings/sync_event_watcher.h" + +namespace mojo { + +namespace { + +struct WatcherState; + +using WatcherStateMap = + std::map<const SequenceLocalSyncEventWatcher*, scoped_refptr<WatcherState>>; + +// Ref-counted watcher state which may outlive the watcher to which it pertains. +// This is necessary to store outside of the SequenceLocalSyncEventWatcher +// itself in order to support nested sync operations where an inner operation +// may destroy the watcher. +struct WatcherState : public base::RefCounted<WatcherState> { + WatcherState() = default; + + bool watcher_was_destroyed = false; + + private: + friend class base::RefCounted<WatcherState>; + + ~WatcherState() = default; + + DISALLOW_COPY_AND_ASSIGN(WatcherState); +}; + +} // namespace + +// Owns the WaitableEvent and SyncEventWatcher shared by all +// SequenceLocalSyncEventWatchers on a single sequence, and coordinates the +// multiplexing of those shared objects to support an arbitrary number of +// SequenceLocalSyncEventWatchers waiting and signaling potentially while +// nested. +class SequenceLocalSyncEventWatcher::SequenceLocalState { + public: + SequenceLocalState() + : event_(base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::NOT_SIGNALED), + event_watcher_(&event_, + base::BindRepeating(&SequenceLocalState::OnEventSignaled, + base::Unretained(this))), + weak_ptr_factory_(this) { + // We always allow this event handler to be awoken during any sync event on + // the sequence. Individual watchers still must opt into having such + // wake-ups propagated to them. + event_watcher_.AllowWokenUpBySyncWatchOnSameThread(); + } + + ~SequenceLocalState() {} + + // Initializes a SequenceLocalState instance in sequence-local storage if + // not already initialized. Returns a WeakPtr to the stored state object. + static base::WeakPtr<SequenceLocalState> GetOrCreate() { + auto& state_ptr = GetStorageSlot().Get(); + if (!state_ptr) + state_ptr = std::make_unique<SequenceLocalState>(); + return state_ptr->weak_ptr_factory_.GetWeakPtr(); + } + + // Registers a new watcher and returns an iterator into the WatcherStateMap to + // be used for fast access with other methods. + WatcherStateMap::iterator RegisterWatcher( + const SequenceLocalSyncEventWatcher* watcher) { + auto result = registered_watchers_.emplace( + watcher, base::MakeRefCounted<WatcherState>()); + DCHECK(result.second); + return result.first; + } + + void UnregisterWatcher(WatcherStateMap::iterator iter) { + if (top_watcher_ == iter->first) { + // If the watcher being unregistered is currently blocking in a + // |SyncWatch()| operation, we need to unblock it. Setting this flag does + // that. + top_watcher_state_->watcher_was_destroyed = true; + top_watcher_state_ = nullptr; + top_watcher_ = nullptr; + } + + { + base::AutoLock lock(ready_watchers_lock_); + ready_watchers_.erase(iter->first); + } + + registered_watchers_.erase(iter); + if (registered_watchers_.empty()) { + // If no more watchers are registered, clear our sequence-local storage. + // Deletes |this|. + GetStorageSlot().Get().reset(); + } + } + + void SignalForWatcher(const SequenceLocalSyncEventWatcher* watcher) { + bool must_signal = false; + { + base::AutoLock lock(ready_watchers_lock_); + must_signal = ready_watchers_.empty(); + ready_watchers_.insert(watcher); + } + + // If we didn't have any ready watchers before, the event may not have + // been signaled. Signal it to ensure that |OnEventSignaled()| is run. + if (must_signal) + event_.Signal(); + } + + void ResetForWatcher(const SequenceLocalSyncEventWatcher* watcher) { + base::AutoLock lock(ready_watchers_lock_); + ready_watchers_.erase(watcher); + + // No more watchers are ready, so we can reset the event. The next watcher + // to call |SignalForWatcher()| will re-signal the event. + if (ready_watchers_.empty()) + event_.Reset(); + } + + bool SyncWatch(const SequenceLocalSyncEventWatcher* watcher, + WatcherState* watcher_state, + const bool* should_stop) { + // |SyncWatch()| calls may nest arbitrarily deep on the same sequence. We + // preserve the outer watcher state on the stack and restore it once the + // innermost watch is complete. + const SequenceLocalSyncEventWatcher* outer_watcher = top_watcher_; + WatcherState* outer_watcher_state = top_watcher_state_; + + // Keep a ref on the stack so the state stays alive even if the watcher is + // destroyed. + scoped_refptr<WatcherState> top_watcher_state(watcher_state); + top_watcher_state_ = watcher_state; + top_watcher_ = watcher; + + // In addition to the caller's own stop condition, we need to interrupt the + // SyncEventWatcher if |watcher| is destroyed while we're waiting. + const bool* stop_flags[] = {should_stop, + &top_watcher_state_->watcher_was_destroyed}; + + // |SyncWatch()| may delete |this|. + auto weak_self = weak_ptr_factory_.GetWeakPtr(); + bool result = event_watcher_.SyncWatch(stop_flags, 2); + if (!weak_self) + return false; + + top_watcher_state_ = outer_watcher_state; + top_watcher_ = outer_watcher; + return result; + } + + private: + using StorageSlotType = + base::SequenceLocalStorageSlot<std::unique_ptr<SequenceLocalState>>; + static StorageSlotType& GetStorageSlot() { + static base::NoDestructor<StorageSlotType> storage; + return *storage; + } + + void OnEventSignaled(); + + // The shared event and watcher used for this sequence. + base::WaitableEvent event_; + mojo::SyncEventWatcher event_watcher_; + + // All SequenceLocalSyncEventWatchers on the current sequence have some state + // registered here. + WatcherStateMap registered_watchers_; + + // Tracks state of the top-most |SyncWatch()| invocation on the stack. + const SequenceLocalSyncEventWatcher* top_watcher_ = nullptr; + WatcherState* top_watcher_state_ = nullptr; + + // Set of all SequenceLocalSyncEventWatchers in a signaled state, guarded by + // a lock for sequence-safe signaling. + base::Lock ready_watchers_lock_; + base::flat_set<const SequenceLocalSyncEventWatcher*> ready_watchers_; + + base::WeakPtrFactory<SequenceLocalState> weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(SequenceLocalState); +}; + +void SequenceLocalSyncEventWatcher::SequenceLocalState::OnEventSignaled() { + for (;;) { + base::flat_set<const SequenceLocalSyncEventWatcher*> ready_watchers; + { + base::AutoLock lock(ready_watchers_lock_); + std::swap(ready_watchers_, ready_watchers); + } + if (ready_watchers.empty()) + return; + + auto weak_self = weak_ptr_factory_.GetWeakPtr(); + for (auto* watcher : ready_watchers) { + if (top_watcher_ == watcher || watcher->can_wake_up_during_any_watch_) { + watcher->callback_.Run(); + + // The callback may have deleted |this|. + if (!weak_self) + return; + } + } + } +} + +// Manages a watcher's reference to the sequence-local state. This hides +// implementation details from the SequenceLocalSyncEventWatcher interface. +class SequenceLocalSyncEventWatcher::Registration { + public: + explicit Registration(const SequenceLocalSyncEventWatcher* watcher) + : weak_shared_state_(SequenceLocalState::GetOrCreate()), + shared_state_(weak_shared_state_.get()), + watcher_state_iterator_(shared_state_->RegisterWatcher(watcher)), + watcher_state_(watcher_state_iterator_->second) {} + + ~Registration() { + if (weak_shared_state_) { + // Because |this| may itself be owned by sequence- or thread-local storage + // (e.g. if an interface binding lives there) we have no guarantee that + // our SequenceLocalState's storage slot will still be alive during our + // own destruction; so we have to guard against any access to it. Note + // that this uncertainty only exists within the destructor and does not + // apply to other methods on SequenceLocalSyncEventWatcher. + // + // May delete |shared_state_|. + shared_state_->UnregisterWatcher(watcher_state_iterator_); + } + } + + SequenceLocalState* shared_state() const { return shared_state_; } + WatcherState* watcher_state() { return watcher_state_.get(); } + + private: + const base::WeakPtr<SequenceLocalState> weak_shared_state_; + SequenceLocalState* const shared_state_; + WatcherStateMap::iterator watcher_state_iterator_; + const scoped_refptr<WatcherState> watcher_state_; + + DISALLOW_COPY_AND_ASSIGN(Registration); +}; + +SequenceLocalSyncEventWatcher::SequenceLocalSyncEventWatcher( + const base::RepeatingClosure& callback) + : registration_(std::make_unique<Registration>(this)), + callback_(callback) {} + +SequenceLocalSyncEventWatcher::~SequenceLocalSyncEventWatcher() = default; + +void SequenceLocalSyncEventWatcher::SignalEvent() { + registration_->shared_state()->SignalForWatcher(this); +} + +void SequenceLocalSyncEventWatcher::ResetEvent() { + registration_->shared_state()->ResetForWatcher(this); +} + +void SequenceLocalSyncEventWatcher::AllowWokenUpBySyncWatchOnSameSequence() { + can_wake_up_during_any_watch_ = true; +} + +bool SequenceLocalSyncEventWatcher::SyncWatch(const bool* should_stop) { + // NOTE: |SyncWatch()| may delete |this|. + return registration_->shared_state()->SyncWatch( + this, registration_->watcher_state(), should_stop); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/serialization.h b/mojo/public/cpp/bindings/lib/serialization.h index 2a7d288d55..8ced91ea53 100644 --- a/mojo/public/cpp/bindings/lib/serialization.h +++ b/mojo/public/cpp/bindings/lib/serialization.h @@ -7,96 +7,122 @@ #include <string.h> -#include "mojo/public/cpp/bindings/array_traits_carray.h" +#include <type_traits> + +#include "base/numerics/safe_math.h" +#include "mojo/public/cpp/bindings/array_traits_span.h" #include "mojo/public/cpp/bindings/array_traits_stl.h" #include "mojo/public/cpp/bindings/lib/array_serialization.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" #include "mojo/public/cpp/bindings/lib/buffer.h" -#include "mojo/public/cpp/bindings/lib/handle_interface_serialization.h" +#include "mojo/public/cpp/bindings/lib/handle_serialization.h" #include "mojo/public/cpp/bindings/lib/map_serialization.h" -#include "mojo/public/cpp/bindings/lib/native_enum_serialization.h" -#include "mojo/public/cpp/bindings/lib/native_struct_serialization.h" #include "mojo/public/cpp/bindings/lib/string_serialization.h" #include "mojo/public/cpp/bindings/lib/template_util.h" +#include "mojo/public/cpp/bindings/map_traits_flat_map.h" #include "mojo/public/cpp/bindings/map_traits_stl.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/bindings/string_traits_stl.h" -#include "mojo/public/cpp/bindings/string_traits_string16.h" #include "mojo/public/cpp/bindings/string_traits_string_piece.h" namespace mojo { namespace internal { -template <typename MojomType, typename DataArrayType, typename UserType> -DataArrayType StructSerializeImpl(UserType* input) { - static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value, - "Unexpected type."); +template <typename MojomType, typename EnableType = void> +struct MojomSerializationImplTraits; + +template <typename MojomType> +struct MojomSerializationImplTraits< + MojomType, + typename std::enable_if< + BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value>::type> { + template <typename MaybeConstUserType, typename WriterType> + static void Serialize(MaybeConstUserType& input, + Buffer* buffer, + WriterType* writer, + SerializationContext* context) { + mojo::internal::Serialize<MojomType>(input, buffer, writer, context); + } +}; + +template <typename MojomType> +struct MojomSerializationImplTraits< + MojomType, + typename std::enable_if< + BelongsTo<MojomType, MojomTypeCategory::UNION>::value>::type> { + template <typename MaybeConstUserType, typename WriterType> + static void Serialize(MaybeConstUserType& input, + Buffer* buffer, + WriterType* writer, + SerializationContext* context) { + mojo::internal::Serialize<MojomType>(input, buffer, writer, + false /* inline */, context); + } +}; +template <typename MojomType, typename UserType> +mojo::Message SerializeAsMessageImpl(UserType* input) { SerializationContext context; - size_t size = PrepareToSerialize<MojomType>(*input, &context); - DCHECK_EQ(size, Align(size)); + mojo::Message message(0, 0, 0, 0, nullptr); + typename MojomTypeTraits<MojomType>::Data::BufferWriter writer; + MojomSerializationImplTraits<MojomType>::Serialize( + *input, message.payload_buffer(), &writer, &context); + message.AttachHandlesFromSerializationContext(&context); + return message; +} +template <typename MojomType, typename DataArrayType, typename UserType> +DataArrayType SerializeImpl(UserType* input) { + static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value || + BelongsTo<MojomType, MojomTypeCategory::UNION>::value, + "Unexpected type."); + Message message = SerializeAsMessageImpl<MojomType>(input); + uint32_t size = message.payload_num_bytes(); DataArrayType result(size); - if (size == 0) - return result; - - void* result_buffer = &result.front(); - // The serialization logic requires that the buffer is 8-byte aligned. If the - // result buffer is not properly aligned, we have to do an extra copy. In - // practice, this should never happen for std::vector. - bool need_copy = !IsAligned(result_buffer); - - if (need_copy) { - // calloc sets the memory to all zero. - result_buffer = calloc(size, 1); - DCHECK(IsAligned(result_buffer)); - } - - Buffer buffer; - buffer.Initialize(result_buffer, size); - typename MojomTypeTraits<MojomType>::Data* data = nullptr; - Serialize<MojomType>(*input, &buffer, &data, &context); - - if (need_copy) { - memcpy(&result.front(), result_buffer, size); - free(result_buffer); - } - + if (size) + memcpy(&result.front(), message.payload(), size); return result; } -template <typename MojomType, typename DataArrayType, typename UserType> -bool StructDeserializeImpl(const DataArrayType& input, - UserType* output, - bool (*validate_func)(const void*, - ValidationContext*)) { - static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value, +template <typename MojomType, typename UserType> +bool DeserializeImpl(const void* data, + size_t data_num_bytes, + std::vector<mojo::ScopedHandle> handles, + UserType* output, + bool (*validate_func)(const void*, ValidationContext*)) { + static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value || + BelongsTo<MojomType, MojomTypeCategory::UNION>::value, "Unexpected type."); using DataType = typename MojomTypeTraits<MojomType>::Data; - // TODO(sammc): Use DataArrayType::empty() once WTF::Vector::empty() exists. - void* input_buffer = - input.size() == 0 - ? nullptr - : const_cast<void*>(reinterpret_cast<const void*>(&input.front())); + const void* input_buffer = data_num_bytes == 0 ? nullptr : data; + void* aligned_input_buffer = nullptr; - // Please see comments in StructSerializeImpl. + // Validation code will insist that the input buffer is aligned, so we ensure + // that here. If the input data is not aligned, we (sadly) copy into an + // aligned buffer. In practice this should happen only rarely if ever. bool need_copy = !IsAligned(input_buffer); - if (need_copy) { - input_buffer = malloc(input.size()); - DCHECK(IsAligned(input_buffer)); - memcpy(input_buffer, &input.front(), input.size()); + aligned_input_buffer = malloc(data_num_bytes); + DCHECK(IsAligned(aligned_input_buffer)); + memcpy(aligned_input_buffer, data, data_num_bytes); + input_buffer = aligned_input_buffer; } - ValidationContext validation_context(input_buffer, input.size(), 0, 0); + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(data_num_bytes)); + ValidationContext validation_context( + input_buffer, static_cast<uint32_t>(data_num_bytes), handles.size(), 0); bool result = false; if (validate_func(input_buffer, &validation_context)) { - auto data = reinterpret_cast<DataType*>(input_buffer); SerializationContext context; - result = Deserialize<MojomType>(data, output, &context); + *context.mutable_handles() = std::move(handles); + result = Deserialize<MojomType>( + reinterpret_cast<DataType*>(const_cast<void*>(input_buffer)), output, + &context); } - if (need_copy) - free(input_buffer); + if (aligned_input_buffer) + free(aligned_input_buffer); return result; } diff --git a/mojo/public/cpp/bindings/lib/serialization_context.cc b/mojo/public/cpp/bindings/lib/serialization_context.cc index e2fd5c6e18..267b54154b 100644 --- a/mojo/public/cpp/bindings/lib/serialization_context.cc +++ b/mojo/public/cpp/bindings/lib/serialization_context.cc @@ -7,50 +7,78 @@ #include <limits> #include "base/logging.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/system/core.h" namespace mojo { namespace internal { -SerializedHandleVector::SerializedHandleVector() {} +SerializationContext::SerializationContext() = default; -SerializedHandleVector::~SerializedHandleVector() { - for (auto handle : handles_) { - if (handle.is_valid()) { - MojoResult rv = MojoClose(handle.value()); - DCHECK_EQ(rv, MOJO_RESULT_OK); - } +SerializationContext::~SerializationContext() = default; + +void SerializationContext::AddHandle(mojo::ScopedHandle handle, + Handle_Data* out_data) { + if (!handle.is_valid()) { + out_data->value = kEncodedInvalidHandleValue; + } else { + DCHECK_LT(handles_.size(), std::numeric_limits<uint32_t>::max()); + out_data->value = static_cast<uint32_t>(handles_.size()); + handles_.emplace_back(std::move(handle)); } } -Handle_Data SerializedHandleVector::AddHandle(mojo::Handle handle) { - Handle_Data data; +void SerializationContext::AddInterfaceInfo( + mojo::ScopedMessagePipeHandle handle, + uint32_t version, + Interface_Data* out_data) { + AddHandle(ScopedHandle::From(std::move(handle)), &out_data->handle); + out_data->version = version; +} + +void SerializationContext::AddAssociatedEndpoint( + ScopedInterfaceEndpointHandle handle, + AssociatedEndpointHandle_Data* out_data) { if (!handle.is_valid()) { - data.value = kEncodedInvalidHandleValue; + out_data->value = kEncodedInvalidHandleValue; } else { - DCHECK_LT(handles_.size(), std::numeric_limits<uint32_t>::max()); - data.value = static_cast<uint32_t>(handles_.size()); - handles_.push_back(handle); + DCHECK_LT(associated_endpoint_handles_.size(), + std::numeric_limits<uint32_t>::max()); + out_data->value = + static_cast<uint32_t>(associated_endpoint_handles_.size()); + associated_endpoint_handles_.emplace_back(std::move(handle)); } - return data; } -mojo::Handle SerializedHandleVector::TakeHandle( - const Handle_Data& encoded_handle) { - if (!encoded_handle.is_valid()) - return mojo::Handle(); - DCHECK_LT(encoded_handle.value, handles_.size()); - return FetchAndReset(&handles_[encoded_handle.value]); +void SerializationContext::AddAssociatedInterfaceInfo( + ScopedInterfaceEndpointHandle handle, + uint32_t version, + AssociatedInterface_Data* out_data) { + AddAssociatedEndpoint(std::move(handle), &out_data->handle); + out_data->version = version; } -void SerializedHandleVector::Swap(std::vector<mojo::Handle>* other) { - handles_.swap(*other); +void SerializationContext::TakeHandlesFromMessage(Message* message) { + handles_.swap(*message->mutable_handles()); + associated_endpoint_handles_.swap( + *message->mutable_associated_endpoint_handles()); } -SerializationContext::SerializationContext() {} +mojo::ScopedHandle SerializationContext::TakeHandle( + const Handle_Data& encoded_handle) { + if (!encoded_handle.is_valid()) + return mojo::ScopedHandle(); + DCHECK_LT(encoded_handle.value, handles_.size()); + return std::move(handles_[encoded_handle.value]); +} -SerializationContext::~SerializationContext() { - DCHECK(!custom_contexts || custom_contexts->empty()); +mojo::ScopedInterfaceEndpointHandle +SerializationContext::TakeAssociatedEndpointHandle( + const AssociatedEndpointHandle_Data& encoded_handle) { + if (!encoded_handle.is_valid()) + return mojo::ScopedInterfaceEndpointHandle(); + DCHECK_LT(encoded_handle.value, associated_endpoint_handles_.size()); + return std::move(associated_endpoint_handles_[encoded_handle.value]); } } // namespace internal diff --git a/mojo/public/cpp/bindings/lib/serialization_context.h b/mojo/public/cpp/bindings/lib/serialization_context.h index a34fe3d4ed..0e3c0788dc 100644 --- a/mojo/public/cpp/bindings/lib/serialization_context.h +++ b/mojo/public/cpp/bindings/lib/serialization_context.h @@ -8,67 +8,88 @@ #include <stddef.h> #include <memory> -#include <queue> #include <vector> +#include "base/component_export.h" +#include "base/containers/stack_container.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" #include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" #include "mojo/public/cpp/system/handle.h" namespace mojo { + +class Message; + namespace internal { -// A container for handles during serialization/deserialization. -class MOJO_CPP_BINDINGS_EXPORT SerializedHandleVector { +// Context information for serialization/deserialization routines. +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) SerializationContext { public: - SerializedHandleVector(); - ~SerializedHandleVector(); + SerializationContext(); + ~SerializationContext(); - size_t size() const { return handles_.size(); } + // Adds a handle to the handle list and outputs its serialized form in + // |*out_data|. + void AddHandle(mojo::ScopedHandle handle, Handle_Data* out_data); + + // Adds an interface info to the handle list and outputs its serialized form + // in |*out_data|. + void AddInterfaceInfo(mojo::ScopedMessagePipeHandle handle, + uint32_t version, + Interface_Data* out_data); + + // Adds an associated interface endpoint (for e.g. an + // AssociatedInterfaceRequest) to this context and outputs its serialized form + // in |*out_data|. + void AddAssociatedEndpoint(ScopedInterfaceEndpointHandle handle, + AssociatedEndpointHandle_Data* out_data); + + // Adds an associated interface info to associated endpoint handle and version + // data lists and outputs its serialized form in |*out_data|. + void AddAssociatedInterfaceInfo(ScopedInterfaceEndpointHandle handle, + uint32_t version, + AssociatedInterface_Data* out_data); + + const std::vector<mojo::ScopedHandle>* handles() { return &handles_; } + std::vector<mojo::ScopedHandle>* mutable_handles() { return &handles_; } + + const std::vector<ScopedInterfaceEndpointHandle>* + associated_endpoint_handles() const { + return &associated_endpoint_handles_; + } + std::vector<ScopedInterfaceEndpointHandle>* + mutable_associated_endpoint_handles() { + return &associated_endpoint_handles_; + } - // Adds a handle to the handle list and returns its index for encoding. - Handle_Data AddHandle(mojo::Handle handle); + // Takes handles from a received Message object and assumes ownership of them. + // Individual handles can be extracted using Take* methods below. + void TakeHandlesFromMessage(Message* message); // Takes a handle from the list of serialized handle data. - mojo::Handle TakeHandle(const Handle_Data& encoded_handle); + mojo::ScopedHandle TakeHandle(const Handle_Data& encoded_handle); // Takes a handle from the list of serialized handle data and returns it in // |*out_handle| as a specific scoped handle type. template <typename T> ScopedHandleBase<T> TakeHandleAs(const Handle_Data& encoded_handle) { - return MakeScopedHandle(T(TakeHandle(encoded_handle).value())); + return ScopedHandleBase<T>::From(TakeHandle(encoded_handle)); } - // Swaps all owned handles out with another Handle vector. - void Swap(std::vector<mojo::Handle>* other); + mojo::ScopedInterfaceEndpointHandle TakeAssociatedEndpointHandle( + const AssociatedEndpointHandle_Data& encoded_handle); private: - // Handles are owned by this object. - std::vector<mojo::Handle> handles_; - - DISALLOW_COPY_AND_ASSIGN(SerializedHandleVector); -}; - -// Context information for serialization/deserialization routines. -struct MOJO_CPP_BINDINGS_EXPORT SerializationContext { - SerializationContext(); - - ~SerializationContext(); - - // Opaque context pointers returned by StringTraits::SetUpContext(). - std::unique_ptr<std::queue<void*>> custom_contexts; - - // Stashes handles encoded in a message by index. - SerializedHandleVector handles; - - // The number of ScopedInterfaceEndpointHandles that need to be serialized. - // It is calculated by PrepareToSerialize(). - uint32_t associated_endpoint_count = 0; + // Handles owned by this object. Used during serialization to hold onto + // handles accumulated during pre-serialization, and used during + // deserialization to hold onto handles extracted from a message. + std::vector<mojo::ScopedHandle> handles_; // Stashes ScopedInterfaceEndpointHandles encoded in a message by index. - std::vector<ScopedInterfaceEndpointHandle> associated_endpoint_handles; + std::vector<ScopedInterfaceEndpointHandle> associated_endpoint_handles_; + + DISALLOW_COPY_AND_ASSIGN(SerializationContext); }; } // namespace internal diff --git a/mojo/public/cpp/bindings/lib/serialization_forward.h b/mojo/public/cpp/bindings/lib/serialization_forward.h index 55c9982ccc..562951ee4a 100644 --- a/mojo/public/cpp/bindings/lib/serialization_forward.h +++ b/mojo/public/cpp/bindings/lib/serialization_forward.h @@ -33,22 +33,6 @@ struct IsOptionalWrapper { typename std::remove_reference<T>::type>::type>::value; }; -// PrepareToSerialize() must be matched by a Serialize() for the same input -// later. Moreover, within the same SerializationContext if PrepareToSerialize() -// is called for |input_1|, ..., |input_n|, Serialize() must be called for -// those objects in the exact same order. -template <typename MojomType, - typename InputUserType, - typename... Args, - typename std::enable_if< - !IsOptionalWrapper<InputUserType>::value>::type* = nullptr> -size_t PrepareToSerialize(InputUserType&& input, Args&&... args) { - return Serializer<MojomType, - typename std::remove_reference<InputUserType>::type>:: - PrepareToSerialize(std::forward<InputUserType>(input), - std::forward<Args>(args)...); -} - template <typename MojomType, typename InputUserType, typename... Args, @@ -71,33 +55,19 @@ bool Deserialize(DataType&& input, InputUserType* output, Args&&... args) { std::forward<DataType>(input), output, std::forward<Args>(args)...); } -// Specialization that unwraps base::Optional<>. template <typename MojomType, typename InputUserType, - typename... Args, - typename std::enable_if< - IsOptionalWrapper<InputUserType>::value>::type* = nullptr> -size_t PrepareToSerialize(InputUserType&& input, Args&&... args) { - if (!input) - return 0; - return PrepareToSerialize<MojomType>(*input, std::forward<Args>(args)...); -} - -template <typename MojomType, - typename InputUserType, - typename DataType, + typename BufferWriterType, typename... Args, typename std::enable_if< IsOptionalWrapper<InputUserType>::value>::type* = nullptr> void Serialize(InputUserType&& input, Buffer* buffer, - DataType** output, + BufferWriterType* writer, Args&&... args) { - if (!input) { - *output = nullptr; + if (!input) return; - } - Serialize<MojomType>(*input, buffer, output, std::forward<Args>(args)...); + Serialize<MojomType>(*input, buffer, writer, std::forward<Args>(args)...); } template <typename MojomType, diff --git a/mojo/public/cpp/bindings/lib/serialization_util.h b/mojo/public/cpp/bindings/lib/serialization_util.h index 4820a014ec..a7a99b3bb7 100644 --- a/mojo/public/cpp/bindings/lib/serialization_util.h +++ b/mojo/public/cpp/bindings/lib/serialization_util.h @@ -21,7 +21,7 @@ namespace internal { template <typename T> struct HasIsNullMethod { template <typename U> - static char Test(decltype(U::IsNull)*); + static char Test(decltype(U::IsNull) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); @@ -48,7 +48,7 @@ bool CallIsNullIfExists(const UserType& input) { template <typename T> struct HasSetToNullMethod { template <typename U> - static char Test(decltype(U::SetToNull)*); + static char Test(decltype(U::SetToNull) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); @@ -80,7 +80,7 @@ bool CallSetToNullIfExists(UserType* output) { template <typename T> struct HasSetUpContextMethod { template <typename U> - static char Test(decltype(U::SetUpContext)*); + static char Test(decltype(U::SetUpContext) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); @@ -97,17 +97,7 @@ template <typename Traits> struct CustomContextHelper<Traits, true> { template <typename MaybeConstUserType> static void* SetUp(MaybeConstUserType& input, SerializationContext* context) { - void* custom_context = Traits::SetUpContext(input); - if (!context->custom_contexts) - context->custom_contexts.reset(new std::queue<void*>()); - context->custom_contexts->push(custom_context); - return custom_context; - } - - static void* GetNext(SerializationContext* context) { - void* custom_context = context->custom_contexts->front(); - context->custom_contexts->pop(); - return custom_context; + return Traits::SetUpContext(input); } template <typename MaybeConstUserType> @@ -123,8 +113,6 @@ struct CustomContextHelper<Traits, false> { return nullptr; } - static void* GetNext(SerializationContext* context) { return nullptr; } - template <typename MaybeConstUserType> static void TearDown(MaybeConstUserType& input, void* custom_context) { DCHECK(!custom_context); @@ -148,7 +136,8 @@ ReturnType CallWithContext(ReturnType (*f)(ParamType), template <typename T, typename MaybeConstUserType> struct HasGetBeginMethod { template <typename U> - static char Test(decltype(U::GetBegin(std::declval<MaybeConstUserType&>()))*); + static char Test( + decltype(U::GetBegin(std::declval<MaybeConstUserType&>())) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); @@ -179,7 +168,7 @@ size_t CallGetBeginIfExists(MaybeConstUserType& input) { template <typename T, typename MaybeConstUserType> struct HasGetDataMethod { template <typename U> - static char Test(decltype(U::GetData(std::declval<MaybeConstUserType&>()))*); + static char Test(decltype(U::GetData(std::declval<MaybeConstUserType&>())) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); diff --git a/mojo/public/cpp/bindings/lib/string_serialization.h b/mojo/public/cpp/bindings/lib/string_serialization.h index 6e0c758576..1fe6b87af7 100644 --- a/mojo/public/cpp/bindings/lib/string_serialization.h +++ b/mojo/public/cpp/bindings/lib/string_serialization.h @@ -22,36 +22,18 @@ struct Serializer<StringDataView, MaybeConstUserType> { using UserType = typename std::remove_const<MaybeConstUserType>::type; using Traits = StringTraits<UserType>; - static size_t PrepareToSerialize(MaybeConstUserType& input, - SerializationContext* context) { - if (CallIsNullIfExists<Traits>(input)) - return 0; - - void* custom_context = CustomContextHelper<Traits>::SetUp(input, context); - return Align(sizeof(String_Data) + - CallWithContext(Traits::GetSize, input, custom_context)); - } - static void Serialize(MaybeConstUserType& input, Buffer* buffer, - String_Data** output, + String_Data::BufferWriter* writer, SerializationContext* context) { - if (CallIsNullIfExists<Traits>(input)) { - *output = nullptr; + if (CallIsNullIfExists<Traits>(input)) return; - } - - void* custom_context = CustomContextHelper<Traits>::GetNext(context); - - String_Data* result = String_Data::New( - CallWithContext(Traits::GetSize, input, custom_context), buffer); - if (result) { - memcpy(result->storage(), - CallWithContext(Traits::GetData, input, custom_context), - CallWithContext(Traits::GetSize, input, custom_context)); - } - *output = result; + void* custom_context = CustomContextHelper<Traits>::SetUp(input, context); + const size_t size = CallWithContext(Traits::GetSize, input, custom_context); + writer->Allocate(size, buffer); + memcpy((*writer)->storage(), + CallWithContext(Traits::GetData, input, custom_context), size); CustomContextHelper<Traits>::TearDown(input, custom_context); } diff --git a/mojo/public/cpp/bindings/lib/string_traits_string16.cc b/mojo/public/cpp/bindings/lib/string_traits_string16.cc deleted file mode 100644 index 95ff6ccf25..0000000000 --- a/mojo/public/cpp/bindings/lib/string_traits_string16.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/string_traits_string16.h" - -#include <string> - -#include "base/strings/utf_string_conversions.h" - -namespace mojo { - -// static -void* StringTraits<base::string16>::SetUpContext(const base::string16& input) { - return new std::string(base::UTF16ToUTF8(input)); -} - -// static -void StringTraits<base::string16>::TearDownContext(const base::string16& input, - void* context) { - delete static_cast<std::string*>(context); -} - -// static -size_t StringTraits<base::string16>::GetSize(const base::string16& input, - void* context) { - return static_cast<std::string*>(context)->size(); -} - -// static -const char* StringTraits<base::string16>::GetData(const base::string16& input, - void* context) { - return static_cast<std::string*>(context)->data(); -} - -// static -bool StringTraits<base::string16>::Read(StringDataView input, - base::string16* output) { - return base::UTF8ToUTF16(input.storage(), input.size(), output); -} - -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/string_traits_wtf.cc b/mojo/public/cpp/bindings/lib/string_traits_wtf.cc index 203f6f5903..71b758c49c 100644 --- a/mojo/public/cpp/bindings/lib/string_traits_wtf.cc +++ b/mojo/public/cpp/bindings/lib/string_traits_wtf.cc @@ -8,7 +8,8 @@ #include "base/logging.h" #include "mojo/public/cpp/bindings/lib/array_internal.h" -#include "third_party/WebKit/Source/wtf/text/StringUTF8Adaptor.h" +#include "mojo/public/cpp/bindings/string_data_view.h" +#include "third_party/blink/renderer/platform/wtf/text/string_utf8_adaptor.h" namespace mojo { namespace { @@ -16,7 +17,7 @@ namespace { struct UTF8AdaptorInfo { explicit UTF8AdaptorInfo(const WTF::String& input) : utf8_adaptor(input) { #if DCHECK_IS_ON() - original_size_in_bytes = input.charactersSizeInBytes(); + original_size_in_bytes = input.CharactersSizeInBytes(); #endif } @@ -34,7 +35,7 @@ UTF8AdaptorInfo* ToAdaptor(const WTF::String& input, void* context) { UTF8AdaptorInfo* adaptor = static_cast<UTF8AdaptorInfo*>(context); #if DCHECK_IS_ON() - DCHECK_EQ(adaptor->original_size_in_bytes, input.charactersSizeInBytes()); + DCHECK_EQ(adaptor->original_size_in_bytes, input.CharactersSizeInBytes()); #endif return adaptor; } @@ -43,7 +44,7 @@ UTF8AdaptorInfo* ToAdaptor(const WTF::String& input, void* context) { // static void StringTraits<WTF::String>::SetToNull(WTF::String* output) { - if (output->isNull()) + if (output->IsNull()) return; WTF::String result; @@ -70,13 +71,13 @@ size_t StringTraits<WTF::String>::GetSize(const WTF::String& input, // static const char* StringTraits<WTF::String>::GetData(const WTF::String& input, void* context) { - return ToAdaptor(input, context)->utf8_adaptor.data(); + return ToAdaptor(input, context)->utf8_adaptor.Data(); } // static bool StringTraits<WTF::String>::Read(StringDataView input, WTF::String* output) { - WTF::String result = WTF::String::fromUTF8(input.storage(), input.size()); + WTF::String result = WTF::String::FromUTF8(input.storage(), input.size()); output->swap(result); return true; } diff --git a/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc b/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc index 585a8f094c..2b359861d7 100644 --- a/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc +++ b/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc @@ -7,85 +7,79 @@ #if ENABLE_SYNC_CALL_RESTRICTIONS #include "base/debug/leak_annotations.h" -#include "base/lazy_instance.h" #include "base/logging.h" -#include "base/threading/thread_local.h" +#include "base/macros.h" +#include "base/no_destructor.h" +#include "base/synchronization/lock.h" +#include "base/threading/sequence_local_storage_slot.h" #include "mojo/public/c/system/core.h" namespace mojo { namespace { -class SyncCallSettings { +class GlobalSyncCallSettings { public: - static SyncCallSettings* current(); + GlobalSyncCallSettings() = default; + ~GlobalSyncCallSettings() = default; - bool allowed() const { - return scoped_allow_count_ > 0 || system_defined_value_; + bool sync_call_allowed_by_default() const { + base::AutoLock lock(lock_); + return sync_call_allowed_by_default_; } - void IncreaseScopedAllowCount() { scoped_allow_count_++; } - void DecreaseScopedAllowCount() { - DCHECK_LT(0u, scoped_allow_count_); - scoped_allow_count_--; + void DisallowSyncCallByDefault() { + base::AutoLock lock(lock_); + sync_call_allowed_by_default_ = false; } private: - SyncCallSettings(); - ~SyncCallSettings(); + mutable base::Lock lock_; + bool sync_call_allowed_by_default_ = true; - bool system_defined_value_ = true; - size_t scoped_allow_count_ = 0; + DISALLOW_COPY_AND_ASSIGN(GlobalSyncCallSettings); }; -base::LazyInstance<base::ThreadLocalPointer<SyncCallSettings>>::DestructorAtExit - g_sync_call_settings = LAZY_INSTANCE_INITIALIZER; - -// static -SyncCallSettings* SyncCallSettings::current() { - SyncCallSettings* result = g_sync_call_settings.Pointer()->Get(); - if (!result) { - result = new SyncCallSettings(); - ANNOTATE_LEAKING_OBJECT_PTR(result); - DCHECK_EQ(result, g_sync_call_settings.Pointer()->Get()); - } - return result; -} - -SyncCallSettings::SyncCallSettings() { - MojoResult result = MojoGetProperty(MOJO_PROPERTY_TYPE_SYNC_CALL_ALLOWED, - &system_defined_value_); - DCHECK_EQ(MOJO_RESULT_OK, result); - - DCHECK(!g_sync_call_settings.Pointer()->Get()); - g_sync_call_settings.Pointer()->Set(this); +GlobalSyncCallSettings& GetGlobalSettings() { + static base::NoDestructor<GlobalSyncCallSettings> global_settings; + return *global_settings; } -SyncCallSettings::~SyncCallSettings() { - g_sync_call_settings.Pointer()->Set(nullptr); +size_t& GetSequenceLocalScopedAllowCount() { + static base::NoDestructor<base::SequenceLocalStorageSlot<size_t>> count; + return count->Get(); } } // namespace // static void SyncCallRestrictions::AssertSyncCallAllowed() { - if (!SyncCallSettings::current()->allowed()) { - LOG(FATAL) << "Mojo sync calls are not allowed in this process because " - << "they can lead to jank and deadlock. If you must make an " - << "exception, please see " - << "SyncCallRestrictions::ScopedAllowSyncCall and consult " - << "mojo/OWNERS."; - } + if (GetGlobalSettings().sync_call_allowed_by_default()) + return; + if (GetSequenceLocalScopedAllowCount() > 0) + return; + + LOG(FATAL) << "Mojo sync calls are not allowed in this process because " + << "they can lead to jank and deadlock. If you must make an " + << "exception, please see " + << "SyncCallRestrictions::ScopedAllowSyncCall and consult " + << "mojo/OWNERS."; +} + +// static +void SyncCallRestrictions::DisallowSyncCall() { + GetGlobalSettings().DisallowSyncCallByDefault(); } // static void SyncCallRestrictions::IncreaseScopedAllowCount() { - SyncCallSettings::current()->IncreaseScopedAllowCount(); + ++GetSequenceLocalScopedAllowCount(); } // static void SyncCallRestrictions::DecreaseScopedAllowCount() { - SyncCallSettings::current()->DecreaseScopedAllowCount(); + DCHECK_GT(GetSequenceLocalScopedAllowCount(), 0u); + --GetSequenceLocalScopedAllowCount(); } } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/sync_event_watcher.cc b/mojo/public/cpp/bindings/lib/sync_event_watcher.cc index b1c97e3691..17165912fc 100644 --- a/mojo/public/cpp/bindings/lib/sync_event_watcher.cc +++ b/mojo/public/cpp/bindings/lib/sync_event_watcher.cc @@ -4,6 +4,9 @@ #include "mojo/public/cpp/bindings/sync_event_watcher.h" +#include <algorithm> + +#include "base/containers/stack_container.h" #include "base/logging.h" namespace mojo { @@ -16,19 +19,20 @@ SyncEventWatcher::SyncEventWatcher(base::WaitableEvent* event, destroyed_(new base::RefCountedData<bool>(false)) {} SyncEventWatcher::~SyncEventWatcher() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (registered_) - registry_->UnregisterEvent(event_); + registry_->UnregisterEvent(event_, callback_); destroyed_->data = true; } void SyncEventWatcher::AllowWokenUpBySyncWatchOnSameThread() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); IncrementRegisterCount(); } -bool SyncEventWatcher::SyncWatch(const bool* should_stop) { - DCHECK(thread_checker_.CalledOnValidThread()); +bool SyncEventWatcher::SyncWatch(const bool** stop_flags, + size_t num_stop_flags) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); IncrementRegisterCount(); if (!registered_) { DecrementRegisterCount(); @@ -38,8 +42,14 @@ bool SyncEventWatcher::SyncWatch(const bool* should_stop) { // This object may be destroyed during the Wait() call. So we have to preserve // the boolean that Wait uses. auto destroyed = destroyed_; - const bool* should_stop_array[] = {should_stop, &destroyed->data}; - bool result = registry_->Wait(should_stop_array, 2); + + constexpr size_t kFlagStackCapacity = 4; + base::StackVector<const bool*, kFlagStackCapacity> should_stop_array; + should_stop_array.container().push_back(&destroyed->data); + std::copy(stop_flags, stop_flags + num_stop_flags, + std::back_inserter(should_stop_array.container())); + bool result = registry_->Wait(should_stop_array.container().data(), + should_stop_array.container().size()); // This object has been destroyed. if (destroyed->data) @@ -51,15 +61,17 @@ bool SyncEventWatcher::SyncWatch(const bool* should_stop) { void SyncEventWatcher::IncrementRegisterCount() { register_request_count_++; - if (!registered_) - registered_ = registry_->RegisterEvent(event_, callback_); + if (!registered_) { + registry_->RegisterEvent(event_, callback_); + registered_ = true; + } } void SyncEventWatcher::DecrementRegisterCount() { DCHECK_GT(register_request_count_, 0u); register_request_count_--; if (register_request_count_ == 0 && registered_) { - registry_->UnregisterEvent(event_); + registry_->UnregisterEvent(event_, callback_); registered_ = false; } } diff --git a/mojo/public/cpp/bindings/lib/sync_handle_registry.cc b/mojo/public/cpp/bindings/lib/sync_handle_registry.cc index fd3df396ec..2ac4833445 100644 --- a/mojo/public/cpp/bindings/lib/sync_handle_registry.cc +++ b/mojo/public/cpp/bindings/lib/sync_handle_registry.cc @@ -4,27 +4,36 @@ #include "mojo/public/cpp/bindings/sync_handle_registry.h" +#include <algorithm> + #include "base/lazy_instance.h" #include "base/logging.h" #include "base/stl_util.h" -#include "base/threading/thread_local.h" +#include "base/threading/sequence_local_storage_slot.h" +#include "base/threading/sequenced_task_runner_handle.h" #include "mojo/public/c/system/core.h" namespace mojo { namespace { -base::LazyInstance<base::ThreadLocalPointer<SyncHandleRegistry>>::Leaky +base::LazyInstance< + base::SequenceLocalStorageSlot<scoped_refptr<SyncHandleRegistry>>>::Leaky g_current_sync_handle_watcher = LAZY_INSTANCE_INITIALIZER; } // namespace // static scoped_refptr<SyncHandleRegistry> SyncHandleRegistry::current() { - scoped_refptr<SyncHandleRegistry> result( - g_current_sync_handle_watcher.Pointer()->Get()); + // SyncMessageFilter can be used on threads without sequence-local storage + // being available. Those receive a unique, standalone SyncHandleRegistry. + if (!base::SequencedTaskRunnerHandle::IsSet()) + return new SyncHandleRegistry(); + + scoped_refptr<SyncHandleRegistry> result = + g_current_sync_handle_watcher.Get().Get(); if (!result) { result = new SyncHandleRegistry(); - DCHECK_EQ(result.get(), g_current_sync_handle_watcher.Pointer()->Get()); + g_current_sync_handle_watcher.Get().Set(result); } return result; } @@ -32,7 +41,7 @@ scoped_refptr<SyncHandleRegistry> SyncHandleRegistry::current() { bool SyncHandleRegistry::RegisterHandle(const Handle& handle, MojoHandleSignals handle_signals, const HandleCallback& callback) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (base::ContainsKey(handles_, handle)) return false; @@ -46,7 +55,7 @@ bool SyncHandleRegistry::RegisterHandle(const Handle& handle, } void SyncHandleRegistry::UnregisterHandle(const Handle& handle) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!base::ContainsKey(handles_, handle)) return; @@ -55,27 +64,63 @@ void SyncHandleRegistry::UnregisterHandle(const Handle& handle) { handles_.erase(handle); } -bool SyncHandleRegistry::RegisterEvent(base::WaitableEvent* event, +void SyncHandleRegistry::RegisterEvent(base::WaitableEvent* event, const base::Closure& callback) { - auto result = events_.insert({event, callback}); - DCHECK(result.second); - MojoResult rv = wait_set_.AddEvent(event); - if (rv == MOJO_RESULT_OK) - return true; - DCHECK_EQ(MOJO_RESULT_ALREADY_EXISTS, rv); - return false; + auto it = events_.find(event); + if (it == events_.end()) { + auto result = events_.emplace(event, EventCallbackList{}); + it = result.first; + } + + // The event may already be in the WaitSet, but we don't care. This will be a + // no-op in that case, which is more efficient than scanning the list of + // callbacks to see if any are valid. + wait_set_.AddEvent(event); + + it->second.container().push_back(callback); } -void SyncHandleRegistry::UnregisterEvent(base::WaitableEvent* event) { +void SyncHandleRegistry::UnregisterEvent(base::WaitableEvent* event, + const base::Closure& callback) { auto it = events_.find(event); - DCHECK(it != events_.end()); - events_.erase(it); - MojoResult rv = wait_set_.RemoveEvent(event); - DCHECK_EQ(MOJO_RESULT_OK, rv); + if (it == events_.end()) + return; + + bool has_valid_callbacks = false; + auto& callbacks = it->second.container(); + if (is_dispatching_event_callbacks_) { + // Not safe to remove any elements from |callbacks| here since an outer + // stack frame is currently iterating over it in Wait(). + for (auto& cb : callbacks) { + if (cb.Equals(callback)) + cb.Reset(); + else if (cb) + has_valid_callbacks = true; + } + remove_invalid_event_callbacks_after_dispatch_ = true; + } else { + callbacks.erase(std::remove_if(callbacks.begin(), callbacks.end(), + [&callback](const base::Closure& cb) { + return cb.Equals(callback); + }), + callbacks.end()); + if (callbacks.empty()) + events_.erase(it); + else + has_valid_callbacks = true; + } + + if (!has_valid_callbacks) { + // Regardless of whether or not we're nested within a Wait(), we need to + // ensure that |event| is removed from the WaitSet before returning if this + // was the last callback registered for it. + MojoResult rv = wait_set_.RemoveEvent(event); + DCHECK_EQ(MOJO_RESULT_OK, rv); + } } bool SyncHandleRegistry::Wait(const bool* should_stop[], size_t count) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); size_t num_ready_handles; Handle ready_handle; @@ -83,9 +128,10 @@ bool SyncHandleRegistry::Wait(const bool* should_stop[], size_t count) { scoped_refptr<SyncHandleRegistry> preserver(this); while (true) { - for (size_t i = 0; i < count; ++i) + for (size_t i = 0; i < count; ++i) { if (*should_stop[i]) return true; + } // TODO(yzshen): Theoretically it can reduce sync call re-entrancy if we // give priority to the handle that is waiting for sync response. @@ -102,34 +148,51 @@ bool SyncHandleRegistry::Wait(const bool* should_stop[], size_t count) { if (ready_event) { const auto iter = events_.find(ready_event); DCHECK(iter != events_.end()); - iter->second.Run(); + bool was_dispatching_event_callbacks = is_dispatching_event_callbacks_; + is_dispatching_event_callbacks_ = true; + + // NOTE: It's possible for the container to be extended by any of these + // callbacks if they call RegisterEvent, so we are careful to iterate by + // index. Also note that conversely, elements cannot be *removed* from the + // container, by any of these callbacks, so it is safe to assume the size + // only stays the same or increases, with no elements changing position. + auto& callbacks = iter->second.container(); + for (size_t i = 0; i < callbacks.size(); ++i) { + auto& callback = callbacks[i]; + if (callback) + callback.Run(); + } + + is_dispatching_event_callbacks_ = was_dispatching_event_callbacks; + if (!was_dispatching_event_callbacks && + remove_invalid_event_callbacks_after_dispatch_) { + // If we've had events unregistered within any callback dispatch, now is + // a good time to prune them from the map. + RemoveInvalidEventCallbacks(); + remove_invalid_event_callbacks_after_dispatch_ = false; + } } }; return false; } -SyncHandleRegistry::SyncHandleRegistry() { - DCHECK(!g_current_sync_handle_watcher.Pointer()->Get()); - g_current_sync_handle_watcher.Pointer()->Set(this); -} - -SyncHandleRegistry::~SyncHandleRegistry() { - DCHECK(thread_checker_.CalledOnValidThread()); - - // This object may be destructed after the thread local storage slot used by - // |g_current_sync_handle_watcher| is reset during thread shutdown. - // For example, another slot in the thread local storage holds a referrence to - // this object, and that slot is cleaned up after - // |g_current_sync_handle_watcher|. - if (!g_current_sync_handle_watcher.Pointer()->Get()) - return; - - // If this breaks, it is likely that the global variable is bulit into and - // accessed from multiple modules. - DCHECK_EQ(this, g_current_sync_handle_watcher.Pointer()->Get()); - - g_current_sync_handle_watcher.Pointer()->Set(nullptr); +SyncHandleRegistry::SyncHandleRegistry() = default; + +SyncHandleRegistry::~SyncHandleRegistry() = default; + +void SyncHandleRegistry::RemoveInvalidEventCallbacks() { + for (auto it = events_.begin(); it != events_.end();) { + auto& callbacks = it->second.container(); + callbacks.erase( + std::remove_if(callbacks.begin(), callbacks.end(), + [](const base::Closure& callback) { return !callback; }), + callbacks.end()); + if (callbacks.empty()) + events_.erase(it++); + else + ++it; + } } } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc b/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc index f20af56b20..294b8a1a4b 100644 --- a/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc +++ b/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc @@ -21,7 +21,7 @@ SyncHandleWatcher::SyncHandleWatcher( destroyed_(new base::RefCountedData<bool>(false)) {} SyncHandleWatcher::~SyncHandleWatcher() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (registered_) registry_->UnregisterHandle(handle_); @@ -29,12 +29,12 @@ SyncHandleWatcher::~SyncHandleWatcher() { } void SyncHandleWatcher::AllowWokenUpBySyncWatchOnSameThread() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); IncrementRegisterCount(); } bool SyncHandleWatcher::SyncWatch(const bool* should_stop) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); IncrementRegisterCount(); if (!registered_) { DecrementRegisterCount(); diff --git a/mojo/public/cpp/bindings/lib/task_runner_helper.cc b/mojo/public/cpp/bindings/lib/task_runner_helper.cc new file mode 100644 index 0000000000..6104a9740e --- /dev/null +++ b/mojo/public/cpp/bindings/lib/task_runner_helper.cc @@ -0,0 +1,24 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + +#include "base/sequenced_task_runner.h" +#include "base/threading/sequenced_task_runner_handle.h" + +namespace mojo { +namespace internal { + +scoped_refptr<base::SequencedTaskRunner> +GetTaskRunnerToUseFromUserProvidedTaskRunner( + scoped_refptr<base::SequencedTaskRunner> runner) { + if (runner) { + DCHECK(runner->RunsTasksInCurrentSequence()); + return runner; + } + return base::SequencedTaskRunnerHandle::Get(); +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/task_runner_helper.h b/mojo/public/cpp/bindings/lib/task_runner_helper.h new file mode 100644 index 0000000000..d34d179675 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/task_runner_helper.h @@ -0,0 +1,28 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_TASK_RUNNER_HELPER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_TASK_RUNNER_HELPER_H_ + +#include "base/memory/ref_counted.h" + +namespace base { +class SequencedTaskRunner; +} // namespace base + +namespace mojo { +namespace internal { + +// Returns the SequencedTaskRunner to use from the optional user-provided +// SequencedTaskRunner. If |runner| is provided non-null, it is returned. +// Otherwise, SequencedTaskRunnerHandle::Get() is returned. If |runner| is non- +// null, it must run tasks on the current sequence. +scoped_refptr<base::SequencedTaskRunner> +GetTaskRunnerToUseFromUserProvidedTaskRunner( + scoped_refptr<base::SequencedTaskRunner> runner); + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_TASK_RUNNER_HELPER_H_ diff --git a/mojo/public/cpp/bindings/lib/template_util.h b/mojo/public/cpp/bindings/lib/template_util.h index 5151123ac0..383eb91593 100644 --- a/mojo/public/cpp/bindings/lib/template_util.h +++ b/mojo/public/cpp/bindings/lib/template_util.h @@ -114,6 +114,11 @@ struct Conditional<false, T, F> { typedef F type; }; +template <typename T> +struct AlwaysFalse { + static const bool value = false; +}; + } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/union_accessor.h b/mojo/public/cpp/bindings/lib/union_accessor.h deleted file mode 100644 index 821aede595..0000000000 --- a/mojo/public/cpp/bindings/lib/union_accessor.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2015 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_UNION_ACCESSOR_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_UNION_ACCESSOR_H_ - -namespace mojo { -namespace internal { - -// When serializing and deserializing Unions, it is necessary to access -// the private fields and methods of the Union. This allows us to do that -// without leaking those same fields and methods in the Union interface. -// All Union wrappers are friends of this class allowing such access. -template <typename U> -class UnionAccessor { - public: - explicit UnionAccessor(U* u) : u_(u) {} - - typename U::Union_* data() { return &(u_->data_); } - - typename U::Tag* tag() { return &(u_->tag_); } - - void SwitchActive(typename U::Tag new_tag) { u_->SwitchActive(new_tag); } - - private: - U* u_; -}; - -} // namespace internal -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_UNION_ACCESSOR_H_ diff --git a/mojo/public/cpp/bindings/lib/unserialized_message_context.cc b/mojo/public/cpp/bindings/lib/unserialized_message_context.cc new file mode 100644 index 0000000000..b029f4ef00 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/unserialized_message_context.cc @@ -0,0 +1,24 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/unserialized_message_context.h" + +namespace mojo { +namespace internal { + +UnserializedMessageContext::UnserializedMessageContext(const Tag* tag, + uint32_t message_name, + uint32_t message_flags) + : tag_(tag) { + header_.interface_id = 0; + header_.version = 1; + header_.name = message_name; + header_.flags = message_flags; + header_.num_bytes = sizeof(header_); +} + +UnserializedMessageContext::~UnserializedMessageContext() = default; + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/unserialized_message_context.h b/mojo/public/cpp/bindings/lib/unserialized_message_context.h new file mode 100644 index 0000000000..4886a981dc --- /dev/null +++ b/mojo/public/cpp/bindings/lib/unserialized_message_context.h @@ -0,0 +1,63 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_UNSERIALIZED_MESSAGE_CONTEXT_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_UNSERIALIZED_MESSAGE_CONTEXT_H_ + +#include <stdint.h> + +#include "base/component_export.h" +#include "base/macros.h" +#include "base/optional.h" +#include "mojo/public/c/system/types.h" +#include "mojo/public/cpp/bindings/lib/buffer.h" +#include "mojo/public/cpp/bindings/lib/message_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_context.h" + +namespace mojo { +namespace internal { + +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) UnserializedMessageContext { + public: + struct Tag {}; + + UnserializedMessageContext(const Tag* tag, + uint32_t message_name, + uint32_t message_flags); + virtual ~UnserializedMessageContext(); + + template <typename MessageType> + MessageType* SafeCast() { + if (&MessageType::kMessageTag != tag_) + return nullptr; + return static_cast<MessageType*>(this); + } + + const Tag* tag() const { return tag_; } + uint32_t message_name() const { return header_.name; } + uint32_t message_flags() const { return header_.flags; } + + MessageHeaderV1* header() { return &header_; } + + virtual void Serialize(SerializationContext* serialization_context, + Buffer* buffer) = 0; + + private: + // The |tag_| is used for run-time type identification of specific + // unserialized message types, e.g. messages generated by mojom bindings. This + // allows opaque message objects to be safely downcast once pulled off a pipe. + const Tag* const tag_; + + // We store message metadata in a serialized header structure to simplify + // Message implementation which needs to query such metadata for both + // serialized and unserialized message objects. + MessageHeaderV1 header_; + + DISALLOW_COPY_AND_ASSIGN(UnserializedMessageContext); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_UNSERIALIZED_MESSAGE_CONTEXT_H_ diff --git a/mojo/public/cpp/bindings/lib/validation_context.h b/mojo/public/cpp/bindings/lib/validation_context.h index ed6c6542e7..7c4de47327 100644 --- a/mojo/public/cpp/bindings/lib/validation_context.h +++ b/mojo/public/cpp/bindings/lib/validation_context.h @@ -9,9 +9,9 @@ #include <stdint.h> #include "base/compiler_specific.h" +#include "base/component_export.h" #include "base/macros.h" #include "base/strings/string_piece.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" static const int kMaxRecursionDepth = 100; @@ -24,7 +24,7 @@ namespace internal { // ValidationContext is used when validating object sizes, pointers and handle // indices in the payload of incoming messages. -class MOJO_CPP_BINDINGS_EXPORT ValidationContext { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) ValidationContext { public: // [data, data + data_num_bytes) specifies the initial valid memory range. // [0, num_handles) specifies the initial valid range of handle indices. diff --git a/mojo/public/cpp/bindings/lib/validation_errors.h b/mojo/public/cpp/bindings/lib/validation_errors.h index 122418d9e3..e48e37c6b6 100644 --- a/mojo/public/cpp/bindings/lib/validation_errors.h +++ b/mojo/public/cpp/bindings/lib/validation_errors.h @@ -6,9 +6,9 @@ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_ERRORS_H_ #include "base/callback.h" +#include "base/component_export.h" #include "base/logging.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/validation_context.h" namespace mojo { @@ -76,23 +76,24 @@ enum ValidationError { VALIDATION_ERROR_MAX_RECURSION_DEPTH, }; -MOJO_CPP_BINDINGS_EXPORT const char* ValidationErrorToString( - ValidationError error); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +const char* ValidationErrorToString(ValidationError error); -MOJO_CPP_BINDINGS_EXPORT void ReportValidationError( - ValidationContext* context, - ValidationError error, - const char* description = nullptr); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +void ReportValidationError(ValidationContext* context, + ValidationError error, + const char* description = nullptr); -MOJO_CPP_BINDINGS_EXPORT void ReportValidationErrorForMessage( - mojo::Message* message, - ValidationError error, - const char* description = nullptr); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +void ReportValidationErrorForMessage(mojo::Message* message, + ValidationError error, + const char* description = nullptr); // This class may be used by tests to suppress validation error logging. This is // not thread-safe and must only be instantiated on the main thread with no // other threads using Mojo bindings at the time of construction or destruction. -class MOJO_CPP_BINDINGS_EXPORT ScopedSuppressValidationErrorLoggingForTests { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) + ScopedSuppressValidationErrorLoggingForTests { public: ScopedSuppressValidationErrorLoggingForTests(); ~ScopedSuppressValidationErrorLoggingForTests(); @@ -105,7 +106,8 @@ class MOJO_CPP_BINDINGS_EXPORT ScopedSuppressValidationErrorLoggingForTests { // Only used by validation tests and when there is only one thread doing message // validation. -class MOJO_CPP_BINDINGS_EXPORT ValidationErrorObserverForTesting { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) + ValidationErrorObserverForTesting { public: explicit ValidationErrorObserverForTesting(const base::Closure& callback); ~ValidationErrorObserverForTesting(); @@ -127,11 +129,13 @@ class MOJO_CPP_BINDINGS_EXPORT ValidationErrorObserverForTesting { // // The function returns true if the error is recorded (by a // SerializationWarningObserverForTesting object), false otherwise. -MOJO_CPP_BINDINGS_EXPORT bool ReportSerializationWarning(ValidationError error); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ReportSerializationWarning(ValidationError error); // Only used by serialization tests and when there is only one thread doing // message serialization. -class MOJO_CPP_BINDINGS_EXPORT SerializationWarningObserverForTesting { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) + SerializationWarningObserverForTesting { public: SerializationWarningObserverForTesting(); ~SerializationWarningObserverForTesting(); diff --git a/mojo/public/cpp/bindings/lib/validation_util.cc b/mojo/public/cpp/bindings/lib/validation_util.cc index 7614df5cbc..4b414c4e3b 100644 --- a/mojo/public/cpp/bindings/lib/validation_util.cc +++ b/mojo/public/cpp/bindings/lib/validation_util.cc @@ -8,14 +8,25 @@ #include <limits> +#include "base/strings/stringprintf.h" #include "mojo/public/cpp/bindings/lib/message_internal.h" #include "mojo/public/cpp/bindings/lib/serialization_util.h" #include "mojo/public/cpp/bindings/lib/validation_errors.h" -#include "mojo/public/interfaces/bindings/interface_control_messages.mojom.h" namespace mojo { namespace internal { +void ReportNonNullableValidationError(ValidationContext* validation_context, + ValidationError error, + int field_index) { + const char* null_or_invalid = + error == VALIDATION_ERROR_UNEXPECTED_NULL_POINTER ? "null" : "invalid"; + + std::string error_message = + base::StringPrintf("%s field %d", null_or_invalid, field_index); + ReportValidationError(validation_context, error, error_message.c_str()); +} + bool ValidateStructHeaderAndClaimMemory(const void* data, ValidationContext* validation_context) { if (!IsAligned(data)) { @@ -118,53 +129,53 @@ bool IsHandleOrInterfaceValid(const Handle_Data& input) { bool ValidateHandleOrInterfaceNonNullable( const AssociatedInterface_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (IsHandleOrInterfaceValid(input)) return true; - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, - error_message); + ReportNonNullableValidationError( + validation_context, VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, + field_index); return false; } bool ValidateHandleOrInterfaceNonNullable( const AssociatedEndpointHandle_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (IsHandleOrInterfaceValid(input)) return true; - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, - error_message); + ReportNonNullableValidationError( + validation_context, VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, + field_index); return false; } bool ValidateHandleOrInterfaceNonNullable( const Interface_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (IsHandleOrInterfaceValid(input)) return true; - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, - error_message); + ReportNonNullableValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, + field_index); return false; } bool ValidateHandleOrInterfaceNonNullable( const Handle_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (IsHandleOrInterfaceValid(input)) return true; - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, - error_message); + ReportNonNullableValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, + field_index); return false; } diff --git a/mojo/public/cpp/bindings/lib/validation_util.h b/mojo/public/cpp/bindings/lib/validation_util.h index ea5a991668..3b88956f7a 100644 --- a/mojo/public/cpp/bindings/lib/validation_util.h +++ b/mojo/public/cpp/bindings/lib/validation_util.h @@ -7,7 +7,7 @@ #include <stdint.h> -#include "mojo/public/cpp/bindings/bindings_export.h" +#include "base/component_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" #include "mojo/public/cpp/bindings/lib/serialization_util.h" #include "mojo/public/cpp/bindings/lib/validate_params.h" @@ -18,6 +18,12 @@ namespace mojo { namespace internal { +// Calls ReportValidationError() with a constructed error string. +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +void ReportNonNullableValidationError(ValidationContext* validation_context, + ValidationError error, + int field_index); + // Checks whether decoding the pointer will overflow and produce a pointer // smaller than |offset|. inline bool ValidateEncodedPointer(const uint64_t* offset) { @@ -47,32 +53,35 @@ bool ValidatePointer(const Pointer<T>& input, // |validation_context|. On success, the memory range is marked as occupied. // Note: Does not verify |version| or that |num_bytes| is correct for the // claimed version. -MOJO_CPP_BINDINGS_EXPORT bool ValidateStructHeaderAndClaimMemory( - const void* data, - ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateStructHeaderAndClaimMemory(const void* data, + ValidationContext* validation_context); // Validates that |data| contains a valid union header, in terms of alignment // and size. It checks that the memory range [data, data + kUnionDataSize) is // not marked as occupied by other objects in |validation_context|. On success, // the memory range is marked as occupied. -MOJO_CPP_BINDINGS_EXPORT bool ValidateNonInlinedUnionHeaderAndClaimMemory( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateNonInlinedUnionHeaderAndClaimMemory( const void* data, ValidationContext* validation_context); // Validates that the message is a request which doesn't expect a response. -MOJO_CPP_BINDINGS_EXPORT bool ValidateMessageIsRequestWithoutResponse( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateMessageIsRequestWithoutResponse( const Message* message, ValidationContext* validation_context); // Validates that the message is a request expecting a response. -MOJO_CPP_BINDINGS_EXPORT bool ValidateMessageIsRequestExpectingResponse( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateMessageIsRequestExpectingResponse( const Message* message, ValidationContext* validation_context); // Validates that the message is a response. -MOJO_CPP_BINDINGS_EXPORT bool ValidateMessageIsResponse( - const Message* message, - ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateMessageIsResponse(const Message* message, + ValidationContext* validation_context); // Validates that the message payload is a valid struct of type ParamsType. template <typename ParamsType> @@ -85,54 +94,56 @@ bool ValidateMessagePayload(const Message* message, // |input| is not null/invalid. template <typename T> bool ValidatePointerNonNullable(const T& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (input.offset) return true; - - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, - error_message); + ReportNonNullableValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + field_index); return false; } template <typename T> bool ValidateInlinedUnionNonNullable(const T& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (!input.is_null()) return true; - - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, - error_message); + ReportNonNullableValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + field_index); return false; } -MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( - const AssociatedInterface_Data& input); -MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( - const AssociatedEndpointHandle_Data& input); -MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( - const Interface_Data& input); -MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( - const Handle_Data& input); - -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool IsHandleOrInterfaceValid(const AssociatedInterface_Data& input); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool IsHandleOrInterfaceValid(const AssociatedEndpointHandle_Data& input); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool IsHandleOrInterfaceValid(const Interface_Data& input); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool IsHandleOrInterfaceValid(const Handle_Data& input); + +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterfaceNonNullable( const AssociatedInterface_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterfaceNonNullable( const AssociatedEndpointHandle_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterfaceNonNullable( const Interface_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterfaceNonNullable( const Handle_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context); template <typename T> @@ -187,18 +198,18 @@ bool ValidateNonInlinedUnion(const Pointer<T>& input, T::Validate(input.Get(), validation_context, false); } -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( - const AssociatedInterface_Data& input, - ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( - const AssociatedEndpointHandle_Data& input, - ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( - const Interface_Data& input, - ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( - const Handle_Data& input, - ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterface(const AssociatedInterface_Data& input, + ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterface(const AssociatedEndpointHandle_Data& input, + ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterface(const Interface_Data& input, + ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterface(const Handle_Data& input, + ValidationContext* validation_context); } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h b/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h index cb24bc46ee..bb0ee531f5 100644 --- a/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h +++ b/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h @@ -8,11 +8,10 @@ #include <type_traits> #include "mojo/public/cpp/bindings/clone_traits.h" -#include "mojo/public/cpp/bindings/lib/equals_traits.h" -#include "third_party/WebKit/Source/wtf/HashMap.h" -#include "third_party/WebKit/Source/wtf/Optional.h" -#include "third_party/WebKit/Source/wtf/Vector.h" -#include "third_party/WebKit/Source/wtf/text/WTFString.h" +#include "mojo/public/cpp/bindings/equals_traits.h" +#include "third_party/blink/renderer/platform/wtf/hash_map.h" +#include "third_party/blink/renderer/platform/wtf/text/wtf_string.h" +#include "third_party/blink/renderer/platform/wtf/vector.h" namespace mojo { @@ -20,7 +19,7 @@ template <typename T> struct CloneTraits<WTF::Vector<T>, false> { static WTF::Vector<T> Clone(const WTF::Vector<T>& input) { WTF::Vector<T> result; - result.reserveCapacity(input.size()); + result.ReserveCapacity(input.size()); for (const auto& element : input) result.push_back(mojo::Clone(element)); @@ -32,22 +31,20 @@ template <typename K, typename V> struct CloneTraits<WTF::HashMap<K, V>, false> { static WTF::HashMap<K, V> Clone(const WTF::HashMap<K, V>& input) { WTF::HashMap<K, V> result; - auto input_end = input.end(); - for (auto it = input.begin(); it != input_end; ++it) - result.add(mojo::Clone(it->key), mojo::Clone(it->value)); + for (const auto& element : input) + result.insert(mojo::Clone(element.key), mojo::Clone(element.value)); + return result; } }; -namespace internal { - template <typename T> struct EqualsTraits<WTF::Vector<T>, false> { static bool Equals(const WTF::Vector<T>& a, const WTF::Vector<T>& b) { if (a.size() != b.size()) return false; for (size_t i = 0; i < a.size(); ++i) { - if (!internal::Equals(a[i], b[i])) + if (!mojo::Equals(a[i], b[i])) return false; } return true; @@ -65,14 +62,13 @@ struct EqualsTraits<WTF::HashMap<K, V>, false> { for (auto iter = a.begin(); iter != a_end; ++iter) { auto b_iter = b.find(iter->key); - if (b_iter == b_end || !internal::Equals(iter->value, b_iter->value)) + if (b_iter == b_end || !mojo::Equals(iter->value, b_iter->value)) return false; } return true; } }; -} // namespace internal } // namespace mojo #endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_CLONE_EQUALS_UTIL_H_ diff --git a/mojo/public/cpp/bindings/lib/wtf_hash_util.h b/mojo/public/cpp/bindings/lib/wtf_hash_util.h index cc590da67a..fa02262e8e 100644 --- a/mojo/public/cpp/bindings/lib/wtf_hash_util.h +++ b/mojo/public/cpp/bindings/lib/wtf_hash_util.h @@ -9,9 +9,9 @@ #include "mojo/public/cpp/bindings/lib/hash_util.h" #include "mojo/public/cpp/bindings/struct_ptr.h" -#include "third_party/WebKit/Source/wtf/HashFunctions.h" -#include "third_party/WebKit/Source/wtf/text/StringHash.h" -#include "third_party/WebKit/Source/wtf/text/WTFString.h" +#include "third_party/blink/renderer/platform/wtf/hash_functions.h" +#include "third_party/blink/renderer/platform/wtf/text/string_hash.h" +#include "third_party/blink/renderer/platform/wtf/text/wtf_string.h" namespace mojo { namespace internal { @@ -48,7 +48,7 @@ struct WTFHashTraits<T, false> { template <> struct WTFHashTraits<WTF::String, false> { static size_t Hash(size_t seed, const WTF::String& value) { - return HashCombine(seed, WTF::StringHash::hash(value)); + return HashCombine(seed, WTF::StringHash::GetHash(value)); } }; @@ -59,25 +59,25 @@ size_t WTFHash(size_t seed, const T& value) { template <typename T> struct StructPtrHashFn { - static unsigned hash(const StructPtr<T>& value) { + static unsigned GetHash(const StructPtr<T>& value) { return value.Hash(kHashSeed); } - static bool equal(const StructPtr<T>& left, const StructPtr<T>& right) { + static bool Equal(const StructPtr<T>& left, const StructPtr<T>& right) { return left.Equals(right); } - static const bool safeToCompareToEmptyOrDeleted = false; + static const bool safe_to_compare_to_empty_or_deleted = false; }; template <typename T> struct InlinedStructPtrHashFn { - static unsigned hash(const InlinedStructPtr<T>& value) { + static unsigned GetHash(const InlinedStructPtr<T>& value) { return value.Hash(kHashSeed); } - static bool equal(const InlinedStructPtr<T>& left, + static bool Equal(const InlinedStructPtr<T>& left, const InlinedStructPtr<T>& right) { return left.Equals(right); } - static const bool safeToCompareToEmptyOrDeleted = false; + static const bool safe_to_compare_to_empty_or_deleted = false; }; } // namespace internal @@ -93,14 +93,14 @@ struct DefaultHash<mojo::StructPtr<T>> { template <typename T> struct HashTraits<mojo::StructPtr<T>> : public GenericHashTraits<mojo::StructPtr<T>> { - static const bool hasIsEmptyValueFunction = true; - static bool isEmptyValue(const mojo::StructPtr<T>& value) { + static const bool kHasIsEmptyValueFunction = true; + static bool IsEmptyValue(const mojo::StructPtr<T>& value) { return value.is_null(); } - static void constructDeletedValue(mojo::StructPtr<T>& slot, bool) { + static void ConstructDeletedValue(mojo::StructPtr<T>& slot, bool) { mojo::internal::StructPtrWTFHelper<T>::ConstructDeletedValue(slot); } - static bool isDeletedValue(const mojo::StructPtr<T>& value) { + static bool IsDeletedValue(const mojo::StructPtr<T>& value) { return mojo::internal::StructPtrWTFHelper<T>::IsHashTableDeletedValue( value); } @@ -114,14 +114,14 @@ struct DefaultHash<mojo::InlinedStructPtr<T>> { template <typename T> struct HashTraits<mojo::InlinedStructPtr<T>> : public GenericHashTraits<mojo::InlinedStructPtr<T>> { - static const bool hasIsEmptyValueFunction = true; - static bool isEmptyValue(const mojo::InlinedStructPtr<T>& value) { + static const bool kHasIsEmptyValueFunction = true; + static bool IsEmptyValue(const mojo::InlinedStructPtr<T>& value) { return value.is_null(); } - static void constructDeletedValue(mojo::InlinedStructPtr<T>& slot, bool) { + static void ConstructDeletedValue(mojo::InlinedStructPtr<T>& slot, bool) { mojo::internal::InlinedStructPtrWTFHelper<T>::ConstructDeletedValue(slot); } - static bool isDeletedValue(const mojo::InlinedStructPtr<T>& value) { + static bool IsDeletedValue(const mojo::InlinedStructPtr<T>& value) { return mojo::internal::InlinedStructPtrWTFHelper< T>::IsHashTableDeletedValue(value); } |