diff options
Diffstat (limited to 'mojo/public/cpp/bindings/lib')
71 files changed, 9758 insertions, 0 deletions
diff --git a/mojo/public/cpp/bindings/lib/array_internal.cc b/mojo/public/cpp/bindings/lib/array_internal.cc new file mode 100644 index 0000000000..dd24eac470 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/array_internal.cc @@ -0,0 +1,59 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/array_internal.h" + +#include <stddef.h> +#include <stdint.h> + +#include <sstream> + +namespace mojo { +namespace internal { + +std::string MakeMessageWithArrayIndex(const char* message, + size_t size, + size_t index) { + std::ostringstream stream; + stream << message << ": array size - " << size << "; index - " << index; + return stream.str(); +} + +std::string MakeMessageWithExpectedArraySize(const char* message, + size_t size, + size_t expected_size) { + std::ostringstream stream; + stream << message << ": array size - " << size << "; expected size - " + << expected_size; + return stream.str(); +} + +ArrayDataTraits<bool>::BitRef::~BitRef() { +} + +ArrayDataTraits<bool>::BitRef::BitRef(uint8_t* storage, uint8_t mask) + : storage_(storage), mask_(mask) { +} + +ArrayDataTraits<bool>::BitRef& ArrayDataTraits<bool>::BitRef::operator=( + bool value) { + if (value) { + *storage_ |= mask_; + } else { + *storage_ &= ~mask_; + } + return *this; +} + +ArrayDataTraits<bool>::BitRef& ArrayDataTraits<bool>::BitRef::operator=( + const BitRef& value) { + return (*this) = static_cast<bool>(value); +} + +ArrayDataTraits<bool>::BitRef::operator bool() const { + return (*storage_ & mask_) != 0; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/array_internal.h b/mojo/public/cpp/bindings/lib/array_internal.h new file mode 100644 index 0000000000..eecfcfbc28 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/array_internal.h @@ -0,0 +1,368 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_ARRAY_INTERNAL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_ARRAY_INTERNAL_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <limits> +#include <new> + +#include "base/logging.h" +#include "mojo/public/c/system/macros.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/bindings/lib/buffer.h" +#include "mojo/public/cpp/bindings/lib/serialization_util.h" +#include "mojo/public/cpp/bindings/lib/template_util.h" +#include "mojo/public/cpp/bindings/lib/validate_params.h" +#include "mojo/public/cpp/bindings/lib/validation_context.h" +#include "mojo/public/cpp/bindings/lib/validation_errors.h" +#include "mojo/public/cpp/bindings/lib/validation_util.h" + +namespace mojo { +namespace internal { + +template <typename K, typename V> +class Map_Data; + +MOJO_CPP_BINDINGS_EXPORT std::string +MakeMessageWithArrayIndex(const char* message, size_t size, size_t index); + +MOJO_CPP_BINDINGS_EXPORT std::string MakeMessageWithExpectedArraySize( + const char* message, + size_t size, + size_t expected_size); + +template <typename T> +struct ArrayDataTraits { + using StorageType = T; + using Ref = T&; + using ConstRef = const T&; + + static const uint32_t kMaxNumElements = + (std::numeric_limits<uint32_t>::max() - sizeof(ArrayHeader)) / + sizeof(StorageType); + + static uint32_t GetStorageSize(uint32_t num_elements) { + DCHECK(num_elements <= kMaxNumElements); + return sizeof(ArrayHeader) + sizeof(StorageType) * num_elements; + } + static Ref ToRef(StorageType* storage, size_t offset) { + return storage[offset]; + } + static ConstRef ToConstRef(const StorageType* storage, size_t offset) { + return storage[offset]; + } +}; + +// Specialization of Arrays for bools, optimized for space. It has the +// following differences from a generalized Array: +// * Each element takes up a single bit of memory. +// * Accessing a non-const single element uses a helper class |BitRef|, which +// emulates a reference to a bool. +template <> +struct ArrayDataTraits<bool> { + // Helper class to emulate a reference to a bool, used for direct element + // access. + class MOJO_CPP_BINDINGS_EXPORT BitRef { + public: + ~BitRef(); + BitRef& operator=(bool value); + BitRef& operator=(const BitRef& value); + operator bool() const; + + private: + friend struct ArrayDataTraits<bool>; + BitRef(uint8_t* storage, uint8_t mask); + BitRef(); + uint8_t* storage_; + uint8_t mask_; + }; + + // Because each element consumes only 1/8 byte. + static const uint32_t kMaxNumElements = std::numeric_limits<uint32_t>::max(); + + using StorageType = uint8_t; + using Ref = BitRef; + using ConstRef = bool; + + static uint32_t GetStorageSize(uint32_t num_elements) { + return sizeof(ArrayHeader) + ((num_elements + 7) / 8); + } + static BitRef ToRef(StorageType* storage, size_t offset) { + return BitRef(&storage[offset / 8], 1 << (offset % 8)); + } + static bool ToConstRef(const StorageType* storage, size_t offset) { + return (storage[offset / 8] & (1 << (offset % 8))) != 0; + } +}; + +// What follows is code to support the serialization/validation of +// Array_Data<T>. There are four interesting cases: arrays of primitives, +// arrays of handles/interfaces, arrays of objects and arrays of unions. +// Arrays of objects are represented as arrays of pointers to objects. Arrays +// of unions are inlined so they are not pointers, but comparing with primitives +// they require more work for serialization/validation. +// +// TODO(yzshen): Validation code should be organzied in a way similar to +// Serializer<>, or merged into it. It should be templatized with the mojo +// data view type instead of the data type, that way we can use MojomTypeTraits +// to determine the categories. + +template <typename T, bool is_union, bool is_handle_or_interface> +struct ArraySerializationHelper; + +template <typename T> +struct ArraySerializationHelper<T, false, false> { + using ElementType = typename ArrayDataTraits<T>::StorageType; + + static bool ValidateElements(const ArrayHeader* header, + const ElementType* elements, + ValidationContext* validation_context, + const ContainerValidateParams* validate_params) { + DCHECK(!validate_params->element_is_nullable) + << "Primitive type should be non-nullable"; + DCHECK(!validate_params->element_validate_params) + << "Primitive type should not have array validate params"; + + if (!validate_params->validate_enum_func) + return true; + + // Enum validation. + for (uint32_t i = 0; i < header->num_elements; ++i) { + if (!validate_params->validate_enum_func(elements[i], validation_context)) + return false; + } + return true; + } +}; + +template <typename T> +struct ArraySerializationHelper<T, false, true> { + using ElementType = typename ArrayDataTraits<T>::StorageType; + + static bool ValidateElements(const ArrayHeader* header, + const ElementType* elements, + ValidationContext* validation_context, + const ContainerValidateParams* validate_params) { + DCHECK(!validate_params->element_validate_params) + << "Handle or interface type should not have array validate params"; + + for (uint32_t i = 0; i < header->num_elements; ++i) { + if (!validate_params->element_is_nullable && + !IsHandleOrInterfaceValid(elements[i])) { + static const ValidationError kError = + std::is_same<T, Interface_Data>::value || + std::is_same<T, Handle_Data>::value + ? VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE + : VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID; + ReportValidationError( + validation_context, kError, + MakeMessageWithArrayIndex( + "invalid handle or interface ID in array expecting valid " + "handles or interface IDs", + header->num_elements, i) + .c_str()); + return false; + } + if (!ValidateHandleOrInterface(elements[i], validation_context)) + return false; + } + return true; + } +}; + +template <typename T> +struct ArraySerializationHelper<Pointer<T>, false, false> { + using ElementType = typename ArrayDataTraits<Pointer<T>>::StorageType; + + static bool ValidateElements(const ArrayHeader* header, + const ElementType* elements, + ValidationContext* validation_context, + const ContainerValidateParams* validate_params) { + for (uint32_t i = 0; i < header->num_elements; ++i) { + if (!validate_params->element_is_nullable && !elements[i].offset) { + ReportValidationError( + validation_context, + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + MakeMessageWithArrayIndex("null in array expecting valid pointers", + header->num_elements, + i).c_str()); + return false; + } + if (!ValidateCaller<T>::Run(elements[i], validation_context, + validate_params->element_validate_params)) { + return false; + } + } + return true; + } + + private: + template <typename U, + bool is_array_or_map = IsSpecializationOf<Array_Data, U>::value || + IsSpecializationOf<Map_Data, U>::value> + struct ValidateCaller { + static bool Run(const Pointer<U>& data, + ValidationContext* validation_context, + const ContainerValidateParams* validate_params) { + DCHECK(!validate_params) + << "Struct type should not have array validate params"; + + return ValidateStruct(data, validation_context); + } + }; + + template <typename U> + struct ValidateCaller<U, true> { + static bool Run(const Pointer<U>& data, + ValidationContext* validation_context, + const ContainerValidateParams* validate_params) { + return ValidateContainer(data, validation_context, validate_params); + } + }; +}; + +template <typename U> +struct ArraySerializationHelper<U, true, false> { + using ElementType = typename ArrayDataTraits<U>::StorageType; + + static bool ValidateElements(const ArrayHeader* header, + const ElementType* elements, + ValidationContext* validation_context, + const ContainerValidateParams* validate_params) { + for (uint32_t i = 0; i < header->num_elements; ++i) { + if (!validate_params->element_is_nullable && elements[i].is_null()) { + ReportValidationError( + validation_context, + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + MakeMessageWithArrayIndex("null in array expecting valid unions", + header->num_elements, i) + .c_str()); + return false; + } + if (!ValidateInlinedUnion(elements[i], validation_context)) + return false; + } + return true; + } +}; + +template <typename T> +class Array_Data { + public: + using Traits = ArrayDataTraits<T>; + using StorageType = typename Traits::StorageType; + using Ref = typename Traits::Ref; + using ConstRef = typename Traits::ConstRef; + using Helper = ArraySerializationHelper< + T, + IsUnionDataType<T>::value, + std::is_same<T, AssociatedInterface_Data>::value || + std::is_same<T, AssociatedEndpointHandle_Data>::value || + std::is_same<T, Interface_Data>::value || + std::is_same<T, Handle_Data>::value>; + using Element = T; + + // Returns null if |num_elements| or the corresponding storage size cannot be + // stored in uint32_t. + static Array_Data<T>* New(size_t num_elements, Buffer* buf) { + if (num_elements > Traits::kMaxNumElements) + return nullptr; + + uint32_t num_bytes = + Traits::GetStorageSize(static_cast<uint32_t>(num_elements)); + return new (buf->Allocate(num_bytes)) + Array_Data<T>(num_bytes, static_cast<uint32_t>(num_elements)); + } + + static bool Validate(const void* data, + ValidationContext* validation_context, + const ContainerValidateParams* validate_params) { + if (!data) + return true; + if (!IsAligned(data)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MISALIGNED_OBJECT); + return false; + } + if (!validation_context->IsValidRange(data, sizeof(ArrayHeader))) { + ReportValidationError(validation_context, + VALIDATION_ERROR_ILLEGAL_MEMORY_RANGE); + return false; + } + const ArrayHeader* header = static_cast<const ArrayHeader*>(data); + if (header->num_elements > Traits::kMaxNumElements || + header->num_bytes < Traits::GetStorageSize(header->num_elements)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER); + return false; + } + if (validate_params->expected_num_elements != 0 && + header->num_elements != validate_params->expected_num_elements) { + ReportValidationError( + validation_context, + VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER, + MakeMessageWithExpectedArraySize( + "fixed-size array has wrong number of elements", + header->num_elements, + validate_params->expected_num_elements).c_str()); + return false; + } + if (!validation_context->ClaimMemory(data, header->num_bytes)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_ILLEGAL_MEMORY_RANGE); + return false; + } + + const Array_Data<T>* object = static_cast<const Array_Data<T>*>(data); + return Helper::ValidateElements(&object->header_, object->storage(), + validation_context, validate_params); + } + + size_t size() const { return header_.num_elements; } + + Ref at(size_t offset) { + DCHECK(offset < static_cast<size_t>(header_.num_elements)); + return Traits::ToRef(storage(), offset); + } + + ConstRef at(size_t offset) const { + DCHECK(offset < static_cast<size_t>(header_.num_elements)); + return Traits::ToConstRef(storage(), offset); + } + + StorageType* storage() { + return reinterpret_cast<StorageType*>(reinterpret_cast<char*>(this) + + sizeof(*this)); + } + + const StorageType* storage() const { + return reinterpret_cast<const StorageType*>( + reinterpret_cast<const char*>(this) + sizeof(*this)); + } + + private: + Array_Data(uint32_t num_bytes, uint32_t num_elements) { + header_.num_bytes = num_bytes; + header_.num_elements = num_elements; + } + ~Array_Data() = delete; + + internal::ArrayHeader header_; + + // Elements of type internal::ArrayDataTraits<T>::StorageType follow. +}; +static_assert(sizeof(Array_Data<char>) == 8, "Bad sizeof(Array_Data)"); + +// UTF-8 encoded +using String_Data = Array_Data<char>; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_ARRAY_INTERNAL_H_ diff --git a/mojo/public/cpp/bindings/lib/array_serialization.h b/mojo/public/cpp/bindings/lib/array_serialization.h new file mode 100644 index 0000000000..d2f8ecfd72 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/array_serialization.h @@ -0,0 +1,555 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_ARRAY_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_ARRAY_SERIALIZATION_H_ + +#include <stddef.h> +#include <string.h> // For |memcpy()|. + +#include <limits> +#include <type_traits> +#include <utility> +#include <vector> + +#include "base/logging.h" +#include "mojo/public/cpp/bindings/array_data_view.h" +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" +#include "mojo/public/cpp/bindings/lib/template_util.h" +#include "mojo/public/cpp/bindings/lib/validation_errors.h" + +namespace mojo { +namespace internal { + +template <typename Traits, + typename MaybeConstUserType, + bool HasGetBegin = + HasGetBeginMethod<Traits, MaybeConstUserType>::value> +class ArrayIterator {}; + +// Used as the UserTypeIterator template parameter of ArraySerializer. +template <typename Traits, typename MaybeConstUserType> +class ArrayIterator<Traits, MaybeConstUserType, true> { + public: + using IteratorType = decltype( + CallGetBeginIfExists<Traits>(std::declval<MaybeConstUserType&>())); + + explicit ArrayIterator(MaybeConstUserType& input) + : input_(input), iter_(CallGetBeginIfExists<Traits>(input)) {} + ~ArrayIterator() {} + + size_t GetSize() const { return Traits::GetSize(input_); } + + using GetNextResult = + decltype(Traits::GetValue(std::declval<IteratorType&>())); + GetNextResult GetNext() { + GetNextResult value = Traits::GetValue(iter_); + Traits::AdvanceIterator(iter_); + return value; + } + + using GetDataIfExistsResult = decltype( + CallGetDataIfExists<Traits>(std::declval<MaybeConstUserType&>())); + GetDataIfExistsResult GetDataIfExists() { + return CallGetDataIfExists<Traits>(input_); + } + + private: + MaybeConstUserType& input_; + IteratorType iter_; +}; + +// Used as the UserTypeIterator template parameter of ArraySerializer. +template <typename Traits, typename MaybeConstUserType> +class ArrayIterator<Traits, MaybeConstUserType, false> { + public: + explicit ArrayIterator(MaybeConstUserType& input) : input_(input), iter_(0) {} + ~ArrayIterator() {} + + size_t GetSize() const { return Traits::GetSize(input_); } + + using GetNextResult = + decltype(Traits::GetAt(std::declval<MaybeConstUserType&>(), 0)); + GetNextResult GetNext() { + DCHECK_LT(iter_, Traits::GetSize(input_)); + return Traits::GetAt(input_, iter_++); + } + + using GetDataIfExistsResult = decltype( + CallGetDataIfExists<Traits>(std::declval<MaybeConstUserType&>())); + GetDataIfExistsResult GetDataIfExists() { + return CallGetDataIfExists<Traits>(input_); + } + + private: + MaybeConstUserType& input_; + size_t iter_; +}; + +// ArraySerializer is also used to serialize map keys and values. Therefore, it +// has a UserTypeIterator parameter which is an adaptor for reading to hide the +// difference between ArrayTraits and MapTraits. +template <typename MojomType, + typename MaybeConstUserType, + typename UserTypeIterator, + typename EnableType = void> +struct ArraySerializer; + +// Handles serialization and deserialization of arrays of pod types. +template <typename MojomType, + typename MaybeConstUserType, + typename UserTypeIterator> +struct ArraySerializer< + MojomType, + MaybeConstUserType, + UserTypeIterator, + typename std::enable_if<BelongsTo<typename MojomType::Element, + MojomTypeCategory::POD>::value>::type> { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Data = typename MojomTypeTraits<MojomType>::Data; + using DataElement = typename Data::Element; + using Element = typename MojomType::Element; + using Traits = ArrayTraits<UserType>; + + static_assert(std::is_same<Element, DataElement>::value, + "Incorrect array serializer"); + static_assert(std::is_same<Element, typename Traits::Element>::value, + "Incorrect array serializer"); + + static size_t GetSerializedSize(UserTypeIterator* input, + SerializationContext* context) { + return sizeof(Data) + Align(input->GetSize() * sizeof(DataElement)); + } + + static void SerializeElements(UserTypeIterator* input, + Buffer* buf, + Data* output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + DCHECK(!validate_params->element_is_nullable) + << "Primitive type should be non-nullable"; + DCHECK(!validate_params->element_validate_params) + << "Primitive type should not have array validate params"; + + size_t size = input->GetSize(); + if (size == 0) + return; + + auto data = input->GetDataIfExists(); + if (data) { + memcpy(output->storage(), data, size * sizeof(DataElement)); + } else { + for (size_t i = 0; i < size; ++i) + output->at(i) = input->GetNext(); + } + } + + static bool DeserializeElements(Data* input, + UserType* output, + SerializationContext* context) { + if (!Traits::Resize(*output, input->size())) + return false; + ArrayIterator<Traits, UserType> iterator(*output); + if (input->size()) { + auto data = iterator.GetDataIfExists(); + if (data) { + memcpy(data, input->storage(), input->size() * sizeof(DataElement)); + } else { + for (size_t i = 0; i < input->size(); ++i) + iterator.GetNext() = input->at(i); + } + } + return true; + } +}; + +// Handles serialization and deserialization of arrays of enum types. +template <typename MojomType, + typename MaybeConstUserType, + typename UserTypeIterator> +struct ArraySerializer< + MojomType, + MaybeConstUserType, + UserTypeIterator, + typename std::enable_if<BelongsTo<typename MojomType::Element, + MojomTypeCategory::ENUM>::value>::type> { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Data = typename MojomTypeTraits<MojomType>::Data; + using DataElement = typename Data::Element; + using Element = typename MojomType::Element; + using Traits = ArrayTraits<UserType>; + + static_assert(sizeof(Element) == sizeof(DataElement), + "Incorrect array serializer"); + + static size_t GetSerializedSize(UserTypeIterator* input, + SerializationContext* context) { + return sizeof(Data) + Align(input->GetSize() * sizeof(DataElement)); + } + + static void SerializeElements(UserTypeIterator* input, + Buffer* buf, + Data* output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + DCHECK(!validate_params->element_is_nullable) + << "Primitive type should be non-nullable"; + DCHECK(!validate_params->element_validate_params) + << "Primitive type should not have array validate params"; + + size_t size = input->GetSize(); + for (size_t i = 0; i < size; ++i) + Serialize<Element>(input->GetNext(), output->storage() + i); + } + + static bool DeserializeElements(Data* input, + UserType* output, + SerializationContext* context) { + if (!Traits::Resize(*output, input->size())) + return false; + ArrayIterator<Traits, UserType> iterator(*output); + for (size_t i = 0; i < input->size(); ++i) { + if (!Deserialize<Element>(input->at(i), &iterator.GetNext())) + return false; + } + return true; + } +}; + +// Serializes and deserializes arrays of bools. +template <typename MojomType, + typename MaybeConstUserType, + typename UserTypeIterator> +struct ArraySerializer<MojomType, + MaybeConstUserType, + UserTypeIterator, + typename std::enable_if<BelongsTo< + typename MojomType::Element, + MojomTypeCategory::BOOLEAN>::value>::type> { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Traits = ArrayTraits<UserType>; + using Data = typename MojomTypeTraits<MojomType>::Data; + + static_assert(std::is_same<bool, typename Traits::Element>::value, + "Incorrect array serializer"); + + static size_t GetSerializedSize(UserTypeIterator* input, + SerializationContext* context) { + return sizeof(Data) + Align((input->GetSize() + 7) / 8); + } + + static void SerializeElements(UserTypeIterator* input, + Buffer* buf, + Data* output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + DCHECK(!validate_params->element_is_nullable) + << "Primitive type should be non-nullable"; + DCHECK(!validate_params->element_validate_params) + << "Primitive type should not have array validate params"; + + size_t size = input->GetSize(); + for (size_t i = 0; i < size; ++i) + output->at(i) = input->GetNext(); + } + static bool DeserializeElements(Data* input, + UserType* output, + SerializationContext* context) { + if (!Traits::Resize(*output, input->size())) + return false; + ArrayIterator<Traits, UserType> iterator(*output); + for (size_t i = 0; i < input->size(); ++i) + iterator.GetNext() = input->at(i); + return true; + } +}; + +// Serializes and deserializes arrays of handles or interfaces. +template <typename MojomType, + typename MaybeConstUserType, + typename UserTypeIterator> +struct ArraySerializer< + MojomType, + MaybeConstUserType, + UserTypeIterator, + typename std::enable_if< + BelongsTo<typename MojomType::Element, + MojomTypeCategory::ASSOCIATED_INTERFACE | + MojomTypeCategory::ASSOCIATED_INTERFACE_REQUEST | + MojomTypeCategory::HANDLE | + MojomTypeCategory::INTERFACE | + MojomTypeCategory::INTERFACE_REQUEST>::value>::type> { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Data = typename MojomTypeTraits<MojomType>::Data; + using Element = typename MojomType::Element; + using Traits = ArrayTraits<UserType>; + + static size_t GetSerializedSize(UserTypeIterator* input, + SerializationContext* context) { + size_t element_count = input->GetSize(); + if (BelongsTo<Element, + MojomTypeCategory::ASSOCIATED_INTERFACE | + MojomTypeCategory::ASSOCIATED_INTERFACE_REQUEST>::value) { + for (size_t i = 0; i < element_count; ++i) { + typename UserTypeIterator::GetNextResult next = input->GetNext(); + size_t size = PrepareToSerialize<Element>(next, context); + DCHECK_EQ(size, 0u); + } + } + return sizeof(Data) + Align(element_count * sizeof(typename Data::Element)); + } + + static void SerializeElements(UserTypeIterator* input, + Buffer* buf, + Data* output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + DCHECK(!validate_params->element_validate_params) + << "Handle or interface type should not have array validate params"; + + size_t size = input->GetSize(); + for (size_t i = 0; i < size; ++i) { + typename UserTypeIterator::GetNextResult next = input->GetNext(); + Serialize<Element>(next, &output->at(i), context); + + static const ValidationError kError = + BelongsTo<Element, + MojomTypeCategory::ASSOCIATED_INTERFACE | + MojomTypeCategory::ASSOCIATED_INTERFACE_REQUEST>::value + ? VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID + : VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE; + MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( + !validate_params->element_is_nullable && + !IsHandleOrInterfaceValid(output->at(i)), + kError, + MakeMessageWithArrayIndex("invalid handle or interface ID in array " + "expecting valid handles or interface IDs", + size, i)); + } + } + static bool DeserializeElements(Data* input, + UserType* output, + SerializationContext* context) { + if (!Traits::Resize(*output, input->size())) + return false; + ArrayIterator<Traits, UserType> iterator(*output); + for (size_t i = 0; i < input->size(); ++i) { + bool result = + Deserialize<Element>(&input->at(i), &iterator.GetNext(), context); + DCHECK(result); + } + return true; + } +}; + +// This template must only apply to pointer mojo entity (strings, structs, +// arrays and maps). +template <typename MojomType, + typename MaybeConstUserType, + typename UserTypeIterator> +struct ArraySerializer<MojomType, + MaybeConstUserType, + UserTypeIterator, + typename std::enable_if<BelongsTo< + typename MojomType::Element, + MojomTypeCategory::ARRAY | MojomTypeCategory::MAP | + MojomTypeCategory::STRING | + MojomTypeCategory::STRUCT>::value>::type> { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Data = typename MojomTypeTraits<MojomType>::Data; + using Element = typename MojomType::Element; + using DataElementPtr = typename MojomTypeTraits<Element>::Data*; + using Traits = ArrayTraits<UserType>; + + static size_t GetSerializedSize(UserTypeIterator* input, + SerializationContext* context) { + size_t element_count = input->GetSize(); + size_t size = sizeof(Data) + element_count * sizeof(typename Data::Element); + for (size_t i = 0; i < element_count; ++i) { + typename UserTypeIterator::GetNextResult next = input->GetNext(); + size += PrepareToSerialize<Element>(next, context); + } + return size; + } + + static void SerializeElements(UserTypeIterator* input, + Buffer* buf, + Data* output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + size_t size = input->GetSize(); + for (size_t i = 0; i < size; ++i) { + DataElementPtr data_ptr; + typename UserTypeIterator::GetNextResult next = input->GetNext(); + SerializeCaller<Element>::Run(next, buf, &data_ptr, + validate_params->element_validate_params, + context); + output->at(i).Set(data_ptr); + MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( + !validate_params->element_is_nullable && !data_ptr, + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + MakeMessageWithArrayIndex("null in array expecting valid pointers", + size, i)); + } + } + static bool DeserializeElements(Data* input, + UserType* output, + SerializationContext* context) { + if (!Traits::Resize(*output, input->size())) + return false; + ArrayIterator<Traits, UserType> iterator(*output); + for (size_t i = 0; i < input->size(); ++i) { + if (!Deserialize<Element>(input->at(i).Get(), &iterator.GetNext(), + context)) + return false; + } + return true; + } + + private: + template <typename T, + bool is_array_or_map = BelongsTo<T, + MojomTypeCategory::ARRAY | + MojomTypeCategory::MAP>::value> + struct SerializeCaller { + template <typename InputElementType> + static void Run(InputElementType&& input, + Buffer* buf, + DataElementPtr* output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + Serialize<T>(std::forward<InputElementType>(input), buf, output, context); + } + }; + + template <typename T> + struct SerializeCaller<T, true> { + template <typename InputElementType> + static void Run(InputElementType&& input, + Buffer* buf, + DataElementPtr* output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + Serialize<T>(std::forward<InputElementType>(input), buf, output, + validate_params, context); + } + }; +}; + +// Handles serialization and deserialization of arrays of unions. +template <typename MojomType, + typename MaybeConstUserType, + typename UserTypeIterator> +struct ArraySerializer< + MojomType, + MaybeConstUserType, + UserTypeIterator, + typename std::enable_if<BelongsTo<typename MojomType::Element, + MojomTypeCategory::UNION>::value>::type> { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Data = typename MojomTypeTraits<MojomType>::Data; + using Element = typename MojomType::Element; + using Traits = ArrayTraits<UserType>; + + static size_t GetSerializedSize(UserTypeIterator* input, + SerializationContext* context) { + size_t element_count = input->GetSize(); + size_t size = sizeof(Data); + for (size_t i = 0; i < element_count; ++i) { + // Call with |inlined| set to false, so that it will account for both the + // data in the union and the space in the array used to hold the union. + typename UserTypeIterator::GetNextResult next = input->GetNext(); + size += PrepareToSerialize<Element>(next, false, context); + } + return size; + } + + static void SerializeElements(UserTypeIterator* input, + Buffer* buf, + Data* output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + size_t size = input->GetSize(); + for (size_t i = 0; i < size; ++i) { + typename Data::Element* result = output->storage() + i; + typename UserTypeIterator::GetNextResult next = input->GetNext(); + Serialize<Element>(next, buf, &result, true, context); + MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( + !validate_params->element_is_nullable && output->at(i).is_null(), + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + MakeMessageWithArrayIndex("null in array expecting valid unions", + size, i)); + } + } + + static bool DeserializeElements(Data* input, + UserType* output, + SerializationContext* context) { + if (!Traits::Resize(*output, input->size())) + return false; + ArrayIterator<Traits, UserType> iterator(*output); + for (size_t i = 0; i < input->size(); ++i) { + if (!Deserialize<Element>(&input->at(i), &iterator.GetNext(), context)) + return false; + } + return true; + } +}; + +template <typename Element, typename MaybeConstUserType> +struct Serializer<ArrayDataView<Element>, MaybeConstUserType> { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Traits = ArrayTraits<UserType>; + using Impl = ArraySerializer<ArrayDataView<Element>, + MaybeConstUserType, + ArrayIterator<Traits, MaybeConstUserType>>; + using Data = typename MojomTypeTraits<ArrayDataView<Element>>::Data; + + static size_t PrepareToSerialize(MaybeConstUserType& input, + SerializationContext* context) { + if (CallIsNullIfExists<Traits>(input)) + return 0; + ArrayIterator<Traits, MaybeConstUserType> iterator(input); + return Impl::GetSerializedSize(&iterator, context); + } + + static void Serialize(MaybeConstUserType& input, + Buffer* buf, + Data** output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + if (!CallIsNullIfExists<Traits>(input)) { + MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( + validate_params->expected_num_elements != 0 && + Traits::GetSize(input) != validate_params->expected_num_elements, + internal::VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER, + internal::MakeMessageWithExpectedArraySize( + "fixed-size array has wrong number of elements", + Traits::GetSize(input), validate_params->expected_num_elements)); + Data* result = Data::New(Traits::GetSize(input), buf); + if (result) { + ArrayIterator<Traits, MaybeConstUserType> iterator(input); + Impl::SerializeElements(&iterator, buf, result, validate_params, + context); + } + *output = result; + } else { + *output = nullptr; + } + } + + static bool Deserialize(Data* input, + UserType* output, + SerializationContext* context) { + if (!input) + return CallSetToNullIfExists<Traits>(output); + return Impl::DeserializeElements(input, output, context); + } +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_ARRAY_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/associated_binding.cc b/mojo/public/cpp/bindings/lib/associated_binding.cc new file mode 100644 index 0000000000..6788e68e07 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/associated_binding.cc @@ -0,0 +1,62 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/associated_binding.h" + +namespace mojo { + +AssociatedBindingBase::AssociatedBindingBase() {} + +AssociatedBindingBase::~AssociatedBindingBase() {} + +void AssociatedBindingBase::AddFilter(std::unique_ptr<MessageReceiver> filter) { + DCHECK(endpoint_client_); + endpoint_client_->AddFilter(std::move(filter)); +} + +void AssociatedBindingBase::Close() { + endpoint_client_.reset(); +} + +void AssociatedBindingBase::CloseWithReason(uint32_t custom_reason, + const std::string& description) { + if (endpoint_client_) + endpoint_client_->CloseWithReason(custom_reason, description); + Close(); +} + +void AssociatedBindingBase::set_connection_error_handler( + const base::Closure& error_handler) { + DCHECK(is_bound()); + endpoint_client_->set_connection_error_handler(error_handler); +} + +void AssociatedBindingBase::set_connection_error_with_reason_handler( + const ConnectionErrorWithReasonCallback& error_handler) { + DCHECK(is_bound()); + endpoint_client_->set_connection_error_with_reason_handler(error_handler); +} + +void AssociatedBindingBase::FlushForTesting() { + endpoint_client_->FlushForTesting(); +} + +void AssociatedBindingBase::BindImpl( + ScopedInterfaceEndpointHandle handle, + MessageReceiverWithResponderStatus* receiver, + std::unique_ptr<MessageReceiver> payload_validator, + bool expect_sync_requests, + scoped_refptr<base::SingleThreadTaskRunner> runner, + uint32_t interface_version) { + if (!handle.is_valid()) { + endpoint_client_.reset(); + return; + } + + endpoint_client_.reset(new InterfaceEndpointClient( + std::move(handle), receiver, std::move(payload_validator), + expect_sync_requests, std::move(runner), interface_version)); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/associated_group.cc b/mojo/public/cpp/bindings/lib/associated_group.cc new file mode 100644 index 0000000000..3e95eeb027 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/associated_group.cc @@ -0,0 +1,34 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/associated_group.h" + +#include "mojo/public/cpp/bindings/associated_group_controller.h" + +namespace mojo { + +AssociatedGroup::AssociatedGroup() = default; + +AssociatedGroup::AssociatedGroup( + scoped_refptr<AssociatedGroupController> controller) + : controller_(std::move(controller)) {} + +AssociatedGroup::AssociatedGroup(const ScopedInterfaceEndpointHandle& handle) + : controller_getter_(handle.CreateGroupControllerGetter()) {} + +AssociatedGroup::AssociatedGroup(const AssociatedGroup& other) = default; + +AssociatedGroup::~AssociatedGroup() = default; + +AssociatedGroup& AssociatedGroup::operator=(const AssociatedGroup& other) = + default; + +AssociatedGroupController* AssociatedGroup::GetController() { + if (controller_) + return controller_.get(); + + return controller_getter_.Run(); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/associated_group_controller.cc b/mojo/public/cpp/bindings/lib/associated_group_controller.cc new file mode 100644 index 0000000000..f4a9aa2852 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/associated_group_controller.cc @@ -0,0 +1,24 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/associated_group_controller.h" + +#include "mojo/public/cpp/bindings/associated_group.h" + +namespace mojo { + +AssociatedGroupController::~AssociatedGroupController() {} + +ScopedInterfaceEndpointHandle +AssociatedGroupController::CreateScopedInterfaceEndpointHandle(InterfaceId id) { + return ScopedInterfaceEndpointHandle(id, this); +} + +bool AssociatedGroupController::NotifyAssociation( + ScopedInterfaceEndpointHandle* handle_to_send, + InterfaceId id) { + return handle_to_send->NotifyAssociation(id, this); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc b/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc new file mode 100644 index 0000000000..78281eda9a --- /dev/null +++ b/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc @@ -0,0 +1,18 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/associated_interface_ptr.h" + +namespace mojo { + +void GetIsolatedInterface(ScopedInterfaceEndpointHandle handle) { + MessagePipe pipe; + scoped_refptr<internal::MultiplexRouter> router = + new internal::MultiplexRouter(std::move(pipe.handle0), + internal::MultiplexRouter::MULTI_INTERFACE, + false, base::ThreadTaskRunnerHandle::Get()); + router->AssociateInterface(std::move(handle)); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h new file mode 100644 index 0000000000..a4b51882d2 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h @@ -0,0 +1,157 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_ASSOCIATED_INTERFACE_PTR_STATE_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_ASSOCIATED_INTERFACE_PTR_STATE_H_ + +#include <stdint.h> + +#include <algorithm> // For |std::swap()|. +#include <memory> +#include <string> +#include <utility> + +#include "base/bind.h" +#include "base/callback_forward.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/memory/ref_counted.h" +#include "base/single_thread_task_runner.h" +#include "mojo/public/cpp/bindings/associated_group.h" +#include "mojo/public/cpp/bindings/associated_interface_ptr_info.h" +#include "mojo/public/cpp/bindings/connection_error_callback.h" +#include "mojo/public/cpp/bindings/interface_endpoint_client.h" +#include "mojo/public/cpp/bindings/interface_id.h" +#include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" +#include "mojo/public/cpp/system/message_pipe.h" + +namespace mojo { +namespace internal { + +template <typename Interface> +class AssociatedInterfacePtrState { + public: + AssociatedInterfacePtrState() : version_(0u) {} + + ~AssociatedInterfacePtrState() { + endpoint_client_.reset(); + proxy_.reset(); + } + + Interface* instance() { + // This will be null if the object is not bound. + return proxy_.get(); + } + + uint32_t version() const { return version_; } + + void QueryVersion(const base::Callback<void(uint32_t)>& callback) { + // It is safe to capture |this| because the callback won't be run after this + // object goes away. + endpoint_client_->QueryVersion( + base::Bind(&AssociatedInterfacePtrState::OnQueryVersion, + base::Unretained(this), callback)); + } + + void RequireVersion(uint32_t version) { + if (version <= version_) + return; + + version_ = version; + endpoint_client_->RequireVersion(version); + } + + void FlushForTesting() { endpoint_client_->FlushForTesting(); } + + void CloseWithReason(uint32_t custom_reason, const std::string& description) { + endpoint_client_->CloseWithReason(custom_reason, description); + } + + void Swap(AssociatedInterfacePtrState* other) { + using std::swap; + swap(other->endpoint_client_, endpoint_client_); + swap(other->proxy_, proxy_); + swap(other->version_, version_); + } + + void Bind(AssociatedInterfacePtrInfo<Interface> info, + scoped_refptr<base::SingleThreadTaskRunner> runner) { + DCHECK(!endpoint_client_); + DCHECK(!proxy_); + DCHECK_EQ(0u, version_); + DCHECK(info.is_valid()); + + version_ = info.version(); + // The version is only queried from the client so the value passed here + // will not be used. + endpoint_client_.reset(new InterfaceEndpointClient( + info.PassHandle(), nullptr, + base::WrapUnique(new typename Interface::ResponseValidator_()), false, + std::move(runner), 0u)); + proxy_.reset(new Proxy(endpoint_client_.get())); + } + + // After this method is called, the object is in an invalid state and + // shouldn't be reused. + AssociatedInterfacePtrInfo<Interface> PassInterface() { + ScopedInterfaceEndpointHandle handle = endpoint_client_->PassHandle(); + endpoint_client_.reset(); + proxy_.reset(); + return AssociatedInterfacePtrInfo<Interface>(std::move(handle), version_); + } + + bool is_bound() const { return !!endpoint_client_; } + + bool encountered_error() const { + return endpoint_client_ ? endpoint_client_->encountered_error() : false; + } + + void set_connection_error_handler(const base::Closure& error_handler) { + DCHECK(endpoint_client_); + endpoint_client_->set_connection_error_handler(error_handler); + } + + void set_connection_error_with_reason_handler( + const ConnectionErrorWithReasonCallback& error_handler) { + DCHECK(endpoint_client_); + endpoint_client_->set_connection_error_with_reason_handler(error_handler); + } + + // Returns true if bound and awaiting a response to a message. + bool has_pending_callbacks() const { + return endpoint_client_ && endpoint_client_->has_pending_responders(); + } + + AssociatedGroup* associated_group() { + return endpoint_client_ ? endpoint_client_->associated_group() : nullptr; + } + + void ForwardMessage(Message message) { endpoint_client_->Accept(&message); } + + void ForwardMessageWithResponder(Message message, + std::unique_ptr<MessageReceiver> responder) { + endpoint_client_->AcceptWithResponder(&message, std::move(responder)); + } + + private: + using Proxy = typename Interface::Proxy_; + + void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, + uint32_t version) { + version_ = version; + callback.Run(version); + } + + std::unique_ptr<InterfaceEndpointClient> endpoint_client_; + std::unique_ptr<Proxy> proxy_; + + uint32_t version_; + + DISALLOW_COPY_AND_ASSIGN(AssociatedInterfacePtrState); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_ASSOCIATED_INTERFACE_PTR_STATE_H_ diff --git a/mojo/public/cpp/bindings/lib/binding_state.cc b/mojo/public/cpp/bindings/lib/binding_state.cc new file mode 100644 index 0000000000..b34cb47e28 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/binding_state.cc @@ -0,0 +1,90 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/binding_state.h" + +namespace mojo { +namespace internal { + +BindingStateBase::BindingStateBase() = default; + +BindingStateBase::~BindingStateBase() = default; + +void BindingStateBase::AddFilter(std::unique_ptr<MessageReceiver> filter) { + DCHECK(endpoint_client_); + endpoint_client_->AddFilter(std::move(filter)); +} + +bool BindingStateBase::HasAssociatedInterfaces() const { + return router_ ? router_->HasAssociatedEndpoints() : false; +} + +void BindingStateBase::PauseIncomingMethodCallProcessing() { + DCHECK(router_); + router_->PauseIncomingMethodCallProcessing(); +} +void BindingStateBase::ResumeIncomingMethodCallProcessing() { + DCHECK(router_); + router_->ResumeIncomingMethodCallProcessing(); +} + +bool BindingStateBase::WaitForIncomingMethodCall(MojoDeadline deadline) { + DCHECK(router_); + return router_->WaitForIncomingMessage(deadline); +} + +void BindingStateBase::Close() { + if (!router_) + return; + + endpoint_client_.reset(); + router_->CloseMessagePipe(); + router_ = nullptr; +} + +void BindingStateBase::CloseWithReason(uint32_t custom_reason, + const std::string& description) { + if (endpoint_client_) + endpoint_client_->CloseWithReason(custom_reason, description); + + Close(); +} + +void BindingStateBase::FlushForTesting() { + endpoint_client_->FlushForTesting(); +} + +void BindingStateBase::EnableTestingMode() { + DCHECK(is_bound()); + router_->EnableTestingMode(); +} + +void BindingStateBase::BindInternal( + ScopedMessagePipeHandle handle, + scoped_refptr<base::SingleThreadTaskRunner> runner, + const char* interface_name, + std::unique_ptr<MessageReceiver> request_validator, + bool passes_associated_kinds, + bool has_sync_methods, + MessageReceiverWithResponderStatus* stub, + uint32_t interface_version) { + DCHECK(!router_); + + MultiplexRouter::Config config = + passes_associated_kinds + ? MultiplexRouter::MULTI_INTERFACE + : (has_sync_methods + ? MultiplexRouter::SINGLE_INTERFACE_WITH_SYNC_METHODS + : MultiplexRouter::SINGLE_INTERFACE); + router_ = new MultiplexRouter(std::move(handle), config, false, runner); + router_->SetMasterInterfaceName(interface_name); + + endpoint_client_.reset(new InterfaceEndpointClient( + router_->CreateLocalEndpointHandle(kMasterInterfaceId), stub, + std::move(request_validator), has_sync_methods, std::move(runner), + interface_version)); +} + +} // namesapce internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/binding_state.h b/mojo/public/cpp/bindings/lib/binding_state.h new file mode 100644 index 0000000000..0b0dbee002 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/binding_state.h @@ -0,0 +1,128 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_BINDING_STATE_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_BINDING_STATE_H_ + +#include <memory> +#include <string> +#include <utility> + +#include "base/bind.h" +#include "base/callback.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/memory/ref_counted.h" +#include "base/single_thread_task_runner.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/connection_error_callback.h" +#include "mojo/public/cpp/bindings/filter_chain.h" +#include "mojo/public/cpp/bindings/interface_endpoint_client.h" +#include "mojo/public/cpp/bindings/interface_id.h" +#include "mojo/public/cpp/bindings/interface_ptr.h" +#include "mojo/public/cpp/bindings/interface_ptr_info.h" +#include "mojo/public/cpp/bindings/interface_request.h" +#include "mojo/public/cpp/bindings/lib/multiplex_router.h" +#include "mojo/public/cpp/bindings/message_header_validator.h" +#include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" +#include "mojo/public/cpp/system/core.h" + +namespace mojo { +namespace internal { + +class MOJO_CPP_BINDINGS_EXPORT BindingStateBase { + public: + BindingStateBase(); + ~BindingStateBase(); + + void AddFilter(std::unique_ptr<MessageReceiver> filter); + + bool HasAssociatedInterfaces() const; + + void PauseIncomingMethodCallProcessing(); + void ResumeIncomingMethodCallProcessing(); + + bool WaitForIncomingMethodCall( + MojoDeadline deadline = MOJO_DEADLINE_INDEFINITE); + + void Close(); + void CloseWithReason(uint32_t custom_reason, const std::string& description); + + void set_connection_error_handler(const base::Closure& error_handler) { + DCHECK(is_bound()); + endpoint_client_->set_connection_error_handler(error_handler); + } + + void set_connection_error_with_reason_handler( + const ConnectionErrorWithReasonCallback& error_handler) { + DCHECK(is_bound()); + endpoint_client_->set_connection_error_with_reason_handler(error_handler); + } + + bool is_bound() const { return !!router_; } + + MessagePipeHandle handle() const { + DCHECK(is_bound()); + return router_->handle(); + } + + void FlushForTesting(); + + void EnableTestingMode(); + + protected: + void BindInternal(ScopedMessagePipeHandle handle, + scoped_refptr<base::SingleThreadTaskRunner> runner, + const char* interface_name, + std::unique_ptr<MessageReceiver> request_validator, + bool passes_associated_kinds, + bool has_sync_methods, + MessageReceiverWithResponderStatus* stub, + uint32_t interface_version); + + scoped_refptr<internal::MultiplexRouter> router_; + std::unique_ptr<InterfaceEndpointClient> endpoint_client_; +}; + +template <typename Interface, typename ImplRefTraits> +class BindingState : public BindingStateBase { + public: + using ImplPointerType = typename ImplRefTraits::PointerType; + + explicit BindingState(ImplPointerType impl) { + stub_.set_sink(std::move(impl)); + } + + ~BindingState() { Close(); } + + void Bind(ScopedMessagePipeHandle handle, + scoped_refptr<base::SingleThreadTaskRunner> runner) { + BindingStateBase::BindInternal( + std::move(handle), runner, Interface::Name_, + base::MakeUnique<typename Interface::RequestValidator_>(), + Interface::PassesAssociatedKinds_, Interface::HasSyncMethods_, &stub_, + Interface::Version_); + } + + InterfaceRequest<Interface> Unbind() { + endpoint_client_.reset(); + InterfaceRequest<Interface> request = + MakeRequest<Interface>(router_->PassMessagePipe()); + router_ = nullptr; + return request; + } + + Interface* impl() { return ImplRefTraits::GetRawPointer(&stub_.sink()); } + + private: + typename Interface::template Stub_<ImplRefTraits> stub_; + + DISALLOW_COPY_AND_ASSIGN(BindingState); +}; + +} // namesapce internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_BINDING_STATE_H_ diff --git a/mojo/public/cpp/bindings/lib/bindings_internal.h b/mojo/public/cpp/bindings/lib/bindings_internal.h new file mode 100644 index 0000000000..631daec392 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/bindings_internal.h @@ -0,0 +1,336 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_BINDINGS_INTERNAL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_BINDINGS_INTERNAL_H_ + +#include <stdint.h> + +#include <functional> + +#include "base/template_util.h" +#include "mojo/public/cpp/bindings/interface_id.h" +#include "mojo/public/cpp/bindings/lib/template_util.h" +#include "mojo/public/cpp/system/core.h" + +namespace mojo { + +template <typename T> +class ArrayDataView; + +template <typename T> +class AssociatedInterfacePtrInfoDataView; + +template <typename T> +class AssociatedInterfaceRequestDataView; + +template <typename T> +class InterfacePtrDataView; + +template <typename T> +class InterfaceRequestDataView; + +template <typename K, typename V> +class MapDataView; + +class NativeStructDataView; + +class StringDataView; + +namespace internal { + +// Please note that this is a different value than |mojo::kInvalidHandleValue|, +// which is the "decoded" invalid handle. +const uint32_t kEncodedInvalidHandleValue = static_cast<uint32_t>(-1); + +// A serialized union always takes 16 bytes: +// 4-byte size + 4-byte tag + 8-byte payload. +const uint32_t kUnionDataSize = 16; + +template <typename T> +class Array_Data; + +template <typename K, typename V> +class Map_Data; + +class NativeStruct_Data; + +using String_Data = Array_Data<char>; + +inline size_t Align(size_t size) { + return (size + 7) & ~0x7; +} + +inline bool IsAligned(const void* ptr) { + return !(reinterpret_cast<uintptr_t>(ptr) & 0x7); +} + +// Pointers are encoded as relative offsets. The offsets are relative to the +// address of where the offset value is stored, such that the pointer may be +// recovered with the expression: +// +// ptr = reinterpret_cast<char*>(offset) + *offset +// +// A null pointer is encoded as an offset value of 0. +// +inline void EncodePointer(const void* ptr, uint64_t* offset) { + if (!ptr) { + *offset = 0; + return; + } + + const char* p_obj = reinterpret_cast<const char*>(ptr); + const char* p_slot = reinterpret_cast<const char*>(offset); + DCHECK(p_obj > p_slot); + + *offset = static_cast<uint64_t>(p_obj - p_slot); +} + +// Note: This function doesn't validate the encoded pointer value. +inline const void* DecodePointer(const uint64_t* offset) { + if (!*offset) + return nullptr; + return reinterpret_cast<const char*>(offset) + *offset; +} + +#pragma pack(push, 1) + +struct StructHeader { + uint32_t num_bytes; + uint32_t version; +}; +static_assert(sizeof(StructHeader) == 8, "Bad sizeof(StructHeader)"); + +struct ArrayHeader { + uint32_t num_bytes; + uint32_t num_elements; +}; +static_assert(sizeof(ArrayHeader) == 8, "Bad_sizeof(ArrayHeader)"); + +template <typename T> +struct Pointer { + using BaseType = T; + + void Set(T* ptr) { EncodePointer(ptr, &offset); } + const T* Get() const { return static_cast<const T*>(DecodePointer(&offset)); } + T* Get() { + return static_cast<T*>(const_cast<void*>(DecodePointer(&offset))); + } + + bool is_null() const { return offset == 0; } + + uint64_t offset; +}; +static_assert(sizeof(Pointer<char>) == 8, "Bad_sizeof(Pointer)"); + +using GenericPointer = Pointer<void>; + +struct Handle_Data { + Handle_Data() = default; + explicit Handle_Data(uint32_t value) : value(value) {} + + bool is_valid() const { return value != kEncodedInvalidHandleValue; } + + uint32_t value; +}; +static_assert(sizeof(Handle_Data) == 4, "Bad_sizeof(Handle_Data)"); + +struct Interface_Data { + Handle_Data handle; + uint32_t version; +}; +static_assert(sizeof(Interface_Data) == 8, "Bad_sizeof(Interface_Data)"); + +struct AssociatedEndpointHandle_Data { + AssociatedEndpointHandle_Data() = default; + explicit AssociatedEndpointHandle_Data(uint32_t value) : value(value) {} + + bool is_valid() const { return value != kEncodedInvalidHandleValue; } + + uint32_t value; +}; +static_assert(sizeof(AssociatedEndpointHandle_Data) == 4, + "Bad_sizeof(AssociatedEndpointHandle_Data)"); + +struct AssociatedInterface_Data { + AssociatedEndpointHandle_Data handle; + uint32_t version; +}; +static_assert(sizeof(AssociatedInterface_Data) == 8, + "Bad_sizeof(AssociatedInterface_Data)"); + +#pragma pack(pop) + +template <typename T> +T FetchAndReset(T* ptr) { + T temp = *ptr; + *ptr = T(); + return temp; +} + +template <typename T> +struct IsUnionDataType { + private: + template <typename U> + static YesType Test(const typename U::MojomUnionDataType*); + + template <typename U> + static NoType Test(...); + + EnsureTypeIsComplete<T> check_t_; + + public: + static const bool value = + sizeof(Test<T>(0)) == sizeof(YesType) && !IsConst<T>::value; +}; + +enum class MojomTypeCategory : uint32_t { + ARRAY = 1 << 0, + ASSOCIATED_INTERFACE = 1 << 1, + ASSOCIATED_INTERFACE_REQUEST = 1 << 2, + BOOLEAN = 1 << 3, + ENUM = 1 << 4, + HANDLE = 1 << 5, + INTERFACE = 1 << 6, + INTERFACE_REQUEST = 1 << 7, + MAP = 1 << 8, + // POD except boolean and enum. + POD = 1 << 9, + STRING = 1 << 10, + STRUCT = 1 << 11, + UNION = 1 << 12 +}; + +inline constexpr MojomTypeCategory operator&(MojomTypeCategory x, + MojomTypeCategory y) { + return static_cast<MojomTypeCategory>(static_cast<uint32_t>(x) & + static_cast<uint32_t>(y)); +} + +inline constexpr MojomTypeCategory operator|(MojomTypeCategory x, + MojomTypeCategory y) { + return static_cast<MojomTypeCategory>(static_cast<uint32_t>(x) | + static_cast<uint32_t>(y)); +} + +template <typename T, bool is_enum = std::is_enum<T>::value> +struct MojomTypeTraits { + using Data = T; + using DataAsArrayElement = Data; + + static const MojomTypeCategory category = MojomTypeCategory::POD; +}; + +template <typename T> +struct MojomTypeTraits<ArrayDataView<T>, false> { + using Data = Array_Data<typename MojomTypeTraits<T>::DataAsArrayElement>; + using DataAsArrayElement = Pointer<Data>; + + static const MojomTypeCategory category = MojomTypeCategory::ARRAY; +}; + +template <typename T> +struct MojomTypeTraits<AssociatedInterfacePtrInfoDataView<T>, false> { + using Data = AssociatedInterface_Data; + using DataAsArrayElement = Data; + + static const MojomTypeCategory category = + MojomTypeCategory::ASSOCIATED_INTERFACE; +}; + +template <typename T> +struct MojomTypeTraits<AssociatedInterfaceRequestDataView<T>, false> { + using Data = AssociatedEndpointHandle_Data; + using DataAsArrayElement = Data; + + static const MojomTypeCategory category = + MojomTypeCategory::ASSOCIATED_INTERFACE_REQUEST; +}; + +template <> +struct MojomTypeTraits<bool, false> { + using Data = bool; + using DataAsArrayElement = Data; + + static const MojomTypeCategory category = MojomTypeCategory::BOOLEAN; +}; + +template <typename T> +struct MojomTypeTraits<T, true> { + using Data = int32_t; + using DataAsArrayElement = Data; + + static const MojomTypeCategory category = MojomTypeCategory::ENUM; +}; + +template <typename T> +struct MojomTypeTraits<ScopedHandleBase<T>, false> { + using Data = Handle_Data; + using DataAsArrayElement = Data; + + static const MojomTypeCategory category = MojomTypeCategory::HANDLE; +}; + +template <typename T> +struct MojomTypeTraits<InterfacePtrDataView<T>, false> { + using Data = Interface_Data; + using DataAsArrayElement = Data; + + static const MojomTypeCategory category = MojomTypeCategory::INTERFACE; +}; + +template <typename T> +struct MojomTypeTraits<InterfaceRequestDataView<T>, false> { + using Data = Handle_Data; + using DataAsArrayElement = Data; + + static const MojomTypeCategory category = + MojomTypeCategory::INTERFACE_REQUEST; +}; + +template <typename K, typename V> +struct MojomTypeTraits<MapDataView<K, V>, false> { + using Data = Map_Data<typename MojomTypeTraits<K>::DataAsArrayElement, + typename MojomTypeTraits<V>::DataAsArrayElement>; + using DataAsArrayElement = Pointer<Data>; + + static const MojomTypeCategory category = MojomTypeCategory::MAP; +}; + +template <> +struct MojomTypeTraits<NativeStructDataView, false> { + using Data = internal::NativeStruct_Data; + using DataAsArrayElement = Pointer<Data>; + + static const MojomTypeCategory category = MojomTypeCategory::STRUCT; +}; + +template <> +struct MojomTypeTraits<StringDataView, false> { + using Data = String_Data; + using DataAsArrayElement = Pointer<Data>; + + static const MojomTypeCategory category = MojomTypeCategory::STRING; +}; + +template <typename T, MojomTypeCategory categories> +struct BelongsTo { + static const bool value = + static_cast<uint32_t>(MojomTypeTraits<T>::category & categories) != 0; +}; + +template <typename T> +struct EnumHashImpl { + static_assert(std::is_enum<T>::value, "Incorrect hash function."); + + size_t operator()(T input) const { + using UnderlyingType = typename base::underlying_type<T>::type; + return std::hash<UnderlyingType>()(static_cast<UnderlyingType>(input)); + } +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_BINDINGS_INTERNAL_H_ diff --git a/mojo/public/cpp/bindings/lib/buffer.h b/mojo/public/cpp/bindings/lib/buffer.h new file mode 100644 index 0000000000..213a44590f --- /dev/null +++ b/mojo/public/cpp/bindings/lib/buffer.h @@ -0,0 +1,70 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_BUFFER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_BUFFER_H_ + +#include <stddef.h> + +#include "base/logging.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" + +namespace mojo { +namespace internal { + +// Buffer provides an interface to allocate memory blocks which are 8-byte +// aligned and zero-initialized. It doesn't own the underlying memory. Users +// must ensure that the memory stays valid while using the allocated blocks from +// Buffer. +class Buffer { + public: + Buffer() {} + + // The memory must have been zero-initialized. |data| must be 8-byte + // aligned. + void Initialize(void* data, size_t size) { + DCHECK(IsAligned(data)); + + data_ = data; + size_ = size; + cursor_ = reinterpret_cast<uintptr_t>(data); + data_end_ = cursor_ + size; + } + + size_t size() const { return size_; } + + void* data() const { return data_; } + + // Allocates |num_bytes| from the buffer and returns a pointer to the start of + // the allocated block. + // The resulting address is 8-byte aligned, and the content of the memory is + // zero-filled. + void* Allocate(size_t num_bytes) { + num_bytes = Align(num_bytes); + uintptr_t result = cursor_; + cursor_ += num_bytes; + if (cursor_ > data_end_ || cursor_ < result) { + NOTREACHED(); + cursor_ -= num_bytes; + return nullptr; + } + + return reinterpret_cast<void*>(result); + } + + private: + void* data_ = nullptr; + size_t size_ = 0; + + uintptr_t cursor_ = 0; + uintptr_t data_end_ = 0; + + DISALLOW_COPY_AND_ASSIGN(Buffer); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_BUFFER_H_ diff --git a/mojo/public/cpp/bindings/lib/connector.cc b/mojo/public/cpp/bindings/lib/connector.cc new file mode 100644 index 0000000000..d93e45ed93 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/connector.cc @@ -0,0 +1,493 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/connector.h" + +#include <stdint.h> +#include <utility> + +#include "base/bind.h" +#include "base/lazy_instance.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/message_loop/message_loop.h" +#include "base/synchronization/lock.h" +#include "base/threading/thread_local.h" +#include "mojo/public/cpp/bindings/lib/may_auto_lock.h" +#include "mojo/public/cpp/bindings/sync_handle_watcher.h" +#include "mojo/public/cpp/system/wait.h" + +namespace mojo { + +namespace { + +// The NestingObserver for each thread. Note that this is always a +// Connector::MessageLoopNestingObserver; we use the base type here because that +// subclass is private to Connector. +base::LazyInstance< + base::ThreadLocalPointer<base::MessageLoop::NestingObserver>>::Leaky + g_tls_nesting_observer = LAZY_INSTANCE_INITIALIZER; + +} // namespace + +// Used to efficiently maintain a doubly-linked list of all Connectors +// currently dispatching on any given thread. +class Connector::ActiveDispatchTracker { + public: + explicit ActiveDispatchTracker(const base::WeakPtr<Connector>& connector); + ~ActiveDispatchTracker(); + + void NotifyBeginNesting(); + + private: + const base::WeakPtr<Connector> connector_; + MessageLoopNestingObserver* const nesting_observer_; + ActiveDispatchTracker* outer_tracker_ = nullptr; + ActiveDispatchTracker* inner_tracker_ = nullptr; + + DISALLOW_COPY_AND_ASSIGN(ActiveDispatchTracker); +}; + +// Watches the MessageLoop on the current thread. Notifies the current chain of +// ActiveDispatchTrackers when a nested message loop is started. +class Connector::MessageLoopNestingObserver + : public base::MessageLoop::NestingObserver, + public base::MessageLoop::DestructionObserver { + public: + MessageLoopNestingObserver() { + base::MessageLoop::current()->AddNestingObserver(this); + base::MessageLoop::current()->AddDestructionObserver(this); + } + + ~MessageLoopNestingObserver() override {} + + // base::MessageLoop::NestingObserver: + void OnBeginNestedMessageLoop() override { + if (top_tracker_) + top_tracker_->NotifyBeginNesting(); + } + + // base::MessageLoop::DestructionObserver: + void WillDestroyCurrentMessageLoop() override { + base::MessageLoop::current()->RemoveNestingObserver(this); + base::MessageLoop::current()->RemoveDestructionObserver(this); + DCHECK_EQ(this, g_tls_nesting_observer.Get().Get()); + g_tls_nesting_observer.Get().Set(nullptr); + delete this; + } + + static MessageLoopNestingObserver* GetForThread() { + if (!base::MessageLoop::current() || + !base::MessageLoop::current()->nesting_allowed()) + return nullptr; + auto* observer = static_cast<MessageLoopNestingObserver*>( + g_tls_nesting_observer.Get().Get()); + if (!observer) { + observer = new MessageLoopNestingObserver; + g_tls_nesting_observer.Get().Set(observer); + } + return observer; + } + + private: + friend class ActiveDispatchTracker; + + ActiveDispatchTracker* top_tracker_ = nullptr; + + DISALLOW_COPY_AND_ASSIGN(MessageLoopNestingObserver); +}; + +Connector::ActiveDispatchTracker::ActiveDispatchTracker( + const base::WeakPtr<Connector>& connector) + : connector_(connector), nesting_observer_(connector_->nesting_observer_) { + DCHECK(nesting_observer_); + if (nesting_observer_->top_tracker_) { + outer_tracker_ = nesting_observer_->top_tracker_; + outer_tracker_->inner_tracker_ = this; + } + nesting_observer_->top_tracker_ = this; +} + +Connector::ActiveDispatchTracker::~ActiveDispatchTracker() { + if (nesting_observer_->top_tracker_ == this) + nesting_observer_->top_tracker_ = outer_tracker_; + else if (inner_tracker_) + inner_tracker_->outer_tracker_ = outer_tracker_; + if (outer_tracker_) + outer_tracker_->inner_tracker_ = inner_tracker_; +} + +void Connector::ActiveDispatchTracker::NotifyBeginNesting() { + if (connector_ && connector_->handle_watcher_) + connector_->handle_watcher_->ArmOrNotify(); + if (outer_tracker_) + outer_tracker_->NotifyBeginNesting(); +} + +Connector::Connector(ScopedMessagePipeHandle message_pipe, + ConnectorConfig config, + scoped_refptr<base::SingleThreadTaskRunner> runner) + : message_pipe_(std::move(message_pipe)), + task_runner_(std::move(runner)), + nesting_observer_(MessageLoopNestingObserver::GetForThread()), + weak_factory_(this) { + if (config == MULTI_THREADED_SEND) + lock_.emplace(); + + weak_self_ = weak_factory_.GetWeakPtr(); + // Even though we don't have an incoming receiver, we still want to monitor + // the message pipe to know if is closed or encounters an error. + WaitToReadMore(); +} + +Connector::~Connector() { + { + // Allow for quick destruction on any thread if the pipe is already closed. + base::AutoLock lock(connected_lock_); + if (!connected_) + return; + } + + DCHECK(thread_checker_.CalledOnValidThread()); + CancelWait(); +} + +void Connector::CloseMessagePipe() { + // Throw away the returned message pipe. + PassMessagePipe(); +} + +ScopedMessagePipeHandle Connector::PassMessagePipe() { + DCHECK(thread_checker_.CalledOnValidThread()); + + CancelWait(); + internal::MayAutoLock locker(&lock_); + ScopedMessagePipeHandle message_pipe = std::move(message_pipe_); + weak_factory_.InvalidateWeakPtrs(); + sync_handle_watcher_callback_count_ = 0; + + base::AutoLock lock(connected_lock_); + connected_ = false; + return message_pipe; +} + +void Connector::RaiseError() { + DCHECK(thread_checker_.CalledOnValidThread()); + + HandleError(true, true); +} + +bool Connector::WaitForIncomingMessage(MojoDeadline deadline) { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (error_) + return false; + + ResumeIncomingMethodCallProcessing(); + + // TODO(rockot): Use a timed Wait here. Nobody uses anything but 0 or + // INDEFINITE deadlines at present, so we only support those. + DCHECK(deadline == 0 || deadline == MOJO_DEADLINE_INDEFINITE); + + MojoResult rv = MOJO_RESULT_UNKNOWN; + if (deadline == 0 && !message_pipe_->QuerySignalsState().readable()) + return false; + + if (deadline == MOJO_DEADLINE_INDEFINITE) { + rv = Wait(message_pipe_.get(), MOJO_HANDLE_SIGNAL_READABLE); + if (rv != MOJO_RESULT_OK) { + // Users that call WaitForIncomingMessage() should expect their code to be + // re-entered, so we call the error handler synchronously. + HandleError(rv != MOJO_RESULT_FAILED_PRECONDITION, false); + return false; + } + } + + ignore_result(ReadSingleMessage(&rv)); + return (rv == MOJO_RESULT_OK); +} + +void Connector::PauseIncomingMethodCallProcessing() { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (paused_) + return; + + paused_ = true; + CancelWait(); +} + +void Connector::ResumeIncomingMethodCallProcessing() { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (!paused_) + return; + + paused_ = false; + WaitToReadMore(); +} + +bool Connector::Accept(Message* message) { + DCHECK(lock_ || thread_checker_.CalledOnValidThread()); + + // It shouldn't hurt even if |error_| may be changed by a different thread at + // the same time. The outcome is that we may write into |message_pipe_| after + // encountering an error, which should be fine. + if (error_) + return false; + + internal::MayAutoLock locker(&lock_); + + if (!message_pipe_.is_valid() || drop_writes_) + return true; + + MojoResult rv = + WriteMessageNew(message_pipe_.get(), message->TakeMojoMessage(), + MOJO_WRITE_MESSAGE_FLAG_NONE); + + switch (rv) { + case MOJO_RESULT_OK: + break; + case MOJO_RESULT_FAILED_PRECONDITION: + // There's no point in continuing to write to this pipe since the other + // end is gone. Avoid writing any future messages. Hide write failures + // from the caller since we'd like them to continue consuming any backlog + // of incoming messages before regarding the message pipe as closed. + drop_writes_ = true; + break; + case MOJO_RESULT_BUSY: + // We'd get a "busy" result if one of the message's handles is: + // - |message_pipe_|'s own handle; + // - simultaneously being used on another thread; or + // - in a "busy" state that prohibits it from being transferred (e.g., + // a data pipe handle in the middle of a two-phase read/write, + // regardless of which thread that two-phase read/write is happening + // on). + // TODO(vtl): I wonder if this should be a |DCHECK()|. (But, until + // crbug.com/389666, etc. are resolved, this will make tests fail quickly + // rather than hanging.) + CHECK(false) << "Race condition or other bug detected"; + return false; + default: + // This particular write was rejected, presumably because of bad input. + // The pipe is not necessarily in a bad state. + return false; + } + return true; +} + +void Connector::AllowWokenUpBySyncWatchOnSameThread() { + DCHECK(thread_checker_.CalledOnValidThread()); + + allow_woken_up_by_others_ = true; + + EnsureSyncWatcherExists(); + sync_watcher_->AllowWokenUpBySyncWatchOnSameThread(); +} + +bool Connector::SyncWatch(const bool* should_stop) { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (error_) + return false; + + ResumeIncomingMethodCallProcessing(); + + EnsureSyncWatcherExists(); + return sync_watcher_->SyncWatch(should_stop); +} + +void Connector::SetWatcherHeapProfilerTag(const char* tag) { + heap_profiler_tag_ = tag; + if (handle_watcher_) { + handle_watcher_->set_heap_profiler_tag(tag); + } +} + +void Connector::OnWatcherHandleReady(MojoResult result) { + OnHandleReadyInternal(result); +} + +void Connector::OnSyncHandleWatcherHandleReady(MojoResult result) { + base::WeakPtr<Connector> weak_self(weak_self_); + + sync_handle_watcher_callback_count_++; + OnHandleReadyInternal(result); + // At this point, this object might have been deleted. + if (weak_self) { + DCHECK_LT(0u, sync_handle_watcher_callback_count_); + sync_handle_watcher_callback_count_--; + } +} + +void Connector::OnHandleReadyInternal(MojoResult result) { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (result != MOJO_RESULT_OK) { + HandleError(result != MOJO_RESULT_FAILED_PRECONDITION, false); + return; + } + + ReadAllAvailableMessages(); + // At this point, this object might have been deleted. Return. +} + +void Connector::WaitToReadMore() { + CHECK(!paused_); + DCHECK(!handle_watcher_); + + handle_watcher_.reset(new SimpleWatcher( + FROM_HERE, SimpleWatcher::ArmingPolicy::MANUAL, task_runner_)); + if (heap_profiler_tag_) + handle_watcher_->set_heap_profiler_tag(heap_profiler_tag_); + MojoResult rv = handle_watcher_->Watch( + message_pipe_.get(), MOJO_HANDLE_SIGNAL_READABLE, + base::Bind(&Connector::OnWatcherHandleReady, base::Unretained(this))); + + if (rv != MOJO_RESULT_OK) { + // If the watch failed because the handle is invalid or its conditions can + // no longer be met, we signal the error asynchronously to avoid reentry. + task_runner_->PostTask( + FROM_HERE, + base::Bind(&Connector::OnWatcherHandleReady, weak_self_, rv)); + } else { + handle_watcher_->ArmOrNotify(); + } + + if (allow_woken_up_by_others_) { + EnsureSyncWatcherExists(); + sync_watcher_->AllowWokenUpBySyncWatchOnSameThread(); + } +} + +bool Connector::ReadSingleMessage(MojoResult* read_result) { + CHECK(!paused_); + + bool receiver_result = false; + + // Detect if |this| was destroyed or the message pipe was closed/transferred + // during message dispatch. + base::WeakPtr<Connector> weak_self = weak_self_; + + Message message; + const MojoResult rv = ReadMessage(message_pipe_.get(), &message); + *read_result = rv; + + if (rv == MOJO_RESULT_OK) { + base::Optional<ActiveDispatchTracker> dispatch_tracker; + if (!is_dispatching_ && nesting_observer_) { + is_dispatching_ = true; + dispatch_tracker.emplace(weak_self); + } + + receiver_result = + incoming_receiver_ && incoming_receiver_->Accept(&message); + + if (!weak_self) + return false; + + if (dispatch_tracker) { + is_dispatching_ = false; + dispatch_tracker.reset(); + } + } else if (rv == MOJO_RESULT_SHOULD_WAIT) { + return true; + } else { + HandleError(rv != MOJO_RESULT_FAILED_PRECONDITION, false); + return false; + } + + if (enforce_errors_from_incoming_receiver_ && !receiver_result) { + HandleError(true, false); + return false; + } + return true; +} + +void Connector::ReadAllAvailableMessages() { + while (!error_) { + base::WeakPtr<Connector> weak_self = weak_self_; + MojoResult rv; + + // May delete |this.| + if (!ReadSingleMessage(&rv)) + return; + + if (!weak_self || paused_) + return; + + DCHECK(rv == MOJO_RESULT_OK || rv == MOJO_RESULT_SHOULD_WAIT); + + if (rv == MOJO_RESULT_SHOULD_WAIT) { + // Attempt to re-arm the Watcher. + MojoResult ready_result; + MojoResult arm_result = handle_watcher_->Arm(&ready_result); + if (arm_result == MOJO_RESULT_OK) + return; + + // The watcher is already ready to notify again. + DCHECK_EQ(MOJO_RESULT_FAILED_PRECONDITION, arm_result); + + if (ready_result == MOJO_RESULT_FAILED_PRECONDITION) { + HandleError(false, false); + return; + } + + // There's more to read now, so we'll just keep looping. + DCHECK_EQ(MOJO_RESULT_OK, ready_result); + } + } +} + +void Connector::CancelWait() { + handle_watcher_.reset(); + sync_watcher_.reset(); +} + +void Connector::HandleError(bool force_pipe_reset, bool force_async_handler) { + if (error_ || !message_pipe_.is_valid()) + return; + + if (paused_) { + // Enforce calling the error handler asynchronously if the user has paused + // receiving messages. We need to wait until the user starts receiving + // messages again. + force_async_handler = true; + } + + if (!force_pipe_reset && force_async_handler) + force_pipe_reset = true; + + if (force_pipe_reset) { + CancelWait(); + internal::MayAutoLock locker(&lock_); + message_pipe_.reset(); + MessagePipe dummy_pipe; + message_pipe_ = std::move(dummy_pipe.handle0); + } else { + CancelWait(); + } + + if (force_async_handler) { + if (!paused_) + WaitToReadMore(); + } else { + error_ = true; + if (!connection_error_handler_.is_null()) + connection_error_handler_.Run(); + } +} + +void Connector::EnsureSyncWatcherExists() { + if (sync_watcher_) + return; + sync_watcher_.reset(new SyncHandleWatcher( + message_pipe_.get(), MOJO_HANDLE_SIGNAL_READABLE, + base::Bind(&Connector::OnSyncHandleWatcherHandleReady, + base::Unretained(this)))); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/control_message_handler.cc b/mojo/public/cpp/bindings/lib/control_message_handler.cc new file mode 100644 index 0000000000..1b7bb78e5f --- /dev/null +++ b/mojo/public/cpp/bindings/lib/control_message_handler.cc @@ -0,0 +1,150 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/control_message_handler.h" + +#include <stddef.h> +#include <stdint.h> +#include <utility> + +#include "base/logging.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/lib/message_builder.h" +#include "mojo/public/cpp/bindings/lib/serialization.h" +#include "mojo/public/cpp/bindings/lib/validation_util.h" +#include "mojo/public/interfaces/bindings/interface_control_messages.mojom.h" + +namespace mojo { +namespace internal { +namespace { + +bool ValidateControlRequestWithResponse(Message* message) { + ValidationContext validation_context(message->payload(), + message->payload_num_bytes(), 0, 0, + message, "ControlRequestValidator"); + if (!ValidateMessageIsRequestExpectingResponse(message, &validation_context)) + return false; + + switch (message->header()->name) { + case interface_control::kRunMessageId: + return ValidateMessagePayload< + interface_control::internal::RunMessageParams_Data>( + message, &validation_context); + } + return false; +} + +bool ValidateControlRequestWithoutResponse(Message* message) { + ValidationContext validation_context(message->payload(), + message->payload_num_bytes(), 0, 0, + message, "ControlRequestValidator"); + if (!ValidateMessageIsRequestWithoutResponse(message, &validation_context)) + return false; + + switch (message->header()->name) { + case interface_control::kRunOrClosePipeMessageId: + return ValidateMessageIsRequestWithoutResponse(message, + &validation_context) && + ValidateMessagePayload< + interface_control::internal::RunOrClosePipeMessageParams_Data>( + message, &validation_context); + } + return false; +} + +} // namespace + +// static +bool ControlMessageHandler::IsControlMessage(const Message* message) { + return message->header()->name == interface_control::kRunMessageId || + message->header()->name == interface_control::kRunOrClosePipeMessageId; +} + +ControlMessageHandler::ControlMessageHandler(uint32_t interface_version) + : interface_version_(interface_version) { +} + +ControlMessageHandler::~ControlMessageHandler() { +} + +bool ControlMessageHandler::Accept(Message* message) { + if (!ValidateControlRequestWithoutResponse(message)) + return false; + + if (message->header()->name == interface_control::kRunOrClosePipeMessageId) + return RunOrClosePipe(message); + + NOTREACHED(); + return false; +} + +bool ControlMessageHandler::AcceptWithResponder( + Message* message, + std::unique_ptr<MessageReceiverWithStatus> responder) { + if (!ValidateControlRequestWithResponse(message)) + return false; + + if (message->header()->name == interface_control::kRunMessageId) + return Run(message, std::move(responder)); + + NOTREACHED(); + return false; +} + +bool ControlMessageHandler::Run( + Message* message, + std::unique_ptr<MessageReceiverWithStatus> responder) { + interface_control::internal::RunMessageParams_Data* params = + reinterpret_cast<interface_control::internal::RunMessageParams_Data*>( + message->mutable_payload()); + interface_control::RunMessageParamsPtr params_ptr; + Deserialize<interface_control::RunMessageParamsDataView>(params, ¶ms_ptr, + &context_); + auto& input = *params_ptr->input; + interface_control::RunOutputPtr output = interface_control::RunOutput::New(); + if (input.is_query_version()) { + output->set_query_version_result( + interface_control::QueryVersionResult::New()); + output->get_query_version_result()->version = interface_version_; + } else if (input.is_flush_for_testing()) { + output.reset(); + } else { + output.reset(); + } + + auto response_params_ptr = interface_control::RunResponseMessageParams::New(); + response_params_ptr->output = std::move(output); + size_t size = + PrepareToSerialize<interface_control::RunResponseMessageParamsDataView>( + response_params_ptr, &context_); + MessageBuilder builder(interface_control::kRunMessageId, + Message::kFlagIsResponse, size, 0); + builder.message()->set_request_id(message->request_id()); + + interface_control::internal::RunResponseMessageParams_Data* response_params = + nullptr; + Serialize<interface_control::RunResponseMessageParamsDataView>( + response_params_ptr, builder.buffer(), &response_params, &context_); + ignore_result(responder->Accept(builder.message())); + + return true; +} + +bool ControlMessageHandler::RunOrClosePipe(Message* message) { + interface_control::internal::RunOrClosePipeMessageParams_Data* params = + reinterpret_cast< + interface_control::internal::RunOrClosePipeMessageParams_Data*>( + message->mutable_payload()); + interface_control::RunOrClosePipeMessageParamsPtr params_ptr; + Deserialize<interface_control::RunOrClosePipeMessageParamsDataView>( + params, ¶ms_ptr, &context_); + auto& input = *params_ptr->input; + if (input.is_require_version()) + return interface_version_ >= input.get_require_version()->version; + + return false; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/control_message_handler.h b/mojo/public/cpp/bindings/lib/control_message_handler.h new file mode 100644 index 0000000000..5d1f716ea8 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/control_message_handler.h @@ -0,0 +1,48 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_CONTROL_MESSAGE_HANDLER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_CONTROL_MESSAGE_HANDLER_H_ + +#include <stdint.h> + +#include "base/compiler_specific.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/serialization_context.h" +#include "mojo/public/cpp/bindings/message.h" + +namespace mojo { +namespace internal { + +// Handlers for request messages defined in interface_control_messages.mojom. +class MOJO_CPP_BINDINGS_EXPORT ControlMessageHandler + : NON_EXPORTED_BASE(public MessageReceiverWithResponderStatus) { + public: + static bool IsControlMessage(const Message* message); + + explicit ControlMessageHandler(uint32_t interface_version); + ~ControlMessageHandler() override; + + // Call the following methods only if IsControlMessage() returned true. + bool Accept(Message* message) override; + bool AcceptWithResponder( + Message* message, + std::unique_ptr<MessageReceiverWithStatus> responder) override; + + private: + bool Run(Message* message, + std::unique_ptr<MessageReceiverWithStatus> responder); + bool RunOrClosePipe(Message* message); + + uint32_t interface_version_; + SerializationContext context_; + + DISALLOW_COPY_AND_ASSIGN(ControlMessageHandler); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_CONTROL_MESSAGE_HANDLER_H_ diff --git a/mojo/public/cpp/bindings/lib/control_message_proxy.cc b/mojo/public/cpp/bindings/lib/control_message_proxy.cc new file mode 100644 index 0000000000..d082b49fb3 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/control_message_proxy.cc @@ -0,0 +1,188 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/control_message_proxy.h" + +#include <stddef.h> +#include <stdint.h> +#include <utility> + +#include "base/bind.h" +#include "base/callback_helpers.h" +#include "base/macros.h" +#include "base/run_loop.h" +#include "mojo/public/cpp/bindings/lib/message_builder.h" +#include "mojo/public/cpp/bindings/lib/serialization.h" +#include "mojo/public/cpp/bindings/lib/validation_util.h" +#include "mojo/public/cpp/bindings/message.h" +#include "mojo/public/interfaces/bindings/interface_control_messages.mojom.h" + +namespace mojo { +namespace internal { + +namespace { + +bool ValidateControlResponse(Message* message) { + ValidationContext validation_context(message->payload(), + message->payload_num_bytes(), 0, 0, + message, "ControlResponseValidator"); + if (!ValidateMessageIsResponse(message, &validation_context)) + return false; + + switch (message->header()->name) { + case interface_control::kRunMessageId: + return ValidateMessagePayload< + interface_control::internal::RunResponseMessageParams_Data>( + message, &validation_context); + } + return false; +} + +using RunCallback = + base::Callback<void(interface_control::RunResponseMessageParamsPtr)>; + +class RunResponseForwardToCallback : public MessageReceiver { + public: + explicit RunResponseForwardToCallback(const RunCallback& callback) + : callback_(callback) {} + bool Accept(Message* message) override; + + private: + RunCallback callback_; + DISALLOW_COPY_AND_ASSIGN(RunResponseForwardToCallback); +}; + +bool RunResponseForwardToCallback::Accept(Message* message) { + if (!ValidateControlResponse(message)) + return false; + + interface_control::internal::RunResponseMessageParams_Data* params = + reinterpret_cast< + interface_control::internal::RunResponseMessageParams_Data*>( + message->mutable_payload()); + interface_control::RunResponseMessageParamsPtr params_ptr; + SerializationContext context; + Deserialize<interface_control::RunResponseMessageParamsDataView>( + params, ¶ms_ptr, &context); + + callback_.Run(std::move(params_ptr)); + return true; +} + +void SendRunMessage(MessageReceiverWithResponder* receiver, + interface_control::RunInputPtr input_ptr, + const RunCallback& callback) { + SerializationContext context; + + auto params_ptr = interface_control::RunMessageParams::New(); + params_ptr->input = std::move(input_ptr); + size_t size = PrepareToSerialize<interface_control::RunMessageParamsDataView>( + params_ptr, &context); + MessageBuilder builder(interface_control::kRunMessageId, + Message::kFlagExpectsResponse, size, 0); + + interface_control::internal::RunMessageParams_Data* params = nullptr; + Serialize<interface_control::RunMessageParamsDataView>( + params_ptr, builder.buffer(), ¶ms, &context); + std::unique_ptr<MessageReceiver> responder = + base::MakeUnique<RunResponseForwardToCallback>(callback); + ignore_result( + receiver->AcceptWithResponder(builder.message(), std::move(responder))); +} + +Message ConstructRunOrClosePipeMessage( + interface_control::RunOrClosePipeInputPtr input_ptr) { + SerializationContext context; + + auto params_ptr = interface_control::RunOrClosePipeMessageParams::New(); + params_ptr->input = std::move(input_ptr); + + size_t size = PrepareToSerialize< + interface_control::RunOrClosePipeMessageParamsDataView>(params_ptr, + &context); + MessageBuilder builder(interface_control::kRunOrClosePipeMessageId, 0, size, + 0); + + interface_control::internal::RunOrClosePipeMessageParams_Data* params = + nullptr; + Serialize<interface_control::RunOrClosePipeMessageParamsDataView>( + params_ptr, builder.buffer(), ¶ms, &context); + return std::move(*builder.message()); +} + +void SendRunOrClosePipeMessage( + MessageReceiverWithResponder* receiver, + interface_control::RunOrClosePipeInputPtr input_ptr) { + Message message(ConstructRunOrClosePipeMessage(std::move(input_ptr))); + + ignore_result(receiver->Accept(&message)); +} + +void RunVersionCallback( + const base::Callback<void(uint32_t)>& callback, + interface_control::RunResponseMessageParamsPtr run_response) { + uint32_t version = 0u; + if (run_response->output && run_response->output->is_query_version_result()) + version = run_response->output->get_query_version_result()->version; + callback.Run(version); +} + +void RunClosure(const base::Closure& callback, + interface_control::RunResponseMessageParamsPtr run_response) { + callback.Run(); +} + +} // namespace + +ControlMessageProxy::ControlMessageProxy(MessageReceiverWithResponder* receiver) + : receiver_(receiver) { +} + +ControlMessageProxy::~ControlMessageProxy() = default; + +void ControlMessageProxy::QueryVersion( + const base::Callback<void(uint32_t)>& callback) { + auto input_ptr = interface_control::RunInput::New(); + input_ptr->set_query_version(interface_control::QueryVersion::New()); + SendRunMessage(receiver_, std::move(input_ptr), + base::Bind(&RunVersionCallback, callback)); +} + +void ControlMessageProxy::RequireVersion(uint32_t version) { + auto require_version = interface_control::RequireVersion::New(); + require_version->version = version; + auto input_ptr = interface_control::RunOrClosePipeInput::New(); + input_ptr->set_require_version(std::move(require_version)); + SendRunOrClosePipeMessage(receiver_, std::move(input_ptr)); +} + +void ControlMessageProxy::FlushForTesting() { + if (encountered_error_) + return; + + auto input_ptr = interface_control::RunInput::New(); + input_ptr->set_flush_for_testing(interface_control::FlushForTesting::New()); + base::RunLoop run_loop; + run_loop_quit_closure_ = run_loop.QuitClosure(); + SendRunMessage( + receiver_, std::move(input_ptr), + base::Bind(&RunClosure, + base::Bind(&ControlMessageProxy::RunFlushForTestingClosure, + base::Unretained(this)))); + run_loop.Run(); +} + +void ControlMessageProxy::RunFlushForTestingClosure() { + DCHECK(!run_loop_quit_closure_.is_null()); + base::ResetAndReturn(&run_loop_quit_closure_).Run(); +} + +void ControlMessageProxy::OnConnectionError() { + encountered_error_ = true; + if (!run_loop_quit_closure_.is_null()) + RunFlushForTestingClosure(); +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/control_message_proxy.h b/mojo/public/cpp/bindings/lib/control_message_proxy.h new file mode 100644 index 0000000000..2f9314ebf0 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/control_message_proxy.h @@ -0,0 +1,49 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_CONTROL_MESSAGE_PROXY_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_CONTROL_MESSAGE_PROXY_H_ + +#include <stdint.h> + +#include "base/callback.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/serialization_context.h" + +namespace mojo { + +class MessageReceiverWithResponder; + +namespace internal { + +// Proxy for request messages defined in interface_control_messages.mojom. +class MOJO_CPP_BINDINGS_EXPORT ControlMessageProxy { + public: + // Doesn't take ownership of |receiver|. It must outlive this object. + explicit ControlMessageProxy(MessageReceiverWithResponder* receiver); + ~ControlMessageProxy(); + + void QueryVersion(const base::Callback<void(uint32_t)>& callback); + void RequireVersion(uint32_t version); + + void FlushForTesting(); + void OnConnectionError(); + + private: + void RunFlushForTestingClosure(); + + // Not owned. + MessageReceiverWithResponder* receiver_; + bool encountered_error_ = false; + + base::Closure run_loop_quit_closure_; + + DISALLOW_COPY_AND_ASSIGN(ControlMessageProxy); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_CONTROL_MESSAGE_PROXY_H_ diff --git a/mojo/public/cpp/bindings/lib/equals_traits.h b/mojo/public/cpp/bindings/lib/equals_traits.h new file mode 100644 index 0000000000..53c7dce693 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/equals_traits.h @@ -0,0 +1,94 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_EQUALS_TRAITS_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_EQUALS_TRAITS_H_ + +#include <type_traits> +#include <unordered_map> +#include <vector> + +#include "base/optional.h" +#include "mojo/public/cpp/bindings/lib/template_util.h" + +namespace mojo { +namespace internal { + +template <typename T> +struct HasEqualsMethod { + template <typename U> + static char Test(decltype(&U::Equals)); + template <typename U> + static int Test(...); + static const bool value = sizeof(Test<T>(0)) == sizeof(char); + + private: + EnsureTypeIsComplete<T> check_t_; +}; + +template <typename T, bool has_equals_method = HasEqualsMethod<T>::value> +struct EqualsTraits; + +template <typename T> +bool Equals(const T& a, const T& b); + +template <typename T> +struct EqualsTraits<T, true> { + static bool Equals(const T& a, const T& b) { return a.Equals(b); } +}; + +template <typename T> +struct EqualsTraits<T, false> { + static bool Equals(const T& a, const T& b) { return a == b; } +}; + +template <typename T> +struct EqualsTraits<base::Optional<T>, false> { + static bool Equals(const base::Optional<T>& a, const base::Optional<T>& b) { + if (!a && !b) + return true; + if (!a || !b) + return false; + + return internal::Equals(*a, *b); + } +}; + +template <typename T> +struct EqualsTraits<std::vector<T>, false> { + static bool Equals(const std::vector<T>& a, const std::vector<T>& b) { + if (a.size() != b.size()) + return false; + for (size_t i = 0; i < a.size(); ++i) { + if (!internal::Equals(a[i], b[i])) + return false; + } + return true; + } +}; + +template <typename K, typename V> +struct EqualsTraits<std::unordered_map<K, V>, false> { + static bool Equals(const std::unordered_map<K, V>& a, + const std::unordered_map<K, V>& b) { + if (a.size() != b.size()) + return false; + for (const auto& element : a) { + auto iter = b.find(element.first); + if (iter == b.end() || !internal::Equals(element.second, iter->second)) + return false; + } + return true; + } +}; + +template <typename T> +bool Equals(const T& a, const T& b) { + return EqualsTraits<T>::Equals(a, b); +} + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_EQUALS_TRAITS_H_ diff --git a/mojo/public/cpp/bindings/lib/filter_chain.cc b/mojo/public/cpp/bindings/lib/filter_chain.cc new file mode 100644 index 0000000000..5d919fe172 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/filter_chain.cc @@ -0,0 +1,47 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/filter_chain.h" + +#include <algorithm> + +#include "base/logging.h" + +namespace mojo { + +FilterChain::FilterChain(MessageReceiver* sink) : sink_(sink) { +} + +FilterChain::FilterChain(FilterChain&& other) : sink_(other.sink_) { + other.sink_ = nullptr; + filters_.swap(other.filters_); +} + +FilterChain& FilterChain::operator=(FilterChain&& other) { + std::swap(sink_, other.sink_); + filters_.swap(other.filters_); + return *this; +} + +FilterChain::~FilterChain() { +} + +void FilterChain::SetSink(MessageReceiver* sink) { + DCHECK(!sink_); + sink_ = sink; +} + +bool FilterChain::Accept(Message* message) { + DCHECK(sink_); + for (auto& filter : filters_) + if (!filter->Accept(message)) + return false; + return sink_->Accept(message); +} + +void FilterChain::Append(std::unique_ptr<MessageReceiver> filter) { + filters_.emplace_back(std::move(filter)); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/fixed_buffer.cc b/mojo/public/cpp/bindings/lib/fixed_buffer.cc new file mode 100644 index 0000000000..725a193cd7 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/fixed_buffer.cc @@ -0,0 +1,30 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/fixed_buffer.h" + +#include <stdlib.h> + +namespace mojo { +namespace internal { + +FixedBufferForTesting::FixedBufferForTesting(size_t size) { + size = internal::Align(size); + // Use calloc here to ensure all message memory is zero'd out. + void* ptr = calloc(size, 1); + Initialize(ptr, size); +} + +FixedBufferForTesting::~FixedBufferForTesting() { + free(data()); +} + +void* FixedBufferForTesting::Leak() { + void* ptr = data(); + Initialize(nullptr, 0); + return ptr; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/fixed_buffer.h b/mojo/public/cpp/bindings/lib/fixed_buffer.h new file mode 100644 index 0000000000..070b0c8cef --- /dev/null +++ b/mojo/public/cpp/bindings/lib/fixed_buffer.h @@ -0,0 +1,39 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_FIXED_BUFFER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_FIXED_BUFFER_H_ + +#include <stddef.h> + +#include "base/compiler_specific.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/buffer.h" + +namespace mojo { +namespace internal { + +// FixedBufferForTesting owns its buffer. The Leak method may be used to steal +// the underlying memory. +class MOJO_CPP_BINDINGS_EXPORT FixedBufferForTesting + : NON_EXPORTED_BASE(public Buffer) { + public: + explicit FixedBufferForTesting(size_t size); + ~FixedBufferForTesting(); + + // Returns the internal memory owned by the Buffer to the caller. The Buffer + // relinquishes its pointer, effectively resetting the state of the Buffer + // and leaving the caller responsible for freeing the returned memory address + // when no longer needed. + void* Leak(); + + private: + DISALLOW_COPY_AND_ASSIGN(FixedBufferForTesting); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_FIXED_BUFFER_H_ diff --git a/mojo/public/cpp/bindings/lib/handle_interface_serialization.h b/mojo/public/cpp/bindings/lib/handle_interface_serialization.h new file mode 100644 index 0000000000..14ed21f0ac --- /dev/null +++ b/mojo/public/cpp/bindings/lib/handle_interface_serialization.h @@ -0,0 +1,181 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_INTERFACE_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_INTERFACE_SERIALIZATION_H_ + +#include <type_traits> + +#include "mojo/public/cpp/bindings/associated_group_controller.h" +#include "mojo/public/cpp/bindings/associated_interface_ptr_info.h" +#include "mojo/public/cpp/bindings/associated_interface_request.h" +#include "mojo/public/cpp/bindings/interface_data_view.h" +#include "mojo/public/cpp/bindings/interface_ptr.h" +#include "mojo/public/cpp/bindings/interface_request.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_context.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" +#include "mojo/public/cpp/system/handle.h" + +namespace mojo { +namespace internal { + +template <typename Base, typename T> +struct Serializer<AssociatedInterfacePtrInfoDataView<Base>, + AssociatedInterfacePtrInfo<T>> { + static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); + + static size_t PrepareToSerialize(const AssociatedInterfacePtrInfo<T>& input, + SerializationContext* context) { + if (input.handle().is_valid()) + context->associated_endpoint_count++; + return 0; + } + + static void Serialize(AssociatedInterfacePtrInfo<T>& input, + AssociatedInterface_Data* output, + SerializationContext* context) { + DCHECK(!input.handle().is_valid() || input.handle().pending_association()); + if (input.handle().is_valid()) { + // Set to the index of the element pushed to the back of the vector. + output->handle.value = + static_cast<uint32_t>(context->associated_endpoint_handles.size()); + context->associated_endpoint_handles.push_back(input.PassHandle()); + } else { + output->handle.value = kEncodedInvalidHandleValue; + } + output->version = input.version(); + } + + static bool Deserialize(AssociatedInterface_Data* input, + AssociatedInterfacePtrInfo<T>* output, + SerializationContext* context) { + if (input->handle.is_valid()) { + DCHECK_LT(input->handle.value, + context->associated_endpoint_handles.size()); + output->set_handle( + std::move(context->associated_endpoint_handles[input->handle.value])); + } else { + output->set_handle(ScopedInterfaceEndpointHandle()); + } + output->set_version(input->version); + return true; + } +}; + +template <typename Base, typename T> +struct Serializer<AssociatedInterfaceRequestDataView<Base>, + AssociatedInterfaceRequest<T>> { + static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); + + static size_t PrepareToSerialize(const AssociatedInterfaceRequest<T>& input, + SerializationContext* context) { + if (input.handle().is_valid()) + context->associated_endpoint_count++; + return 0; + } + + static void Serialize(AssociatedInterfaceRequest<T>& input, + AssociatedEndpointHandle_Data* output, + SerializationContext* context) { + DCHECK(!input.handle().is_valid() || input.handle().pending_association()); + if (input.handle().is_valid()) { + // Set to the index of the element pushed to the back of the vector. + output->value = + static_cast<uint32_t>(context->associated_endpoint_handles.size()); + context->associated_endpoint_handles.push_back(input.PassHandle()); + } else { + output->value = kEncodedInvalidHandleValue; + } + } + + static bool Deserialize(AssociatedEndpointHandle_Data* input, + AssociatedInterfaceRequest<T>* output, + SerializationContext* context) { + if (input->is_valid()) { + DCHECK_LT(input->value, context->associated_endpoint_handles.size()); + output->Bind( + std::move(context->associated_endpoint_handles[input->value])); + } else { + output->Bind(ScopedInterfaceEndpointHandle()); + } + return true; + } +}; + +template <typename Base, typename T> +struct Serializer<InterfacePtrDataView<Base>, InterfacePtr<T>> { + static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); + + static size_t PrepareToSerialize(const InterfacePtr<T>& input, + SerializationContext* context) { + return 0; + } + + static void Serialize(InterfacePtr<T>& input, + Interface_Data* output, + SerializationContext* context) { + InterfacePtrInfo<T> info = input.PassInterface(); + output->handle = context->handles.AddHandle(info.PassHandle().release()); + output->version = info.version(); + } + + static bool Deserialize(Interface_Data* input, + InterfacePtr<T>* output, + SerializationContext* context) { + output->Bind(InterfacePtrInfo<T>( + context->handles.TakeHandleAs<mojo::MessagePipeHandle>(input->handle), + input->version)); + return true; + } +}; + +template <typename Base, typename T> +struct Serializer<InterfaceRequestDataView<Base>, InterfaceRequest<T>> { + static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); + + static size_t PrepareToSerialize(const InterfaceRequest<T>& input, + SerializationContext* context) { + return 0; + } + + static void Serialize(InterfaceRequest<T>& input, + Handle_Data* output, + SerializationContext* context) { + *output = context->handles.AddHandle(input.PassMessagePipe().release()); + } + + static bool Deserialize(Handle_Data* input, + InterfaceRequest<T>* output, + SerializationContext* context) { + output->Bind(context->handles.TakeHandleAs<MessagePipeHandle>(*input)); + return true; + } +}; + +template <typename T> +struct Serializer<ScopedHandleBase<T>, ScopedHandleBase<T>> { + static size_t PrepareToSerialize(const ScopedHandleBase<T>& input, + SerializationContext* context) { + return 0; + } + + static void Serialize(ScopedHandleBase<T>& input, + Handle_Data* output, + SerializationContext* context) { + *output = context->handles.AddHandle(input.release()); + } + + static bool Deserialize(Handle_Data* input, + ScopedHandleBase<T>* output, + SerializationContext* context) { + *output = context->handles.TakeHandleAs<T>(*input); + return true; + } +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_INTERFACE_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/hash_util.h b/mojo/public/cpp/bindings/lib/hash_util.h new file mode 100644 index 0000000000..93280d69da --- /dev/null +++ b/mojo/public/cpp/bindings/lib/hash_util.h @@ -0,0 +1,84 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_HASH_UTIL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_HASH_UTIL_H_ + +#include <cstring> +#include <functional> +#include <type_traits> +#include <vector> + +#include "base/optional.h" +#include "mojo/public/cpp/bindings/lib/template_util.h" + +namespace mojo { +namespace internal { + +template <typename T> +size_t HashCombine(size_t seed, const T& value) { + // Based on proposal in: + // http://www.open-std.org/JTC1/SC22/WG21/docs/papers/2005/n1756.pdf + return seed ^ (std::hash<T>()(value) + (seed << 6) + (seed >> 2)); +} + +template <typename T> +struct HasHashMethod { + template <typename U> + static char Test(decltype(&U::Hash)); + template <typename U> + static int Test(...); + static const bool value = sizeof(Test<T>(0)) == sizeof(char); + + private: + EnsureTypeIsComplete<T> check_t_; +}; + +template <typename T, bool has_hash_method = HasHashMethod<T>::value> +struct HashTraits; + +template <typename T> +size_t Hash(size_t seed, const T& value); + +template <typename T> +struct HashTraits<T, true> { + static size_t Hash(size_t seed, const T& value) { return value.Hash(seed); } +}; + +template <typename T> +struct HashTraits<T, false> { + static size_t Hash(size_t seed, const T& value) { + return HashCombine(seed, value); + } +}; + +template <typename T> +struct HashTraits<std::vector<T>, false> { + static size_t Hash(size_t seed, const std::vector<T>& value) { + for (const auto& element : value) { + seed = HashCombine(seed, element); + } + return seed; + } +}; + +template <typename T> +struct HashTraits<base::Optional<std::vector<T>>, false> { + static size_t Hash(size_t seed, const base::Optional<std::vector<T>>& value) { + if (!value) + return HashCombine(seed, 0); + + return Hash(seed, *value); + } +}; + +template <typename T> +size_t Hash(size_t seed, const T& value) { + return HashTraits<T>::Hash(seed, value); +} + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_HASH_UTIL_H_ diff --git a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc new file mode 100644 index 0000000000..4682e72fad --- /dev/null +++ b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc @@ -0,0 +1,412 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/interface_endpoint_client.h" + +#include <stdint.h> + +#include <utility> + +#include "base/bind.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/single_thread_task_runner.h" +#include "base/stl_util.h" +#include "mojo/public/cpp/bindings/associated_group.h" +#include "mojo/public/cpp/bindings/associated_group_controller.h" +#include "mojo/public/cpp/bindings/interface_endpoint_controller.h" +#include "mojo/public/cpp/bindings/lib/validation_util.h" +#include "mojo/public/cpp/bindings/sync_call_restrictions.h" + +namespace mojo { + +// ---------------------------------------------------------------------------- + +namespace { + +void DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient>& client, + const std::string& message) { + bool is_valid = client && !client->encountered_error(); + DCHECK(!is_valid) << message; +} + +// When receiving an incoming message which expects a repsonse, +// InterfaceEndpointClient creates a ResponderThunk object and passes it to the +// incoming message receiver. When the receiver finishes processing the message, +// it can provide a response using this object. +class ResponderThunk : public MessageReceiverWithStatus { + public: + explicit ResponderThunk( + const base::WeakPtr<InterfaceEndpointClient>& endpoint_client, + scoped_refptr<base::SingleThreadTaskRunner> runner) + : endpoint_client_(endpoint_client), + accept_was_invoked_(false), + task_runner_(std::move(runner)) {} + ~ResponderThunk() override { + if (!accept_was_invoked_) { + // The Service handled a message that was expecting a response + // but did not send a response. + // We raise an error to signal the calling application that an error + // condition occurred. Without this the calling application would have no + // way of knowing it should stop waiting for a response. + if (task_runner_->RunsTasksOnCurrentThread()) { + // Please note that even if this code is run from a different task + // runner on the same thread as |task_runner_|, it is okay to directly + // call InterfaceEndpointClient::RaiseError(), because it will raise + // error from the correct task runner asynchronously. + if (endpoint_client_) { + endpoint_client_->RaiseError(); + } + } else { + task_runner_->PostTask( + FROM_HERE, + base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_)); + } + } + } + + // MessageReceiver implementation: + bool Accept(Message* message) override { + DCHECK(task_runner_->RunsTasksOnCurrentThread()); + accept_was_invoked_ = true; + DCHECK(message->has_flag(Message::kFlagIsResponse)); + + bool result = false; + + if (endpoint_client_) + result = endpoint_client_->Accept(message); + + return result; + } + + // MessageReceiverWithStatus implementation: + bool IsValid() override { + DCHECK(task_runner_->RunsTasksOnCurrentThread()); + return endpoint_client_ && !endpoint_client_->encountered_error(); + } + + void DCheckInvalid(const std::string& message) override { + if (task_runner_->RunsTasksOnCurrentThread()) { + DCheckIfInvalid(endpoint_client_, message); + } else { + task_runner_->PostTask( + FROM_HERE, base::Bind(&DCheckIfInvalid, endpoint_client_, message)); + } + } + + private: + base::WeakPtr<InterfaceEndpointClient> endpoint_client_; + bool accept_was_invoked_; + scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + + DISALLOW_COPY_AND_ASSIGN(ResponderThunk); +}; + +} // namespace + +// ---------------------------------------------------------------------------- + +InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo( + bool* in_response_received) + : response_received(in_response_received) {} + +InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {} + +// ---------------------------------------------------------------------------- + +InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk( + InterfaceEndpointClient* owner) + : owner_(owner) {} + +InterfaceEndpointClient::HandleIncomingMessageThunk:: + ~HandleIncomingMessageThunk() {} + +bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept( + Message* message) { + return owner_->HandleValidatedMessage(message); +} + +// ---------------------------------------------------------------------------- + +InterfaceEndpointClient::InterfaceEndpointClient( + ScopedInterfaceEndpointHandle handle, + MessageReceiverWithResponderStatus* receiver, + std::unique_ptr<MessageReceiver> payload_validator, + bool expect_sync_requests, + scoped_refptr<base::SingleThreadTaskRunner> runner, + uint32_t interface_version) + : expect_sync_requests_(expect_sync_requests), + handle_(std::move(handle)), + incoming_receiver_(receiver), + thunk_(this), + filters_(&thunk_), + task_runner_(std::move(runner)), + control_message_proxy_(this), + control_message_handler_(interface_version), + weak_ptr_factory_(this) { + DCHECK(handle_.is_valid()); + + // TODO(yzshen): the way to use validator (or message filter in general) + // directly is a little awkward. + if (payload_validator) + filters_.Append(std::move(payload_validator)); + + if (handle_.pending_association()) { + handle_.SetAssociationEventHandler(base::Bind( + &InterfaceEndpointClient::OnAssociationEvent, base::Unretained(this))); + } else { + InitControllerIfNecessary(); + } +} + +InterfaceEndpointClient::~InterfaceEndpointClient() { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (controller_) + handle_.group_controller()->DetachEndpointClient(handle_); +} + +AssociatedGroup* InterfaceEndpointClient::associated_group() { + if (!associated_group_) + associated_group_ = base::MakeUnique<AssociatedGroup>(handle_); + return associated_group_.get(); +} + +ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK(!has_pending_responders()); + + if (!handle_.is_valid()) + return ScopedInterfaceEndpointHandle(); + + handle_.SetAssociationEventHandler( + ScopedInterfaceEndpointHandle::AssociationEventCallback()); + + if (controller_) { + controller_ = nullptr; + handle_.group_controller()->DetachEndpointClient(handle_); + } + + return std::move(handle_); +} + +void InterfaceEndpointClient::AddFilter( + std::unique_ptr<MessageReceiver> filter) { + filters_.Append(std::move(filter)); +} + +void InterfaceEndpointClient::RaiseError() { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (!handle_.pending_association()) + handle_.group_controller()->RaiseError(); +} + +void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason, + const std::string& description) { + DCHECK(thread_checker_.CalledOnValidThread()); + + auto handle = PassHandle(); + handle.ResetWithReason(custom_reason, description); +} + +bool InterfaceEndpointClient::Accept(Message* message) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK(!message->has_flag(Message::kFlagExpectsResponse)); + DCHECK(!handle_.pending_association()); + + // This has to been done even if connection error has occurred. For example, + // the message contains a pending associated request. The user may try to use + // the corresponding associated interface pointer after sending this message. + // That associated interface pointer has to join an associated group in order + // to work properly. + if (!message->associated_endpoint_handles()->empty()) + message->SerializeAssociatedEndpointHandles(handle_.group_controller()); + + if (encountered_error_) + return false; + + InitControllerIfNecessary(); + + return controller_->SendMessage(message); +} + +bool InterfaceEndpointClient::AcceptWithResponder( + Message* message, + std::unique_ptr<MessageReceiver> responder) { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK(message->has_flag(Message::kFlagExpectsResponse)); + DCHECK(!handle_.pending_association()); + + // Please see comments in Accept(). + if (!message->associated_endpoint_handles()->empty()) + message->SerializeAssociatedEndpointHandles(handle_.group_controller()); + + if (encountered_error_) + return false; + + InitControllerIfNecessary(); + + // Reserve 0 in case we want it to convey special meaning in the future. + uint64_t request_id = next_request_id_++; + if (request_id == 0) + request_id = next_request_id_++; + + message->set_request_id(request_id); + + bool is_sync = message->has_flag(Message::kFlagIsSync); + if (!controller_->SendMessage(message)) + return false; + + if (!is_sync) { + async_responders_[request_id] = std::move(responder); + return true; + } + + SyncCallRestrictions::AssertSyncCallAllowed(); + + bool response_received = false; + sync_responses_.insert(std::make_pair( + request_id, base::MakeUnique<SyncResponseInfo>(&response_received))); + + base::WeakPtr<InterfaceEndpointClient> weak_self = + weak_ptr_factory_.GetWeakPtr(); + controller_->SyncWatch(&response_received); + // Make sure that this instance hasn't been destroyed. + if (weak_self) { + DCHECK(base::ContainsKey(sync_responses_, request_id)); + auto iter = sync_responses_.find(request_id); + DCHECK_EQ(&response_received, iter->second->response_received); + if (response_received) + ignore_result(responder->Accept(&iter->second->response)); + sync_responses_.erase(iter); + } + + return true; +} + +bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) { + DCHECK(thread_checker_.CalledOnValidThread()); + return filters_.Accept(message); +} + +void InterfaceEndpointClient::NotifyError( + const base::Optional<DisconnectReason>& reason) { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (encountered_error_) + return; + encountered_error_ = true; + + // Response callbacks may hold on to resource, and there's no need to keep + // them alive any longer. Note that it's allowed that a pending response + // callback may own this endpoint, so we simply move the responders onto the + // stack here and let them be destroyed when the stack unwinds. + AsyncResponderMap responders = std::move(async_responders_); + + control_message_proxy_.OnConnectionError(); + + if (!error_handler_.is_null()) { + base::Closure error_handler = std::move(error_handler_); + error_handler.Run(); + } else if (!error_with_reason_handler_.is_null()) { + ConnectionErrorWithReasonCallback error_with_reason_handler = + std::move(error_with_reason_handler_); + if (reason) { + error_with_reason_handler.Run(reason->custom_reason, reason->description); + } else { + error_with_reason_handler.Run(0, std::string()); + } + } +} + +void InterfaceEndpointClient::QueryVersion( + const base::Callback<void(uint32_t)>& callback) { + control_message_proxy_.QueryVersion(callback); +} + +void InterfaceEndpointClient::RequireVersion(uint32_t version) { + control_message_proxy_.RequireVersion(version); +} + +void InterfaceEndpointClient::FlushForTesting() { + control_message_proxy_.FlushForTesting(); +} + +void InterfaceEndpointClient::InitControllerIfNecessary() { + if (controller_ || handle_.pending_association()) + return; + + controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this, + task_runner_); + if (expect_sync_requests_) + controller_->AllowWokenUpBySyncWatchOnSameThread(); +} + +void InterfaceEndpointClient::OnAssociationEvent( + ScopedInterfaceEndpointHandle::AssociationEvent event) { + if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) { + InitControllerIfNecessary(); + } else if (event == + ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) { + task_runner_->PostTask(FROM_HERE, + base::Bind(&InterfaceEndpointClient::NotifyError, + weak_ptr_factory_.GetWeakPtr(), + handle_.disconnect_reason())); + } +} + +bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) { + DCHECK_EQ(handle_.id(), message->interface_id()); + + if (encountered_error_) { + // This message is received after error has been encountered. For associated + // interfaces, this means the remote side sends a + // PeerAssociatedEndpointClosed event but continues to send more messages + // for the same interface. Close the pipe because this shouldn't happen. + DVLOG(1) << "A message is received for an interface after it has been " + << "disconnected. Closing the pipe."; + return false; + } + + if (message->has_flag(Message::kFlagExpectsResponse)) { + std::unique_ptr<MessageReceiverWithStatus> responder = + base::MakeUnique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(), + task_runner_); + if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) { + return control_message_handler_.AcceptWithResponder(message, + std::move(responder)); + } else { + return incoming_receiver_->AcceptWithResponder(message, + std::move(responder)); + } + } else if (message->has_flag(Message::kFlagIsResponse)) { + uint64_t request_id = message->request_id(); + + if (message->has_flag(Message::kFlagIsSync)) { + auto it = sync_responses_.find(request_id); + if (it == sync_responses_.end()) + return false; + it->second->response = std::move(*message); + *it->second->response_received = true; + return true; + } + + auto it = async_responders_.find(request_id); + if (it == async_responders_.end()) + return false; + std::unique_ptr<MessageReceiver> responder = std::move(it->second); + async_responders_.erase(it); + return responder->Accept(message); + } else { + if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) + return control_message_handler_.Accept(message); + + return incoming_receiver_->Accept(message); + } +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/interface_ptr_state.h b/mojo/public/cpp/bindings/lib/interface_ptr_state.h new file mode 100644 index 0000000000..fa54979795 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/interface_ptr_state.h @@ -0,0 +1,226 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_INTERFACE_PTR_STATE_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_INTERFACE_PTR_STATE_H_ + +#include <stdint.h> + +#include <algorithm> // For |std::swap()|. +#include <memory> +#include <string> +#include <utility> + +#include "base/bind.h" +#include "base/callback_forward.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/memory/ref_counted.h" +#include "base/single_thread_task_runner.h" +#include "mojo/public/cpp/bindings/associated_group.h" +#include "mojo/public/cpp/bindings/connection_error_callback.h" +#include "mojo/public/cpp/bindings/filter_chain.h" +#include "mojo/public/cpp/bindings/interface_endpoint_client.h" +#include "mojo/public/cpp/bindings/interface_id.h" +#include "mojo/public/cpp/bindings/interface_ptr_info.h" +#include "mojo/public/cpp/bindings/lib/multiplex_router.h" +#include "mojo/public/cpp/bindings/message_header_validator.h" +#include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" + +namespace mojo { +namespace internal { + +template <typename Interface> +class InterfacePtrState { + public: + InterfacePtrState() : version_(0u) {} + + ~InterfacePtrState() { + endpoint_client_.reset(); + proxy_.reset(); + if (router_) + router_->CloseMessagePipe(); + } + + Interface* instance() { + ConfigureProxyIfNecessary(); + + // This will be null if the object is not bound. + return proxy_.get(); + } + + uint32_t version() const { return version_; } + + void QueryVersion(const base::Callback<void(uint32_t)>& callback) { + ConfigureProxyIfNecessary(); + + // It is safe to capture |this| because the callback won't be run after this + // object goes away. + endpoint_client_->QueryVersion(base::Bind( + &InterfacePtrState::OnQueryVersion, base::Unretained(this), callback)); + } + + void RequireVersion(uint32_t version) { + ConfigureProxyIfNecessary(); + + if (version <= version_) + return; + + version_ = version; + endpoint_client_->RequireVersion(version); + } + + void FlushForTesting() { + ConfigureProxyIfNecessary(); + endpoint_client_->FlushForTesting(); + } + + void CloseWithReason(uint32_t custom_reason, const std::string& description) { + ConfigureProxyIfNecessary(); + endpoint_client_->CloseWithReason(custom_reason, description); + } + + void Swap(InterfacePtrState* other) { + using std::swap; + swap(other->router_, router_); + swap(other->endpoint_client_, endpoint_client_); + swap(other->proxy_, proxy_); + handle_.swap(other->handle_); + runner_.swap(other->runner_); + swap(other->version_, version_); + } + + void Bind(InterfacePtrInfo<Interface> info, + scoped_refptr<base::SingleThreadTaskRunner> runner) { + DCHECK(!router_); + DCHECK(!endpoint_client_); + DCHECK(!proxy_); + DCHECK(!handle_.is_valid()); + DCHECK_EQ(0u, version_); + DCHECK(info.is_valid()); + + handle_ = info.PassHandle(); + version_ = info.version(); + runner_ = std::move(runner); + } + + bool HasAssociatedInterfaces() const { + return router_ ? router_->HasAssociatedEndpoints() : false; + } + + // After this method is called, the object is in an invalid state and + // shouldn't be reused. + InterfacePtrInfo<Interface> PassInterface() { + endpoint_client_.reset(); + proxy_.reset(); + return InterfacePtrInfo<Interface>( + router_ ? router_->PassMessagePipe() : std::move(handle_), version_); + } + + bool is_bound() const { return handle_.is_valid() || endpoint_client_; } + + bool encountered_error() const { + return endpoint_client_ ? endpoint_client_->encountered_error() : false; + } + + void set_connection_error_handler(const base::Closure& error_handler) { + ConfigureProxyIfNecessary(); + + DCHECK(endpoint_client_); + endpoint_client_->set_connection_error_handler(error_handler); + } + + void set_connection_error_with_reason_handler( + const ConnectionErrorWithReasonCallback& error_handler) { + ConfigureProxyIfNecessary(); + + DCHECK(endpoint_client_); + endpoint_client_->set_connection_error_with_reason_handler(error_handler); + } + + // Returns true if bound and awaiting a response to a message. + bool has_pending_callbacks() const { + return endpoint_client_ && endpoint_client_->has_pending_responders(); + } + + AssociatedGroup* associated_group() { + ConfigureProxyIfNecessary(); + return endpoint_client_->associated_group(); + } + + void EnableTestingMode() { + ConfigureProxyIfNecessary(); + router_->EnableTestingMode(); + } + + void ForwardMessage(Message message) { + ConfigureProxyIfNecessary(); + endpoint_client_->Accept(&message); + } + + void ForwardMessageWithResponder(Message message, + std::unique_ptr<MessageReceiver> responder) { + ConfigureProxyIfNecessary(); + endpoint_client_->AcceptWithResponder(&message, std::move(responder)); + } + + private: + using Proxy = typename Interface::Proxy_; + + void ConfigureProxyIfNecessary() { + // The proxy has been configured. + if (proxy_) { + DCHECK(router_); + DCHECK(endpoint_client_); + return; + } + // The object hasn't been bound. + if (!handle_.is_valid()) + return; + + MultiplexRouter::Config config = + Interface::PassesAssociatedKinds_ + ? MultiplexRouter::MULTI_INTERFACE + : (Interface::HasSyncMethods_ + ? MultiplexRouter::SINGLE_INTERFACE_WITH_SYNC_METHODS + : MultiplexRouter::SINGLE_INTERFACE); + router_ = new MultiplexRouter(std::move(handle_), config, true, runner_); + router_->SetMasterInterfaceName(Interface::Name_); + endpoint_client_.reset(new InterfaceEndpointClient( + router_->CreateLocalEndpointHandle(kMasterInterfaceId), nullptr, + base::WrapUnique(new typename Interface::ResponseValidator_()), false, + std::move(runner_), + // The version is only queried from the client so the value passed here + // will not be used. + 0u)); + proxy_.reset(new Proxy(endpoint_client_.get())); + } + + void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, + uint32_t version) { + version_ = version; + callback.Run(version); + } + + scoped_refptr<MultiplexRouter> router_; + + std::unique_ptr<InterfaceEndpointClient> endpoint_client_; + std::unique_ptr<Proxy> proxy_; + + // |router_| (as well as other members above) is not initialized until + // read/write with the message pipe handle is needed. |handle_| is valid + // between the Bind() call and the initialization of |router_|. + ScopedMessagePipeHandle handle_; + scoped_refptr<base::SingleThreadTaskRunner> runner_; + + uint32_t version_; + + DISALLOW_COPY_AND_ASSIGN(InterfacePtrState); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_INTERFACE_PTR_STATE_H_ diff --git a/mojo/public/cpp/bindings/lib/map_data_internal.h b/mojo/public/cpp/bindings/lib/map_data_internal.h new file mode 100644 index 0000000000..f8e3d2918f --- /dev/null +++ b/mojo/public/cpp/bindings/lib/map_data_internal.h @@ -0,0 +1,85 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_DATA_INTERNAL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_DATA_INTERNAL_H_ + +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/lib/validate_params.h" +#include "mojo/public/cpp/bindings/lib/validation_errors.h" +#include "mojo/public/cpp/bindings/lib/validation_util.h" + +namespace mojo { +namespace internal { + +// Map serializes into a struct which has two arrays as struct fields, the keys +// and the values. +template <typename Key, typename Value> +class Map_Data { + public: + static Map_Data* New(Buffer* buf) { + return new (buf->Allocate(sizeof(Map_Data))) Map_Data(); + } + + // |validate_params| must have non-null |key_validate_params| and + // |element_validate_params| members. + static bool Validate(const void* data, + ValidationContext* validation_context, + const ContainerValidateParams* validate_params) { + if (!data) + return true; + + if (!ValidateStructHeaderAndClaimMemory(data, validation_context)) + return false; + + const Map_Data* object = static_cast<const Map_Data*>(data); + if (object->header_.num_bytes != sizeof(Map_Data) || + object->header_.version != 0) { + ReportValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_STRUCT_HEADER); + return false; + } + + if (!ValidatePointerNonNullable( + object->keys, "null key array in map struct", validation_context) || + !ValidateContainer(object->keys, validation_context, + validate_params->key_validate_params)) { + return false; + } + + if (!ValidatePointerNonNullable(object->values, + "null value array in map struct", + validation_context) || + !ValidateContainer(object->values, validation_context, + validate_params->element_validate_params)) { + return false; + } + + if (object->keys.Get()->size() != object->values.Get()->size()) { + ReportValidationError(validation_context, + VALIDATION_ERROR_DIFFERENT_SIZED_ARRAYS_IN_MAP); + return false; + } + + return true; + } + + StructHeader header_; + + Pointer<Array_Data<Key>> keys; + Pointer<Array_Data<Value>> values; + + private: + Map_Data() { + header_.num_bytes = sizeof(*this); + header_.version = 0; + } + ~Map_Data() = delete; +}; +static_assert(sizeof(Map_Data<char, char>) == 24, "Bad sizeof(Map_Data)"); + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_DATA_INTERNAL_H_ diff --git a/mojo/public/cpp/bindings/lib/map_serialization.h b/mojo/public/cpp/bindings/lib/map_serialization.h new file mode 100644 index 0000000000..718a76307d --- /dev/null +++ b/mojo/public/cpp/bindings/lib/map_serialization.h @@ -0,0 +1,182 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_SERIALIZATION_H_ + +#include <type_traits> +#include <vector> + +#include "mojo/public/cpp/bindings/array_data_view.h" +#include "mojo/public/cpp/bindings/lib/array_serialization.h" +#include "mojo/public/cpp/bindings/lib/map_data_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" +#include "mojo/public/cpp/bindings/map_data_view.h" + +namespace mojo { +namespace internal { + +template <typename MaybeConstUserType> +class MapReaderBase { + public: + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Traits = MapTraits<UserType>; + using MaybeConstIterator = + decltype(Traits::GetBegin(std::declval<MaybeConstUserType&>())); + + explicit MapReaderBase(MaybeConstUserType& input) + : input_(input), iter_(Traits::GetBegin(input_)) {} + ~MapReaderBase() {} + + size_t GetSize() const { return Traits::GetSize(input_); } + + // Return null because key or value elements are not stored continuously in + // memory. + void* GetDataIfExists() { return nullptr; } + + protected: + MaybeConstUserType& input_; + MaybeConstIterator iter_; +}; + +// Used as the UserTypeReader template parameter of ArraySerializer. +template <typename MaybeConstUserType> +class MapKeyReader : public MapReaderBase<MaybeConstUserType> { + public: + using Base = MapReaderBase<MaybeConstUserType>; + using Traits = typename Base::Traits; + using MaybeConstIterator = typename Base::MaybeConstIterator; + + explicit MapKeyReader(MaybeConstUserType& input) : Base(input) {} + ~MapKeyReader() {} + + using GetNextResult = + decltype(Traits::GetKey(std::declval<MaybeConstIterator&>())); + GetNextResult GetNext() { + GetNextResult key = Traits::GetKey(this->iter_); + Traits::AdvanceIterator(this->iter_); + return key; + } +}; + +// Used as the UserTypeReader template parameter of ArraySerializer. +template <typename MaybeConstUserType> +class MapValueReader : public MapReaderBase<MaybeConstUserType> { + public: + using Base = MapReaderBase<MaybeConstUserType>; + using Traits = typename Base::Traits; + using MaybeConstIterator = typename Base::MaybeConstIterator; + + explicit MapValueReader(MaybeConstUserType& input) : Base(input) {} + ~MapValueReader() {} + + using GetNextResult = + decltype(Traits::GetValue(std::declval<MaybeConstIterator&>())); + GetNextResult GetNext() { + GetNextResult value = Traits::GetValue(this->iter_); + Traits::AdvanceIterator(this->iter_); + return value; + } +}; + +template <typename Key, typename Value, typename MaybeConstUserType> +struct Serializer<MapDataView<Key, Value>, MaybeConstUserType> { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Traits = MapTraits<UserType>; + using UserKey = typename Traits::Key; + using UserValue = typename Traits::Value; + using Data = typename MojomTypeTraits<MapDataView<Key, Value>>::Data; + using KeyArraySerializer = ArraySerializer<ArrayDataView<Key>, + std::vector<UserKey>, + MapKeyReader<MaybeConstUserType>>; + using ValueArraySerializer = + ArraySerializer<ArrayDataView<Value>, + std::vector<UserValue>, + MapValueReader<MaybeConstUserType>>; + + static size_t PrepareToSerialize(MaybeConstUserType& input, + SerializationContext* context) { + if (CallIsNullIfExists<Traits>(input)) + return 0; + + size_t struct_overhead = sizeof(Data); + MapKeyReader<MaybeConstUserType> key_reader(input); + size_t keys_size = + KeyArraySerializer::GetSerializedSize(&key_reader, context); + MapValueReader<MaybeConstUserType> value_reader(input); + size_t values_size = + ValueArraySerializer::GetSerializedSize(&value_reader, context); + + return struct_overhead + keys_size + values_size; + } + + static void Serialize(MaybeConstUserType& input, + Buffer* buf, + Data** output, + const ContainerValidateParams* validate_params, + SerializationContext* context) { + DCHECK(validate_params->key_validate_params); + DCHECK(validate_params->element_validate_params); + if (CallIsNullIfExists<Traits>(input)) { + *output = nullptr; + return; + } + + auto result = Data::New(buf); + if (result) { + auto keys_ptr = MojomTypeTraits<ArrayDataView<Key>>::Data::New( + Traits::GetSize(input), buf); + if (keys_ptr) { + MapKeyReader<MaybeConstUserType> key_reader(input); + KeyArraySerializer::SerializeElements( + &key_reader, buf, keys_ptr, validate_params->key_validate_params, + context); + result->keys.Set(keys_ptr); + } + + auto values_ptr = MojomTypeTraits<ArrayDataView<Value>>::Data::New( + Traits::GetSize(input), buf); + if (values_ptr) { + MapValueReader<MaybeConstUserType> value_reader(input); + ValueArraySerializer::SerializeElements( + &value_reader, buf, values_ptr, + validate_params->element_validate_params, context); + result->values.Set(values_ptr); + } + } + *output = result; + } + + static bool Deserialize(Data* input, + UserType* output, + SerializationContext* context) { + if (!input) + return CallSetToNullIfExists<Traits>(output); + + std::vector<UserKey> keys; + std::vector<UserValue> values; + + if (!KeyArraySerializer::DeserializeElements(input->keys.Get(), &keys, + context) || + !ValueArraySerializer::DeserializeElements(input->values.Get(), &values, + context)) { + return false; + } + + DCHECK_EQ(keys.size(), values.size()); + size_t size = keys.size(); + Traits::SetToEmpty(output); + + for (size_t i = 0; i < size; ++i) { + if (!Traits::Insert(*output, std::move(keys[i]), std::move(values[i]))) + return false; + } + return true; + } +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/may_auto_lock.h b/mojo/public/cpp/bindings/lib/may_auto_lock.h new file mode 100644 index 0000000000..06091fee90 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/may_auto_lock.h @@ -0,0 +1,62 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MAY_AUTO_LOCK_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_MAY_AUTO_LOCK_H_ + +#include "base/macros.h" +#include "base/optional.h" +#include "base/synchronization/lock.h" + +namespace mojo { +namespace internal { + +// Similar to base::AutoLock, except that it does nothing if |lock| passed into +// the constructor is null. +class MayAutoLock { + public: + explicit MayAutoLock(base::Optional<base::Lock>* lock) + : lock_(lock->has_value() ? &lock->value() : nullptr) { + if (lock_) + lock_->Acquire(); + } + + ~MayAutoLock() { + if (lock_) { + lock_->AssertAcquired(); + lock_->Release(); + } + } + + private: + base::Lock* lock_; + DISALLOW_COPY_AND_ASSIGN(MayAutoLock); +}; + +// Similar to base::AutoUnlock, except that it does nothing if |lock| passed +// into the constructor is null. +class MayAutoUnlock { + public: + explicit MayAutoUnlock(base::Optional<base::Lock>* lock) + : lock_(lock->has_value() ? &lock->value() : nullptr) { + if (lock_) { + lock_->AssertAcquired(); + lock_->Release(); + } + } + + ~MayAutoUnlock() { + if (lock_) + lock_->Acquire(); + } + + private: + base::Lock* lock_; + DISALLOW_COPY_AND_ASSIGN(MayAutoUnlock); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_MAY_AUTO_LOCK_H_ diff --git a/mojo/public/cpp/bindings/lib/message.cc b/mojo/public/cpp/bindings/lib/message.cc new file mode 100644 index 0000000000..e5f3808117 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message.cc @@ -0,0 +1,332 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/message.h" + +#include <stddef.h> +#include <stdint.h> +#include <stdlib.h> + +#include <algorithm> +#include <utility> + +#include "base/bind.h" +#include "base/lazy_instance.h" +#include "base/logging.h" +#include "base/strings/stringprintf.h" +#include "base/threading/thread_local.h" +#include "mojo/public/cpp/bindings/associated_group_controller.h" +#include "mojo/public/cpp/bindings/lib/array_internal.h" + +namespace mojo { + +namespace { + +base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>:: + DestructorAtExit g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER; + +base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>:: + DestructorAtExit g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER; + +void DoNotifyBadMessage(Message message, const std::string& error) { + message.NotifyBadMessage(error); +} + +} // namespace + +Message::Message() { +} + +Message::Message(Message&& other) + : buffer_(std::move(other.buffer_)), + handles_(std::move(other.handles_)), + associated_endpoint_handles_( + std::move(other.associated_endpoint_handles_)) {} + +Message::~Message() { + CloseHandles(); +} + +Message& Message::operator=(Message&& other) { + Reset(); + std::swap(other.buffer_, buffer_); + std::swap(other.handles_, handles_); + std::swap(other.associated_endpoint_handles_, associated_endpoint_handles_); + return *this; +} + +void Message::Reset() { + CloseHandles(); + handles_.clear(); + associated_endpoint_handles_.clear(); + buffer_.reset(); +} + +void Message::Initialize(size_t capacity, bool zero_initialized) { + DCHECK(!buffer_); + buffer_.reset(new internal::MessageBuffer(capacity, zero_initialized)); +} + +void Message::InitializeFromMojoMessage(ScopedMessageHandle message, + uint32_t num_bytes, + std::vector<Handle>* handles) { + DCHECK(!buffer_); + buffer_.reset(new internal::MessageBuffer(std::move(message), num_bytes)); + handles_.swap(*handles); +} + +const uint8_t* Message::payload() const { + if (version() < 2) + return data() + header()->num_bytes; + + return static_cast<const uint8_t*>(header_v2()->payload.Get()); +} + +uint32_t Message::payload_num_bytes() const { + DCHECK_GE(data_num_bytes(), header()->num_bytes); + size_t num_bytes; + if (version() < 2) { + num_bytes = data_num_bytes() - header()->num_bytes; + } else { + auto payload = reinterpret_cast<uintptr_t>(header_v2()->payload.Get()); + if (!payload) { + num_bytes = 0; + } else { + auto payload_end = + reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get()); + if (!payload_end) + payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes()); + DCHECK_GE(payload_end, payload); + num_bytes = payload_end - payload; + } + } + DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max()); + return static_cast<uint32_t>(num_bytes); +} + +uint32_t Message::payload_num_interface_ids() const { + auto* array_pointer = + version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); + return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0; +} + +const uint32_t* Message::payload_interface_ids() const { + auto* array_pointer = + version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); + return array_pointer ? array_pointer->storage() : nullptr; +} + +ScopedMessageHandle Message::TakeMojoMessage() { + // If there are associated endpoints transferred, + // SerializeAssociatedEndpointHandles() must be called before this method. + DCHECK(associated_endpoint_handles_.empty()); + + if (handles_.empty()) // Fast path for the common case: No handles. + return buffer_->TakeMessage(); + + // Allocate a new message with space for the handles, then copy the buffer + // contents into it. + // + // TODO(rockot): We could avoid this copy by extending GetSerializedSize() + // behavior to collect handles. It's unoptimized for now because it's much + // more common to have messages with no handles. + ScopedMessageHandle new_message; + MojoResult rv = AllocMessage( + data_num_bytes(), + handles_.empty() ? nullptr + : reinterpret_cast<const MojoHandle*>(handles_.data()), + handles_.size(), + MOJO_ALLOC_MESSAGE_FLAG_NONE, + &new_message); + CHECK_EQ(rv, MOJO_RESULT_OK); + handles_.clear(); + + void* new_buffer = nullptr; + rv = GetMessageBuffer(new_message.get(), &new_buffer); + CHECK_EQ(rv, MOJO_RESULT_OK); + + memcpy(new_buffer, data(), data_num_bytes()); + buffer_.reset(); + + return new_message; +} + +void Message::NotifyBadMessage(const std::string& error) { + DCHECK(buffer_); + buffer_->NotifyBadMessage(error); +} + +void Message::CloseHandles() { + for (std::vector<Handle>::iterator it = handles_.begin(); + it != handles_.end(); ++it) { + if (it->is_valid()) + CloseRaw(*it); + } +} + +void Message::SerializeAssociatedEndpointHandles( + AssociatedGroupController* group_controller) { + if (associated_endpoint_handles_.empty()) + return; + + DCHECK_GE(version(), 2u); + DCHECK(header_v2()->payload_interface_ids.is_null()); + + size_t size = associated_endpoint_handles_.size(); + auto* data = internal::Array_Data<uint32_t>::New(size, buffer()); + header_v2()->payload_interface_ids.Set(data); + + for (size_t i = 0; i < size; ++i) { + ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i]; + + DCHECK(handle.pending_association()); + data->storage()[i] = + group_controller->AssociateInterface(std::move(handle)); + } + associated_endpoint_handles_.clear(); +} + +bool Message::DeserializeAssociatedEndpointHandles( + AssociatedGroupController* group_controller) { + associated_endpoint_handles_.clear(); + + uint32_t num_ids = payload_num_interface_ids(); + if (num_ids == 0) + return true; + + associated_endpoint_handles_.reserve(num_ids); + uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage(); + bool result = true; + for (uint32_t i = 0; i < num_ids; ++i) { + auto handle = group_controller->CreateLocalEndpointHandle(ids[i]); + if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) { + // |ids[i]| itself is valid but handle creation failed. In that case, mark + // deserialization as failed but continue to deserialize the rest of + // handles. + result = false; + } + + associated_endpoint_handles_.push_back(std::move(handle)); + ids[i] = kInvalidInterfaceId; + } + return result; +} + +PassThroughFilter::PassThroughFilter() {} + +PassThroughFilter::~PassThroughFilter() {} + +bool PassThroughFilter::Accept(Message* message) { return true; } + +SyncMessageResponseContext::SyncMessageResponseContext() + : outer_context_(current()) { + g_tls_sync_response_context.Get().Set(this); +} + +SyncMessageResponseContext::~SyncMessageResponseContext() { + DCHECK_EQ(current(), this); + g_tls_sync_response_context.Get().Set(outer_context_); +} + +// static +SyncMessageResponseContext* SyncMessageResponseContext::current() { + return g_tls_sync_response_context.Get().Get(); +} + +void SyncMessageResponseContext::ReportBadMessage(const std::string& error) { + GetBadMessageCallback().Run(error); +} + +const ReportBadMessageCallback& +SyncMessageResponseContext::GetBadMessageCallback() { + if (bad_message_callback_.is_null()) { + bad_message_callback_ = + base::Bind(&DoNotifyBadMessage, base::Passed(&response_)); + } + return bad_message_callback_; +} + +MojoResult ReadMessage(MessagePipeHandle handle, Message* message) { + MojoResult rv; + + std::vector<Handle> handles; + ScopedMessageHandle mojo_message; + uint32_t num_bytes = 0, num_handles = 0; + rv = ReadMessageNew(handle, + &mojo_message, + &num_bytes, + nullptr, + &num_handles, + MOJO_READ_MESSAGE_FLAG_NONE); + if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) { + DCHECK_GT(num_handles, 0u); + handles.resize(num_handles); + rv = ReadMessageNew(handle, + &mojo_message, + &num_bytes, + reinterpret_cast<MojoHandle*>(handles.data()), + &num_handles, + MOJO_READ_MESSAGE_FLAG_NONE); + } + + if (rv != MOJO_RESULT_OK) + return rv; + + message->InitializeFromMojoMessage( + std::move(mojo_message), num_bytes, &handles); + return MOJO_RESULT_OK; +} + +void ReportBadMessage(const std::string& error) { + internal::MessageDispatchContext* context = + internal::MessageDispatchContext::current(); + DCHECK(context); + context->GetBadMessageCallback().Run(error); +} + +ReportBadMessageCallback GetBadMessageCallback() { + internal::MessageDispatchContext* context = + internal::MessageDispatchContext::current(); + DCHECK(context); + return context->GetBadMessageCallback(); +} + +namespace internal { + +MessageHeaderV2::MessageHeaderV2() = default; + +MessageDispatchContext::MessageDispatchContext(Message* message) + : outer_context_(current()), message_(message) { + g_tls_message_dispatch_context.Get().Set(this); +} + +MessageDispatchContext::~MessageDispatchContext() { + DCHECK_EQ(current(), this); + g_tls_message_dispatch_context.Get().Set(outer_context_); +} + +// static +MessageDispatchContext* MessageDispatchContext::current() { + return g_tls_message_dispatch_context.Get().Get(); +} + +const ReportBadMessageCallback& +MessageDispatchContext::GetBadMessageCallback() { + if (bad_message_callback_.is_null()) { + bad_message_callback_ = + base::Bind(&DoNotifyBadMessage, base::Passed(message_)); + } + return bad_message_callback_; +} + +// static +void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) { + SyncMessageResponseContext* context = SyncMessageResponseContext::current(); + if (context) + context->response_ = std::move(*message); +} + +} // namespace internal + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_buffer.cc b/mojo/public/cpp/bindings/lib/message_buffer.cc new file mode 100644 index 0000000000..cc12ef6e31 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_buffer.cc @@ -0,0 +1,52 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/message_buffer.h" + +#include <limits> + +#include "mojo/public/cpp/bindings/lib/serialization_util.h" + +namespace mojo { +namespace internal { + +MessageBuffer::MessageBuffer(size_t capacity, bool zero_initialized) { + DCHECK_LE(capacity, std::numeric_limits<uint32_t>::max()); + + MojoResult rv = AllocMessage(capacity, nullptr, 0, + MOJO_ALLOC_MESSAGE_FLAG_NONE, &message_); + CHECK_EQ(rv, MOJO_RESULT_OK); + + void* buffer = nullptr; + if (capacity != 0) { + rv = GetMessageBuffer(message_.get(), &buffer); + CHECK_EQ(rv, MOJO_RESULT_OK); + + if (zero_initialized) + memset(buffer, 0, capacity); + } + Initialize(buffer, capacity); +} + +MessageBuffer::MessageBuffer(ScopedMessageHandle message, uint32_t num_bytes) { + message_ = std::move(message); + + void* buffer = nullptr; + if (num_bytes != 0) { + MojoResult rv = GetMessageBuffer(message_.get(), &buffer); + CHECK_EQ(rv, MOJO_RESULT_OK); + } + Initialize(buffer, num_bytes); +} + +MessageBuffer::~MessageBuffer() {} + +void MessageBuffer::NotifyBadMessage(const std::string& error) { + DCHECK(message_.is_valid()); + MojoResult result = mojo::NotifyBadMessage(message_.get(), error); + DCHECK_EQ(result, MOJO_RESULT_OK); +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_buffer.h b/mojo/public/cpp/bindings/lib/message_buffer.h new file mode 100644 index 0000000000..96d5140f77 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_buffer.h @@ -0,0 +1,43 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_LIB_MESSAGE_BUFFER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_LIB_MESSAGE_BUFFER_H_ + +#include <stdint.h> + +#include <utility> + +#include "base/macros.h" +#include "mojo/public/cpp/bindings/lib/buffer.h" +#include "mojo/public/cpp/system/message.h" + +namespace mojo { +namespace internal { + +// A fixed-size Buffer using a Mojo message object for storage. +class MessageBuffer : public Buffer { + public: + // Initializes this buffer to carry a fixed byte capacity and no handles. + MessageBuffer(size_t capacity, bool zero_initialized); + + // Initializes this buffer from an existing Mojo MessageHandle. + MessageBuffer(ScopedMessageHandle message, uint32_t num_bytes); + + ~MessageBuffer(); + + ScopedMessageHandle TakeMessage() { return std::move(message_); } + + void NotifyBadMessage(const std::string& error); + + private: + ScopedMessageHandle message_; + + DISALLOW_COPY_AND_ASSIGN(MessageBuffer); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_LIB_MESSAGE_BUFFER_H_ diff --git a/mojo/public/cpp/bindings/lib/message_builder.cc b/mojo/public/cpp/bindings/lib/message_builder.cc new file mode 100644 index 0000000000..6806a73213 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_builder.cc @@ -0,0 +1,69 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/message_builder.h" + +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/bindings/lib/buffer.h" +#include "mojo/public/cpp/bindings/lib/message_internal.h" + +namespace mojo { +namespace internal { + +template <typename Header> +void Allocate(Buffer* buf, Header** header) { + *header = static_cast<Header*>(buf->Allocate(sizeof(Header))); + (*header)->num_bytes = sizeof(Header); +} + +MessageBuilder::MessageBuilder(uint32_t name, + uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count) { + if (payload_interface_id_count > 0) { + // Version 2 + InitializeMessage( + sizeof(MessageHeaderV2) + Align(payload_size) + + ArrayDataTraits<uint32_t>::GetStorageSize( + static_cast<uint32_t>(payload_interface_id_count))); + + MessageHeaderV2* header; + Allocate(message_.buffer(), &header); + header->version = 2; + header->name = name; + header->flags = flags; + // The payload immediately follows the header. + header->payload.Set(header + 1); + } else if (flags & + (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { + // Version 1 + InitializeMessage(sizeof(MessageHeaderV1) + payload_size); + + MessageHeaderV1* header; + Allocate(message_.buffer(), &header); + header->version = 1; + header->name = name; + header->flags = flags; + } else { + InitializeMessage(sizeof(MessageHeader) + payload_size); + + MessageHeader* header; + Allocate(message_.buffer(), &header); + header->version = 0; + header->name = name; + header->flags = flags; + } +} + +MessageBuilder::~MessageBuilder() { +} + +void MessageBuilder::InitializeMessage(size_t size) { + message_.Initialize(static_cast<uint32_t>(Align(size)), + true /* zero_initialized */); +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_builder.h b/mojo/public/cpp/bindings/lib/message_builder.h new file mode 100644 index 0000000000..8a4d5c4690 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_builder.h @@ -0,0 +1,45 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_BUILDER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_BUILDER_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "base/macros.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/message.h" + +namespace mojo { + +class Message; + +namespace internal { + +class Buffer; + +class MOJO_CPP_BINDINGS_EXPORT MessageBuilder { + public: + MessageBuilder(uint32_t name, + uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count); + ~MessageBuilder(); + + Buffer* buffer() { return message_.buffer(); } + Message* message() { return &message_; } + + private: + void InitializeMessage(size_t size); + + Message message_; + + DISALLOW_COPY_AND_ASSIGN(MessageBuilder); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_BUILDER_H_ diff --git a/mojo/public/cpp/bindings/lib/message_header_validator.cc b/mojo/public/cpp/bindings/lib/message_header_validator.cc new file mode 100644 index 0000000000..9f8c6278c0 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_header_validator.cc @@ -0,0 +1,133 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/message_header_validator.h" + +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/lib/validate_params.h" +#include "mojo/public/cpp/bindings/lib/validation_context.h" +#include "mojo/public/cpp/bindings/lib/validation_errors.h" +#include "mojo/public/cpp/bindings/lib/validation_util.h" + +namespace mojo { +namespace { + +// TODO(yzshen): Define a mojom struct for message header and use the generated +// validation and data view code. +bool IsValidMessageHeader(const internal::MessageHeader* header, + internal::ValidationContext* validation_context) { + // NOTE: Our goal is to preserve support for future extension of the message + // header. If we encounter fields we do not understand, we must ignore them. + + // Extra validation of the struct header: + do { + if (header->version == 0) { + if (header->num_bytes == sizeof(internal::MessageHeader)) + break; + } else if (header->version == 1) { + if (header->num_bytes == sizeof(internal::MessageHeaderV1)) + break; + } else if (header->version == 2) { + if (header->num_bytes == sizeof(internal::MessageHeaderV2)) + break; + } else if (header->version > 2) { + if (header->num_bytes >= sizeof(internal::MessageHeaderV2)) + break; + } + internal::ReportValidationError( + validation_context, + internal::VALIDATION_ERROR_UNEXPECTED_STRUCT_HEADER); + return false; + } while (false); + + // Validate flags (allow unknown bits): + + // These flags require a RequestID. + constexpr uint32_t kRequestIdFlags = + Message::kFlagExpectsResponse | Message::kFlagIsResponse; + if (header->version == 0 && (header->flags & kRequestIdFlags)) { + internal::ReportValidationError( + validation_context, + internal::VALIDATION_ERROR_MESSAGE_HEADER_MISSING_REQUEST_ID); + return false; + } + + // These flags are mutually exclusive. + if ((header->flags & kRequestIdFlags) == kRequestIdFlags) { + internal::ReportValidationError( + validation_context, + internal::VALIDATION_ERROR_MESSAGE_HEADER_INVALID_FLAGS); + return false; + } + + if (header->version < 2) + return true; + + auto* header_v2 = static_cast<const internal::MessageHeaderV2*>(header); + // For the payload pointer: + // - Check that the pointer can be safely decoded. + // - Claim one byte that the pointer points to. It makes sure not only the + // address is within the message, but also the address precedes the array + // storing interface IDs (which is important for safely calculating the + // payload size). + // - Validation of the payload contents will be done separately based on the + // payload type. + if (!header_v2->payload.is_null() && + (!internal::ValidatePointer(header_v2->payload, validation_context) || + !validation_context->ClaimMemory(header_v2->payload.Get(), 1))) { + return false; + } + + const internal::ContainerValidateParams validate_params(0, false, nullptr); + if (!internal::ValidateContainer(header_v2->payload_interface_ids, + validation_context, &validate_params)) { + return false; + } + + if (!header_v2->payload_interface_ids.is_null()) { + size_t num_ids = header_v2->payload_interface_ids.Get()->size(); + const uint32_t* ids = header_v2->payload_interface_ids.Get()->storage(); + for (size_t i = 0; i < num_ids; ++i) { + if (!IsValidInterfaceId(ids[i]) || IsMasterInterfaceId(ids[i])) { + internal::ReportValidationError( + validation_context, + internal::VALIDATION_ERROR_ILLEGAL_INTERFACE_ID); + return false; + } + } + } + + return true; +} + +} // namespace + +MessageHeaderValidator::MessageHeaderValidator() + : MessageHeaderValidator("MessageHeaderValidator") {} + +MessageHeaderValidator::MessageHeaderValidator(const std::string& description) + : description_(description) { +} + +void MessageHeaderValidator::SetDescription(const std::string& description) { + description_ = description; +} + +bool MessageHeaderValidator::Accept(Message* message) { + // Pass 0 as number of handles and associated endpoint handles because we + // don't expect any in the header, even if |message| contains handles. + internal::ValidationContext validation_context( + message->data(), message->data_num_bytes(), 0, 0, message, description_); + + if (!internal::ValidateStructHeaderAndClaimMemory(message->data(), + &validation_context)) + return false; + + if (!IsValidMessageHeader(message->header(), &validation_context)) + return false; + + return true; +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_internal.h b/mojo/public/cpp/bindings/lib/message_internal.h new file mode 100644 index 0000000000..6693198f81 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_internal.h @@ -0,0 +1,82 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_INTERNAL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_INTERNAL_H_ + +#include <stdint.h> + +#include <string> + +#include "base/callback.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" + +namespace mojo { + +class Message; + +namespace internal { + +template <typename T> +class Array_Data; + +#pragma pack(push, 1) + +struct MessageHeader : internal::StructHeader { + // Interface ID for identifying multiple interfaces running on the same + // message pipe. + uint32_t interface_id; + // Message name, which is scoped to the interface that the message belongs to. + uint32_t name; + // 0 or either of the enum values defined above. + uint32_t flags; + // Unused padding to make the struct size a multiple of 8 bytes. + uint32_t padding; +}; +static_assert(sizeof(MessageHeader) == 24, "Bad sizeof(MessageHeader)"); + +struct MessageHeaderV1 : MessageHeader { + // Only used if either kFlagExpectsResponse or kFlagIsResponse is set in + // order to match responses with corresponding requests. + uint64_t request_id; +}; +static_assert(sizeof(MessageHeaderV1) == 32, "Bad sizeof(MessageHeaderV1)"); + +struct MessageHeaderV2 : MessageHeaderV1 { + MessageHeaderV2(); + GenericPointer payload; + Pointer<Array_Data<uint32_t>> payload_interface_ids; +}; +static_assert(sizeof(MessageHeaderV2) == 48, "Bad sizeof(MessageHeaderV2)"); + +#pragma pack(pop) + +class MOJO_CPP_BINDINGS_EXPORT MessageDispatchContext { + public: + explicit MessageDispatchContext(Message* message); + ~MessageDispatchContext(); + + static MessageDispatchContext* current(); + + const base::Callback<void(const std::string&)>& GetBadMessageCallback(); + + private: + MessageDispatchContext* outer_context_; + Message* message_; + base::Callback<void(const std::string&)> bad_message_callback_; + + DISALLOW_COPY_AND_ASSIGN(MessageDispatchContext); +}; + +class MOJO_CPP_BINDINGS_EXPORT SyncMessageResponseSetup { + public: + static void SetCurrentSyncResponseMessage(Message* message); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_INTERNAL_H_ diff --git a/mojo/public/cpp/bindings/lib/multiplex_router.cc b/mojo/public/cpp/bindings/lib/multiplex_router.cc new file mode 100644 index 0000000000..ff7c678289 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/multiplex_router.cc @@ -0,0 +1,960 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/multiplex_router.h" + +#include <stdint.h> + +#include <utility> + +#include "base/bind.h" +#include "base/location.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/single_thread_task_runner.h" +#include "base/stl_util.h" +#include "base/synchronization/waitable_event.h" +#include "base/threading/thread_task_runner_handle.h" +#include "mojo/public/cpp/bindings/interface_endpoint_client.h" +#include "mojo/public/cpp/bindings/interface_endpoint_controller.h" +#include "mojo/public/cpp/bindings/lib/may_auto_lock.h" +#include "mojo/public/cpp/bindings/sync_event_watcher.h" + +namespace mojo { +namespace internal { + +// InterfaceEndpoint stores the information of an interface endpoint registered +// with the router. +// No one other than the router's |endpoints_| and |tasks_| should hold refs to +// this object. +class MultiplexRouter::InterfaceEndpoint + : public base::RefCountedThreadSafe<InterfaceEndpoint>, + public InterfaceEndpointController { + public: + InterfaceEndpoint(MultiplexRouter* router, InterfaceId id) + : router_(router), + id_(id), + closed_(false), + peer_closed_(false), + handle_created_(false), + client_(nullptr) {} + + // --------------------------------------------------------------------------- + // The following public methods are safe to call from any threads without + // locking. + + InterfaceId id() const { return id_; } + + // --------------------------------------------------------------------------- + // The following public methods are called under the router's lock. + + bool closed() const { return closed_; } + void set_closed() { + router_->AssertLockAcquired(); + closed_ = true; + } + + bool peer_closed() const { return peer_closed_; } + void set_peer_closed() { + router_->AssertLockAcquired(); + peer_closed_ = true; + } + + bool handle_created() const { return handle_created_; } + void set_handle_created() { + router_->AssertLockAcquired(); + handle_created_ = true; + } + + const base::Optional<DisconnectReason>& disconnect_reason() const { + return disconnect_reason_; + } + void set_disconnect_reason( + const base::Optional<DisconnectReason>& disconnect_reason) { + router_->AssertLockAcquired(); + disconnect_reason_ = disconnect_reason; + } + + base::SingleThreadTaskRunner* task_runner() const { + return task_runner_.get(); + } + + InterfaceEndpointClient* client() const { return client_; } + + void AttachClient(InterfaceEndpointClient* client, + scoped_refptr<base::SingleThreadTaskRunner> runner) { + router_->AssertLockAcquired(); + DCHECK(!client_); + DCHECK(!closed_); + DCHECK(runner->BelongsToCurrentThread()); + + task_runner_ = std::move(runner); + client_ = client; + } + + // This method must be called on the same thread as the corresponding + // AttachClient() call. + void DetachClient() { + router_->AssertLockAcquired(); + DCHECK(client_); + DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(!closed_); + + task_runner_ = nullptr; + client_ = nullptr; + sync_watcher_.reset(); + } + + void SignalSyncMessageEvent() { + router_->AssertLockAcquired(); + if (sync_message_event_signaled_) + return; + sync_message_event_signaled_ = true; + if (sync_message_event_) + sync_message_event_->Signal(); + } + + void ResetSyncMessageSignal() { + router_->AssertLockAcquired(); + if (!sync_message_event_signaled_) + return; + sync_message_event_signaled_ = false; + if (sync_message_event_) + sync_message_event_->Reset(); + } + + // --------------------------------------------------------------------------- + // The following public methods (i.e., InterfaceEndpointController + // implementation) are called by the client on the same thread as the + // AttachClient() call. They are called outside of the router's lock. + + bool SendMessage(Message* message) override { + DCHECK(task_runner_->BelongsToCurrentThread()); + message->set_interface_id(id_); + return router_->connector_.Accept(message); + } + + void AllowWokenUpBySyncWatchOnSameThread() override { + DCHECK(task_runner_->BelongsToCurrentThread()); + + EnsureSyncWatcherExists(); + sync_watcher_->AllowWokenUpBySyncWatchOnSameThread(); + } + + bool SyncWatch(const bool* should_stop) override { + DCHECK(task_runner_->BelongsToCurrentThread()); + + EnsureSyncWatcherExists(); + return sync_watcher_->SyncWatch(should_stop); + } + + private: + friend class base::RefCountedThreadSafe<InterfaceEndpoint>; + + ~InterfaceEndpoint() override { + router_->AssertLockAcquired(); + + DCHECK(!client_); + DCHECK(closed_); + DCHECK(peer_closed_); + DCHECK(!sync_watcher_); + } + + void OnSyncEventSignaled() { + DCHECK(task_runner_->BelongsToCurrentThread()); + scoped_refptr<MultiplexRouter> router_protector(router_); + + MayAutoLock locker(&router_->lock_); + scoped_refptr<InterfaceEndpoint> self_protector(this); + + bool more_to_process = router_->ProcessFirstSyncMessageForEndpoint(id_); + + if (!more_to_process) + ResetSyncMessageSignal(); + + // Currently there are no queued sync messages and the peer has closed so + // there won't be incoming sync messages in the future. + if (!more_to_process && peer_closed_) { + // If a SyncWatch() call (or multiple ones) of this interface endpoint is + // on the call stack, resetting the sync watcher will allow it to exit + // when the call stack unwinds to that frame. + sync_watcher_.reset(); + } + } + + void EnsureSyncWatcherExists() { + DCHECK(task_runner_->BelongsToCurrentThread()); + if (sync_watcher_) + return; + + { + MayAutoLock locker(&router_->lock_); + if (!sync_message_event_) { + sync_message_event_.emplace( + base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::NOT_SIGNALED); + if (sync_message_event_signaled_) + sync_message_event_->Signal(); + } + } + sync_watcher_.reset( + new SyncEventWatcher(&sync_message_event_.value(), + base::Bind(&InterfaceEndpoint::OnSyncEventSignaled, + base::Unretained(this)))); + } + + // --------------------------------------------------------------------------- + // The following members are safe to access from any threads. + + MultiplexRouter* const router_; + const InterfaceId id_; + + // --------------------------------------------------------------------------- + // The following members are accessed under the router's lock. + + // Whether the endpoint has been closed. + bool closed_; + // Whether the peer endpoint has been closed. + bool peer_closed_; + + // Whether there is already a ScopedInterfaceEndpointHandle created for this + // endpoint. + bool handle_created_; + + base::Optional<DisconnectReason> disconnect_reason_; + + // The task runner on which |client_|'s methods can be called. + scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + // Not owned. It is null if no client is attached to this endpoint. + InterfaceEndpointClient* client_; + + // An event used to signal that sync messages are available. The event is + // initialized under the router's lock and remains unchanged afterwards. It + // may be accessed outside of the router's lock later. + base::Optional<base::WaitableEvent> sync_message_event_; + bool sync_message_event_signaled_ = false; + + // --------------------------------------------------------------------------- + // The following members are only valid while a client is attached. They are + // used exclusively on the client's thread. They may be accessed outside of + // the router's lock. + + std::unique_ptr<SyncEventWatcher> sync_watcher_; + + DISALLOW_COPY_AND_ASSIGN(InterfaceEndpoint); +}; + +// MessageWrapper objects are always destroyed under the router's lock. On +// destruction, if the message it wrappers contains +// ScopedInterfaceEndpointHandles (which cannot be destructed under the +// router's lock), the wrapper unlocks to clean them up. +class MultiplexRouter::MessageWrapper { + public: + MessageWrapper() = default; + + MessageWrapper(MultiplexRouter* router, Message message) + : router_(router), value_(std::move(message)) {} + + MessageWrapper(MessageWrapper&& other) + : router_(other.router_), value_(std::move(other.value_)) {} + + ~MessageWrapper() { + if (value_.associated_endpoint_handles()->empty()) + return; + + router_->AssertLockAcquired(); + { + MayAutoUnlock unlocker(&router_->lock_); + value_.mutable_associated_endpoint_handles()->clear(); + } + } + + MessageWrapper& operator=(MessageWrapper&& other) { + router_ = other.router_; + value_ = std::move(other.value_); + return *this; + } + + Message& value() { return value_; } + + private: + MultiplexRouter* router_ = nullptr; + Message value_; + + DISALLOW_COPY_AND_ASSIGN(MessageWrapper); +}; + +struct MultiplexRouter::Task { + public: + // Doesn't take ownership of |message| but takes its contents. + static std::unique_ptr<Task> CreateMessageTask( + MessageWrapper message_wrapper) { + Task* task = new Task(MESSAGE); + task->message_wrapper = std::move(message_wrapper); + return base::WrapUnique(task); + } + static std::unique_ptr<Task> CreateNotifyErrorTask( + InterfaceEndpoint* endpoint) { + Task* task = new Task(NOTIFY_ERROR); + task->endpoint_to_notify = endpoint; + return base::WrapUnique(task); + } + + ~Task() {} + + bool IsMessageTask() const { return type == MESSAGE; } + bool IsNotifyErrorTask() const { return type == NOTIFY_ERROR; } + + MessageWrapper message_wrapper; + scoped_refptr<InterfaceEndpoint> endpoint_to_notify; + + enum Type { MESSAGE, NOTIFY_ERROR }; + Type type; + + private: + explicit Task(Type in_type) : type(in_type) {} + + DISALLOW_COPY_AND_ASSIGN(Task); +}; + +MultiplexRouter::MultiplexRouter( + ScopedMessagePipeHandle message_pipe, + Config config, + bool set_interface_id_namesapce_bit, + scoped_refptr<base::SingleThreadTaskRunner> runner) + : set_interface_id_namespace_bit_(set_interface_id_namesapce_bit), + task_runner_(runner), + header_validator_(nullptr), + filters_(this), + connector_(std::move(message_pipe), + config == MULTI_INTERFACE ? Connector::MULTI_THREADED_SEND + : Connector::SINGLE_THREADED_SEND, + std::move(runner)), + control_message_handler_(this), + control_message_proxy_(&connector_), + next_interface_id_value_(1), + posted_to_process_tasks_(false), + encountered_error_(false), + paused_(false), + testing_mode_(false) { + DCHECK(task_runner_->BelongsToCurrentThread()); + + if (config == MULTI_INTERFACE) + lock_.emplace(); + + if (config == SINGLE_INTERFACE_WITH_SYNC_METHODS || + config == MULTI_INTERFACE) { + // Always participate in sync handle watching in multi-interface mode, + // because even if it doesn't expect sync requests during sync handle + // watching, it may still need to dispatch messages to associated endpoints + // on a different thread. + connector_.AllowWokenUpBySyncWatchOnSameThread(); + } + connector_.set_incoming_receiver(&filters_); + connector_.set_connection_error_handler( + base::Bind(&MultiplexRouter::OnPipeConnectionError, + base::Unretained(this))); + + std::unique_ptr<MessageHeaderValidator> header_validator = + base::MakeUnique<MessageHeaderValidator>(); + header_validator_ = header_validator.get(); + filters_.Append(std::move(header_validator)); +} + +MultiplexRouter::~MultiplexRouter() { + MayAutoLock locker(&lock_); + + sync_message_tasks_.clear(); + tasks_.clear(); + + for (auto iter = endpoints_.begin(); iter != endpoints_.end();) { + InterfaceEndpoint* endpoint = iter->second.get(); + // Increment the iterator before calling UpdateEndpointStateMayRemove() + // because it may remove the corresponding value from the map. + ++iter; + + if (!endpoint->closed()) { + // This happens when a NotifyPeerEndpointClosed message been received, but + // the interface ID hasn't been used to create local endpoint handle. + DCHECK(!endpoint->client()); + DCHECK(endpoint->peer_closed()); + UpdateEndpointStateMayRemove(endpoint, ENDPOINT_CLOSED); + } else { + UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); + } + } + + DCHECK(endpoints_.empty()); +} + +void MultiplexRouter::SetMasterInterfaceName(const char* name) { + DCHECK(thread_checker_.CalledOnValidThread()); + header_validator_->SetDescription( + std::string(name) + " [master] MessageHeaderValidator"); + control_message_handler_.SetDescription( + std::string(name) + " [master] PipeControlMessageHandler"); + connector_.SetWatcherHeapProfilerTag(name); +} + +InterfaceId MultiplexRouter::AssociateInterface( + ScopedInterfaceEndpointHandle handle_to_send) { + if (!handle_to_send.pending_association()) + return kInvalidInterfaceId; + + uint32_t id = 0; + { + MayAutoLock locker(&lock_); + do { + if (next_interface_id_value_ >= kInterfaceIdNamespaceMask) + next_interface_id_value_ = 1; + id = next_interface_id_value_++; + if (set_interface_id_namespace_bit_) + id |= kInterfaceIdNamespaceMask; + } while (base::ContainsKey(endpoints_, id)); + + InterfaceEndpoint* endpoint = new InterfaceEndpoint(this, id); + endpoints_[id] = endpoint; + if (encountered_error_) + UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); + endpoint->set_handle_created(); + } + + if (!NotifyAssociation(&handle_to_send, id)) { + // The peer handle of |handle_to_send|, which is supposed to join this + // associated group, has been closed. + { + MayAutoLock locker(&lock_); + InterfaceEndpoint* endpoint = FindEndpoint(id); + if (endpoint) + UpdateEndpointStateMayRemove(endpoint, ENDPOINT_CLOSED); + } + + control_message_proxy_.NotifyPeerEndpointClosed( + id, handle_to_send.disconnect_reason()); + } + return id; +} + +ScopedInterfaceEndpointHandle MultiplexRouter::CreateLocalEndpointHandle( + InterfaceId id) { + if (!IsValidInterfaceId(id)) + return ScopedInterfaceEndpointHandle(); + + MayAutoLock locker(&lock_); + bool inserted = false; + InterfaceEndpoint* endpoint = FindOrInsertEndpoint(id, &inserted); + if (inserted) { + DCHECK(!endpoint->handle_created()); + + if (encountered_error_) + UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); + } else { + // If the endpoint already exist, it is because we have received a + // notification that the peer endpoint has closed. + CHECK(!endpoint->closed()); + CHECK(endpoint->peer_closed()); + + if (endpoint->handle_created()) + return ScopedInterfaceEndpointHandle(); + } + + endpoint->set_handle_created(); + return CreateScopedInterfaceEndpointHandle(id); +} + +void MultiplexRouter::CloseEndpointHandle( + InterfaceId id, + const base::Optional<DisconnectReason>& reason) { + if (!IsValidInterfaceId(id)) + return; + + MayAutoLock locker(&lock_); + DCHECK(base::ContainsKey(endpoints_, id)); + InterfaceEndpoint* endpoint = endpoints_[id].get(); + DCHECK(!endpoint->client()); + DCHECK(!endpoint->closed()); + UpdateEndpointStateMayRemove(endpoint, ENDPOINT_CLOSED); + + if (!IsMasterInterfaceId(id) || reason) { + MayAutoUnlock unlocker(&lock_); + control_message_proxy_.NotifyPeerEndpointClosed(id, reason); + } + + ProcessTasks(NO_DIRECT_CLIENT_CALLS, nullptr); +} + +InterfaceEndpointController* MultiplexRouter::AttachEndpointClient( + const ScopedInterfaceEndpointHandle& handle, + InterfaceEndpointClient* client, + scoped_refptr<base::SingleThreadTaskRunner> runner) { + const InterfaceId id = handle.id(); + + DCHECK(IsValidInterfaceId(id)); + DCHECK(client); + + MayAutoLock locker(&lock_); + DCHECK(base::ContainsKey(endpoints_, id)); + + InterfaceEndpoint* endpoint = endpoints_[id].get(); + endpoint->AttachClient(client, std::move(runner)); + + if (endpoint->peer_closed()) + tasks_.push_back(Task::CreateNotifyErrorTask(endpoint)); + ProcessTasks(NO_DIRECT_CLIENT_CALLS, nullptr); + + return endpoint; +} + +void MultiplexRouter::DetachEndpointClient( + const ScopedInterfaceEndpointHandle& handle) { + const InterfaceId id = handle.id(); + + DCHECK(IsValidInterfaceId(id)); + + MayAutoLock locker(&lock_); + DCHECK(base::ContainsKey(endpoints_, id)); + + InterfaceEndpoint* endpoint = endpoints_[id].get(); + endpoint->DetachClient(); +} + +void MultiplexRouter::RaiseError() { + if (task_runner_->BelongsToCurrentThread()) { + connector_.RaiseError(); + } else { + task_runner_->PostTask(FROM_HERE, + base::Bind(&MultiplexRouter::RaiseError, this)); + } +} + +void MultiplexRouter::CloseMessagePipe() { + DCHECK(thread_checker_.CalledOnValidThread()); + connector_.CloseMessagePipe(); + // CloseMessagePipe() above won't trigger connection error handler. + // Explicitly call OnPipeConnectionError() so that associated endpoints will + // get notified. + OnPipeConnectionError(); +} + +void MultiplexRouter::PauseIncomingMethodCallProcessing() { + DCHECK(thread_checker_.CalledOnValidThread()); + connector_.PauseIncomingMethodCallProcessing(); + + MayAutoLock locker(&lock_); + paused_ = true; + + for (auto iter = endpoints_.begin(); iter != endpoints_.end(); ++iter) + iter->second->ResetSyncMessageSignal(); +} + +void MultiplexRouter::ResumeIncomingMethodCallProcessing() { + DCHECK(thread_checker_.CalledOnValidThread()); + connector_.ResumeIncomingMethodCallProcessing(); + + MayAutoLock locker(&lock_); + paused_ = false; + + for (auto iter = endpoints_.begin(); iter != endpoints_.end(); ++iter) { + auto sync_iter = sync_message_tasks_.find(iter->first); + if (iter->second->peer_closed() || + (sync_iter != sync_message_tasks_.end() && + !sync_iter->second.empty())) { + iter->second->SignalSyncMessageEvent(); + } + } + + ProcessTasks(NO_DIRECT_CLIENT_CALLS, nullptr); +} + +bool MultiplexRouter::HasAssociatedEndpoints() const { + DCHECK(thread_checker_.CalledOnValidThread()); + MayAutoLock locker(&lock_); + + if (endpoints_.size() > 1) + return true; + if (endpoints_.size() == 0) + return false; + + return !base::ContainsKey(endpoints_, kMasterInterfaceId); +} + +void MultiplexRouter::EnableTestingMode() { + DCHECK(thread_checker_.CalledOnValidThread()); + MayAutoLock locker(&lock_); + + testing_mode_ = true; + connector_.set_enforce_errors_from_incoming_receiver(false); +} + +bool MultiplexRouter::Accept(Message* message) { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (!message->DeserializeAssociatedEndpointHandles(this)) + return false; + + scoped_refptr<MultiplexRouter> protector(this); + MayAutoLock locker(&lock_); + + DCHECK(!paused_); + + ClientCallBehavior client_call_behavior = + connector_.during_sync_handle_watcher_callback() + ? ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES + : ALLOW_DIRECT_CLIENT_CALLS; + + bool processed = + tasks_.empty() && ProcessIncomingMessage(message, client_call_behavior, + connector_.task_runner()); + + if (!processed) { + // Either the task queue is not empty or we cannot process the message + // directly. In both cases, there is no need to call ProcessTasks(). + tasks_.push_back( + Task::CreateMessageTask(MessageWrapper(this, std::move(*message)))); + Task* task = tasks_.back().get(); + + if (task->message_wrapper.value().has_flag(Message::kFlagIsSync)) { + InterfaceId id = task->message_wrapper.value().interface_id(); + sync_message_tasks_[id].push_back(task); + InterfaceEndpoint* endpoint = FindEndpoint(id); + if (endpoint) + endpoint->SignalSyncMessageEvent(); + } + } else if (!tasks_.empty()) { + // Processing the message may result in new tasks (for error notification) + // being added to the queue. In this case, we have to attempt to process the + // tasks. + ProcessTasks(client_call_behavior, connector_.task_runner()); + } + + // Always return true. If we see errors during message processing, we will + // explicitly call Connector::RaiseError() to disconnect the message pipe. + return true; +} + +bool MultiplexRouter::OnPeerAssociatedEndpointClosed( + InterfaceId id, + const base::Optional<DisconnectReason>& reason) { + DCHECK(!IsMasterInterfaceId(id) || reason); + + MayAutoLock locker(&lock_); + InterfaceEndpoint* endpoint = FindOrInsertEndpoint(id, nullptr); + + if (reason) + endpoint->set_disconnect_reason(reason); + + // It is possible that this endpoint has been set as peer closed. That is + // because when the message pipe is closed, all the endpoints are updated with + // PEER_ENDPOINT_CLOSED. We continue to process remaining tasks in the queue, + // as long as there are refs keeping the router alive. If there is a + // PeerAssociatedEndpointClosedEvent control message in the queue, we will get + // here and see that the endpoint has been marked as peer closed. + if (!endpoint->peer_closed()) { + if (endpoint->client()) + tasks_.push_back(Task::CreateNotifyErrorTask(endpoint)); + UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); + } + + // No need to trigger a ProcessTasks() because it is already on the stack. + + return true; +} + +void MultiplexRouter::OnPipeConnectionError() { + DCHECK(thread_checker_.CalledOnValidThread()); + + scoped_refptr<MultiplexRouter> protector(this); + MayAutoLock locker(&lock_); + + encountered_error_ = true; + + for (auto iter = endpoints_.begin(); iter != endpoints_.end();) { + InterfaceEndpoint* endpoint = iter->second.get(); + // Increment the iterator before calling UpdateEndpointStateMayRemove() + // because it may remove the corresponding value from the map. + ++iter; + + if (endpoint->client()) + tasks_.push_back(Task::CreateNotifyErrorTask(endpoint)); + + UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); + } + + ProcessTasks(connector_.during_sync_handle_watcher_callback() + ? ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES + : ALLOW_DIRECT_CLIENT_CALLS, + connector_.task_runner()); +} + +void MultiplexRouter::ProcessTasks( + ClientCallBehavior client_call_behavior, + base::SingleThreadTaskRunner* current_task_runner) { + AssertLockAcquired(); + + if (posted_to_process_tasks_) + return; + + while (!tasks_.empty() && !paused_) { + std::unique_ptr<Task> task(std::move(tasks_.front())); + tasks_.pop_front(); + + InterfaceId id = kInvalidInterfaceId; + bool sync_message = + task->IsMessageTask() && !task->message_wrapper.value().IsNull() && + task->message_wrapper.value().has_flag(Message::kFlagIsSync); + if (sync_message) { + id = task->message_wrapper.value().interface_id(); + auto& sync_message_queue = sync_message_tasks_[id]; + DCHECK_EQ(task.get(), sync_message_queue.front()); + sync_message_queue.pop_front(); + } + + bool processed = + task->IsNotifyErrorTask() + ? ProcessNotifyErrorTask(task.get(), client_call_behavior, + current_task_runner) + : ProcessIncomingMessage(&task->message_wrapper.value(), + client_call_behavior, current_task_runner); + + if (!processed) { + if (sync_message) { + auto& sync_message_queue = sync_message_tasks_[id]; + sync_message_queue.push_front(task.get()); + } + tasks_.push_front(std::move(task)); + break; + } else { + if (sync_message) { + auto iter = sync_message_tasks_.find(id); + if (iter != sync_message_tasks_.end() && iter->second.empty()) + sync_message_tasks_.erase(iter); + } + } + } +} + +bool MultiplexRouter::ProcessFirstSyncMessageForEndpoint(InterfaceId id) { + AssertLockAcquired(); + + auto iter = sync_message_tasks_.find(id); + if (iter == sync_message_tasks_.end()) + return false; + + if (paused_) + return true; + + MultiplexRouter::Task* task = iter->second.front(); + iter->second.pop_front(); + + DCHECK(task->IsMessageTask()); + MessageWrapper message_wrapper = std::move(task->message_wrapper); + + // Note: after this call, |task| and |iter| may be invalidated. + bool processed = ProcessIncomingMessage( + &message_wrapper.value(), ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES, + nullptr); + DCHECK(processed); + + iter = sync_message_tasks_.find(id); + if (iter == sync_message_tasks_.end()) + return false; + + if (iter->second.empty()) { + sync_message_tasks_.erase(iter); + return false; + } + + return true; +} + +bool MultiplexRouter::ProcessNotifyErrorTask( + Task* task, + ClientCallBehavior client_call_behavior, + base::SingleThreadTaskRunner* current_task_runner) { + DCHECK(!current_task_runner || current_task_runner->BelongsToCurrentThread()); + DCHECK(!paused_); + + AssertLockAcquired(); + InterfaceEndpoint* endpoint = task->endpoint_to_notify.get(); + if (!endpoint->client()) + return true; + + if (client_call_behavior != ALLOW_DIRECT_CLIENT_CALLS || + endpoint->task_runner() != current_task_runner) { + MaybePostToProcessTasks(endpoint->task_runner()); + return false; + } + + DCHECK(endpoint->task_runner()->BelongsToCurrentThread()); + + InterfaceEndpointClient* client = endpoint->client(); + base::Optional<DisconnectReason> disconnect_reason( + endpoint->disconnect_reason()); + + { + // We must unlock before calling into |client| because it may call this + // object within NotifyError(). Holding the lock will lead to deadlock. + // + // It is safe to call into |client| without the lock. Because |client| is + // always accessed on the same thread, including DetachEndpointClient(). + MayAutoUnlock unlocker(&lock_); + client->NotifyError(disconnect_reason); + } + return true; +} + +bool MultiplexRouter::ProcessIncomingMessage( + Message* message, + ClientCallBehavior client_call_behavior, + base::SingleThreadTaskRunner* current_task_runner) { + DCHECK(!current_task_runner || current_task_runner->BelongsToCurrentThread()); + DCHECK(!paused_); + DCHECK(message); + AssertLockAcquired(); + + if (message->IsNull()) { + // This is a sync message and has been processed during sync handle + // watching. + return true; + } + + if (PipeControlMessageHandler::IsPipeControlMessage(message)) { + bool result = false; + + { + MayAutoUnlock unlocker(&lock_); + result = control_message_handler_.Accept(message); + } + + if (!result) + RaiseErrorInNonTestingMode(); + + return true; + } + + InterfaceId id = message->interface_id(); + DCHECK(IsValidInterfaceId(id)); + + InterfaceEndpoint* endpoint = FindEndpoint(id); + if (!endpoint || endpoint->closed()) + return true; + + if (!endpoint->client()) { + // We need to wait until a client is attached in order to dispatch further + // messages. + return false; + } + + bool can_direct_call; + if (message->has_flag(Message::kFlagIsSync)) { + can_direct_call = client_call_behavior != NO_DIRECT_CLIENT_CALLS && + endpoint->task_runner()->BelongsToCurrentThread(); + } else { + can_direct_call = client_call_behavior == ALLOW_DIRECT_CLIENT_CALLS && + endpoint->task_runner() == current_task_runner; + } + + if (!can_direct_call) { + MaybePostToProcessTasks(endpoint->task_runner()); + return false; + } + + DCHECK(endpoint->task_runner()->BelongsToCurrentThread()); + + InterfaceEndpointClient* client = endpoint->client(); + bool result = false; + { + // We must unlock before calling into |client| because it may call this + // object within HandleIncomingMessage(). Holding the lock will lead to + // deadlock. + // + // It is safe to call into |client| without the lock. Because |client| is + // always accessed on the same thread, including DetachEndpointClient(). + MayAutoUnlock unlocker(&lock_); + result = client->HandleIncomingMessage(message); + } + if (!result) + RaiseErrorInNonTestingMode(); + + return true; +} + +void MultiplexRouter::MaybePostToProcessTasks( + base::SingleThreadTaskRunner* task_runner) { + AssertLockAcquired(); + if (posted_to_process_tasks_) + return; + + posted_to_process_tasks_ = true; + posted_to_task_runner_ = task_runner; + task_runner->PostTask( + FROM_HERE, base::Bind(&MultiplexRouter::LockAndCallProcessTasks, this)); +} + +void MultiplexRouter::LockAndCallProcessTasks() { + // There is no need to hold a ref to this class in this case because this is + // always called using base::Bind(), which holds a ref. + MayAutoLock locker(&lock_); + posted_to_process_tasks_ = false; + scoped_refptr<base::SingleThreadTaskRunner> runner( + std::move(posted_to_task_runner_)); + ProcessTasks(ALLOW_DIRECT_CLIENT_CALLS, runner.get()); +} + +void MultiplexRouter::UpdateEndpointStateMayRemove( + InterfaceEndpoint* endpoint, + EndpointStateUpdateType type) { + if (type == ENDPOINT_CLOSED) { + endpoint->set_closed(); + } else { + endpoint->set_peer_closed(); + // If the interface endpoint is performing a sync watch, this makes sure + // it is notified and eventually exits the sync watch. + endpoint->SignalSyncMessageEvent(); + } + if (endpoint->closed() && endpoint->peer_closed()) + endpoints_.erase(endpoint->id()); +} + +void MultiplexRouter::RaiseErrorInNonTestingMode() { + AssertLockAcquired(); + if (!testing_mode_) + RaiseError(); +} + +MultiplexRouter::InterfaceEndpoint* MultiplexRouter::FindOrInsertEndpoint( + InterfaceId id, + bool* inserted) { + AssertLockAcquired(); + // Either |inserted| is nullptr or it points to a boolean initialized as + // false. + DCHECK(!inserted || !*inserted); + + InterfaceEndpoint* endpoint = FindEndpoint(id); + if (!endpoint) { + endpoint = new InterfaceEndpoint(this, id); + endpoints_[id] = endpoint; + if (inserted) + *inserted = true; + } + + return endpoint; +} + +MultiplexRouter::InterfaceEndpoint* MultiplexRouter::FindEndpoint( + InterfaceId id) { + AssertLockAcquired(); + auto iter = endpoints_.find(id); + return iter != endpoints_.end() ? iter->second.get() : nullptr; +} + +void MultiplexRouter::AssertLockAcquired() { +#if DCHECK_IS_ON() + if (lock_) + lock_->AssertAcquired(); +#endif +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/multiplex_router.h b/mojo/public/cpp/bindings/lib/multiplex_router.h new file mode 100644 index 0000000000..cac138bcb7 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/multiplex_router.h @@ -0,0 +1,275 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MULTIPLEX_ROUTER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_MULTIPLEX_ROUTER_H_ + +#include <stdint.h> + +#include <deque> +#include <map> +#include <memory> +#include <string> + +#include "base/compiler_specific.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "base/optional.h" +#include "base/single_thread_task_runner.h" +#include "base/synchronization/lock.h" +#include "base/threading/thread_checker.h" +#include "mojo/public/cpp/bindings/associated_group_controller.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/connector.h" +#include "mojo/public/cpp/bindings/filter_chain.h" +#include "mojo/public/cpp/bindings/interface_id.h" +#include "mojo/public/cpp/bindings/message_header_validator.h" +#include "mojo/public/cpp/bindings/pipe_control_message_handler.h" +#include "mojo/public/cpp/bindings/pipe_control_message_handler_delegate.h" +#include "mojo/public/cpp/bindings/pipe_control_message_proxy.h" +#include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" + +namespace base { +class SingleThreadTaskRunner; +} + +namespace mojo { + +namespace internal { + +// MultiplexRouter supports routing messages for multiple interfaces over a +// single message pipe. +// +// It is created on the thread where the master interface of the message pipe +// lives. Although it is ref-counted, it is guarateed to be destructed on the +// same thread. +// Some public methods are only allowed to be called on the creating thread; +// while the others are safe to call from any threads. Please see the method +// comments for more details. +// +// NOTE: CloseMessagePipe() or PassMessagePipe() MUST be called on |runner|'s +// thread before this object is destroyed. +class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter + : NON_EXPORTED_BASE(public MessageReceiver), + public AssociatedGroupController, + NON_EXPORTED_BASE(public PipeControlMessageHandlerDelegate) { + public: + enum Config { + // There is only the master interface running on this router. Please note + // that because of interface versioning, the other side of the message pipe + // may use a newer master interface definition which passes associated + // interfaces. In that case, this router may still receive pipe control + // messages or messages targetting associated interfaces. + SINGLE_INTERFACE, + // Similar to the mode above, there is only the master interface running on + // this router. Besides, the master interface has sync methods. + SINGLE_INTERFACE_WITH_SYNC_METHODS, + // There may be associated interfaces running on this router. + MULTI_INTERFACE + }; + + // If |set_interface_id_namespace_bit| is true, the interface IDs generated by + // this router will have the highest bit set. + MultiplexRouter(ScopedMessagePipeHandle message_pipe, + Config config, + bool set_interface_id_namespace_bit, + scoped_refptr<base::SingleThreadTaskRunner> runner); + + // Sets the master interface name for this router. Only used when reporting + // message header or control message validation errors. + // |name| must be a string literal. + void SetMasterInterfaceName(const char* name); + + // --------------------------------------------------------------------------- + // The following public methods are safe to call from any threads. + + // AssociatedGroupController implementation: + InterfaceId AssociateInterface( + ScopedInterfaceEndpointHandle handle_to_send) override; + ScopedInterfaceEndpointHandle CreateLocalEndpointHandle( + InterfaceId id) override; + void CloseEndpointHandle( + InterfaceId id, + const base::Optional<DisconnectReason>& reason) override; + InterfaceEndpointController* AttachEndpointClient( + const ScopedInterfaceEndpointHandle& handle, + InterfaceEndpointClient* endpoint_client, + scoped_refptr<base::SingleThreadTaskRunner> runner) override; + void DetachEndpointClient( + const ScopedInterfaceEndpointHandle& handle) override; + void RaiseError() override; + + // --------------------------------------------------------------------------- + // The following public methods are called on the creating thread. + + // Please note that this method shouldn't be called unless it results from an + // explicit request of the user of bindings (e.g., the user sets an + // InterfacePtr to null or closes a Binding). + void CloseMessagePipe(); + + // Extracts the underlying message pipe. + ScopedMessagePipeHandle PassMessagePipe() { + DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK(!HasAssociatedEndpoints()); + return connector_.PassMessagePipe(); + } + + // Blocks the current thread until the first incoming message, or |deadline|. + bool WaitForIncomingMessage(MojoDeadline deadline) { + DCHECK(thread_checker_.CalledOnValidThread()); + return connector_.WaitForIncomingMessage(deadline); + } + + // See Binding for details of pause/resume. + void PauseIncomingMethodCallProcessing(); + void ResumeIncomingMethodCallProcessing(); + + // Whether there are any associated interfaces running currently. + bool HasAssociatedEndpoints() const; + + // Sets this object to testing mode. + // In testing mode, the object doesn't disconnect the underlying message pipe + // when it receives unexpected or invalid messages. + void EnableTestingMode(); + + // Is the router bound to a message pipe handle? + bool is_valid() const { + DCHECK(thread_checker_.CalledOnValidThread()); + return connector_.is_valid(); + } + + // TODO(yzshen): consider removing this getter. + MessagePipeHandle handle() const { + DCHECK(thread_checker_.CalledOnValidThread()); + return connector_.handle(); + } + + bool SimulateReceivingMessageForTesting(Message* message) { + return filters_.Accept(message); + } + + private: + class InterfaceEndpoint; + class MessageWrapper; + struct Task; + + ~MultiplexRouter() override; + + // MessageReceiver implementation: + bool Accept(Message* message) override; + + // PipeControlMessageHandlerDelegate implementation: + bool OnPeerAssociatedEndpointClosed( + InterfaceId id, + const base::Optional<DisconnectReason>& reason) override; + + void OnPipeConnectionError(); + + // Specifies whether we are allowed to directly call into + // InterfaceEndpointClient (given that we are already on the same thread as + // the client). + enum ClientCallBehavior { + // Don't call any InterfaceEndpointClient methods directly. + NO_DIRECT_CLIENT_CALLS, + // Only call InterfaceEndpointClient::HandleIncomingMessage directly to + // handle sync messages. + ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES, + // Allow to call any InterfaceEndpointClient methods directly. + ALLOW_DIRECT_CLIENT_CALLS + }; + + // Processes enqueued tasks (incoming messages and error notifications). + // |current_task_runner| is only used when |client_call_behavior| is + // ALLOW_DIRECT_CLIENT_CALLS to determine whether we are on the right task + // runner to make client calls for async messages or connection error + // notifications. + // + // Note: Because calling into InterfaceEndpointClient may lead to destruction + // of this object, if direct calls are allowed, the caller needs to hold on to + // a ref outside of |lock_| before calling this method. + void ProcessTasks(ClientCallBehavior client_call_behavior, + base::SingleThreadTaskRunner* current_task_runner); + + // Processes the first queued sync message for the endpoint corresponding to + // |id|; returns whether there are more sync messages for that endpoint in the + // queue. + // + // This method is only used by enpoints during sync watching. Therefore, not + // all sync messages are handled by it. + bool ProcessFirstSyncMessageForEndpoint(InterfaceId id); + + // Returns true to indicate that |task|/|message| has been processed. + bool ProcessNotifyErrorTask( + Task* task, + ClientCallBehavior client_call_behavior, + base::SingleThreadTaskRunner* current_task_runner); + bool ProcessIncomingMessage( + Message* message, + ClientCallBehavior client_call_behavior, + base::SingleThreadTaskRunner* current_task_runner); + + void MaybePostToProcessTasks(base::SingleThreadTaskRunner* task_runner); + void LockAndCallProcessTasks(); + + // Updates the state of |endpoint|. If both the endpoint and its peer have + // been closed, removes it from |endpoints_|. + // NOTE: The method may invalidate |endpoint|. + enum EndpointStateUpdateType { ENDPOINT_CLOSED, PEER_ENDPOINT_CLOSED }; + void UpdateEndpointStateMayRemove(InterfaceEndpoint* endpoint, + EndpointStateUpdateType type); + + void RaiseErrorInNonTestingMode(); + + InterfaceEndpoint* FindOrInsertEndpoint(InterfaceId id, bool* inserted); + InterfaceEndpoint* FindEndpoint(InterfaceId id); + + void AssertLockAcquired(); + + // Whether to set the namespace bit when generating interface IDs. Please see + // comments of kInterfaceIdNamespaceMask. + const bool set_interface_id_namespace_bit_; + + scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + + // Owned by |filters_| below. + MessageHeaderValidator* header_validator_; + + FilterChain filters_; + Connector connector_; + + base::ThreadChecker thread_checker_; + + // Protects the following members. + // Not set in Config::SINGLE_INTERFACE* mode. + mutable base::Optional<base::Lock> lock_; + PipeControlMessageHandler control_message_handler_; + + // NOTE: It is unsafe to call into this object while holding |lock_|. + PipeControlMessageProxy control_message_proxy_; + + std::map<InterfaceId, scoped_refptr<InterfaceEndpoint>> endpoints_; + uint32_t next_interface_id_value_; + + std::deque<std::unique_ptr<Task>> tasks_; + // It refers to tasks in |tasks_| and doesn't own any of them. + std::map<InterfaceId, std::deque<Task*>> sync_message_tasks_; + + bool posted_to_process_tasks_; + scoped_refptr<base::SingleThreadTaskRunner> posted_to_task_runner_; + + bool encountered_error_; + + bool paused_; + + bool testing_mode_; + + DISALLOW_COPY_AND_ASSIGN(MultiplexRouter); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_MULTIPLEX_ROUTER_H_ diff --git a/mojo/public/cpp/bindings/lib/native_enum_data.h b/mojo/public/cpp/bindings/lib/native_enum_data.h new file mode 100644 index 0000000000..dcafce2815 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/native_enum_data.h @@ -0,0 +1,26 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_ENUM_DATA_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_ENUM_DATA_H_ + +namespace mojo { +namespace internal { + +class ValidationContext; + +class NativeEnum_Data { + public: + static bool const kIsExtensible = true; + + static bool IsKnownValue(int32_t value) { return false; } + + static bool Validate(int32_t value, + ValidationContext* validation_context) { return true; } +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_ENUM_DATA_H_ diff --git a/mojo/public/cpp/bindings/lib/native_enum_serialization.h b/mojo/public/cpp/bindings/lib/native_enum_serialization.h new file mode 100644 index 0000000000..4faf957c58 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/native_enum_serialization.h @@ -0,0 +1,82 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_ENUM_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_ENUM_SERIALIZATION_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <type_traits> + +#include "base/logging.h" +#include "base/pickle.h" +#include "ipc/ipc_param_traits.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" +#include "mojo/public/cpp/bindings/native_enum.h" + +namespace mojo { +namespace internal { + +template <typename MaybeConstUserType> +struct NativeEnumSerializerImpl { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Traits = IPC::ParamTraits<UserType>; + + // IPC_ENUM_TRAITS* macros serialize enum as int, make sure that fits into + // mojo native-only enum. + static_assert(sizeof(NativeEnum) >= sizeof(int), + "Cannot store the serialization result in NativeEnum."); + + static void Serialize(UserType input, int32_t* output) { + base::Pickle pickle; + Traits::Write(&pickle, input); + + CHECK_GE(sizeof(int32_t), pickle.payload_size()); + *output = 0; + memcpy(reinterpret_cast<char*>(output), pickle.payload(), + pickle.payload_size()); + } + + struct PickleData { + uint32_t payload_size; + int32_t value; + }; + static_assert(sizeof(PickleData) == 8, "PickleData size mismatch."); + + static bool Deserialize(int32_t input, UserType* output) { + PickleData data = {sizeof(int32_t), input}; + base::Pickle pickle_view(reinterpret_cast<const char*>(&data), + sizeof(PickleData)); + base::PickleIterator iter(pickle_view); + return Traits::Read(&pickle_view, &iter, output); + } +}; + +struct UnmappedNativeEnumSerializerImpl { + static void Serialize(NativeEnum input, int32_t* output) { + *output = static_cast<int32_t>(input); + } + static bool Deserialize(int32_t input, NativeEnum* output) { + *output = static_cast<NativeEnum>(input); + return true; + } +}; + +template <> +struct NativeEnumSerializerImpl<NativeEnum> + : public UnmappedNativeEnumSerializerImpl {}; + +template <> +struct NativeEnumSerializerImpl<const NativeEnum> + : public UnmappedNativeEnumSerializerImpl {}; + +template <typename MaybeConstUserType> +struct Serializer<NativeEnum, MaybeConstUserType> + : public NativeEnumSerializerImpl<MaybeConstUserType> {}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_ENUM_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/native_struct.cc b/mojo/public/cpp/bindings/lib/native_struct.cc new file mode 100644 index 0000000000..7b1a1a6c59 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/native_struct.cc @@ -0,0 +1,34 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/native_struct.h" + +#include "mojo/public/cpp/bindings/lib/hash_util.h" + +namespace mojo { + +// static +NativeStructPtr NativeStruct::New() { + return NativeStructPtr(base::in_place); +} + +NativeStruct::NativeStruct() {} + +NativeStruct::~NativeStruct() {} + +NativeStructPtr NativeStruct::Clone() const { + NativeStructPtr rv(New()); + rv->data = data; + return rv; +} + +bool NativeStruct::Equals(const NativeStruct& other) const { + return data == other.data; +} + +size_t NativeStruct::Hash(size_t seed) const { + return internal::Hash(seed, data); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/native_struct_data.cc b/mojo/public/cpp/bindings/lib/native_struct_data.cc new file mode 100644 index 0000000000..0e5d245692 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/native_struct_data.cc @@ -0,0 +1,22 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/native_struct_data.h" + +#include "mojo/public/cpp/bindings/lib/buffer.h" +#include "mojo/public/cpp/bindings/lib/validation_context.h" + +namespace mojo { +namespace internal { + +// static +bool NativeStruct_Data::Validate(const void* data, + ValidationContext* validation_context) { + const ContainerValidateParams data_validate_params(0, false, nullptr); + return Array_Data<uint8_t>::Validate(data, validation_context, + &data_validate_params); +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/native_struct_data.h b/mojo/public/cpp/bindings/lib/native_struct_data.h new file mode 100644 index 0000000000..1c7cd81c77 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/native_struct_data.h @@ -0,0 +1,38 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_DATA_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_DATA_H_ + +#include <vector> + +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/system/handle.h" + +namespace mojo { +namespace internal { + +class ValidationContext; + +class MOJO_CPP_BINDINGS_EXPORT NativeStruct_Data { + public: + static bool Validate(const void* data, ValidationContext* validation_context); + + // Unlike normal structs, the memory layout is exactly the same as an array + // of uint8_t. + Array_Data<uint8_t> data; + + private: + NativeStruct_Data() = delete; + ~NativeStruct_Data() = delete; +}; + +static_assert(sizeof(Array_Data<uint8_t>) == sizeof(NativeStruct_Data), + "Mismatched NativeStruct_Data and Array_Data<uint8_t> size"); + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_DATA_H_ diff --git a/mojo/public/cpp/bindings/lib/native_struct_serialization.cc b/mojo/public/cpp/bindings/lib/native_struct_serialization.cc new file mode 100644 index 0000000000..fa0dbf3803 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/native_struct_serialization.cc @@ -0,0 +1,61 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/native_struct_serialization.h" + +#include "mojo/public/cpp/bindings/lib/serialization.h" + +namespace mojo { +namespace internal { + +// static +size_t UnmappedNativeStructSerializerImpl::PrepareToSerialize( + const NativeStructPtr& input, + SerializationContext* context) { + if (!input) + return 0; + return internal::PrepareToSerialize<ArrayDataView<uint8_t>>(input->data, + context); +} + +// static +void UnmappedNativeStructSerializerImpl::Serialize( + const NativeStructPtr& input, + Buffer* buffer, + NativeStruct_Data** output, + SerializationContext* context) { + if (!input) { + *output = nullptr; + return; + } + + Array_Data<uint8_t>* data = nullptr; + const ContainerValidateParams params(0, false, nullptr); + internal::Serialize<ArrayDataView<uint8_t>>(input->data, buffer, &data, + ¶ms, context); + *output = reinterpret_cast<NativeStruct_Data*>(data); +} + +// static +bool UnmappedNativeStructSerializerImpl::Deserialize( + NativeStruct_Data* input, + NativeStructPtr* output, + SerializationContext* context) { + Array_Data<uint8_t>* data = reinterpret_cast<Array_Data<uint8_t>*>(input); + + NativeStructPtr result(NativeStruct::New()); + if (!internal::Deserialize<ArrayDataView<uint8_t>>(data, &result->data, + context)) { + output = nullptr; + return false; + } + if (!result->data) + *output = nullptr; + else + result.Swap(output); + return true; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/native_struct_serialization.h b/mojo/public/cpp/bindings/lib/native_struct_serialization.h new file mode 100644 index 0000000000..457435b955 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/native_struct_serialization.h @@ -0,0 +1,134 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_SERIALIZATION_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <limits> + +#include "base/logging.h" +#include "base/pickle.h" +#include "ipc/ipc_param_traits.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/bindings/lib/native_struct_data.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" +#include "mojo/public/cpp/bindings/lib/serialization_util.h" +#include "mojo/public/cpp/bindings/native_struct.h" +#include "mojo/public/cpp/bindings/native_struct_data_view.h" + +namespace mojo { +namespace internal { + +template <typename MaybeConstUserType> +struct NativeStructSerializerImpl { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Traits = IPC::ParamTraits<UserType>; + + static size_t PrepareToSerialize(MaybeConstUserType& value, + SerializationContext* context) { + base::PickleSizer sizer; + Traits::GetSize(&sizer, value); + return Align(sizer.payload_size() + sizeof(ArrayHeader)); + } + + static void Serialize(MaybeConstUserType& value, + Buffer* buffer, + NativeStruct_Data** out, + SerializationContext* context) { + base::Pickle pickle; + Traits::Write(&pickle, value); + +#if DCHECK_IS_ON() + base::PickleSizer sizer; + Traits::GetSize(&sizer, value); + DCHECK_EQ(sizer.payload_size(), pickle.payload_size()); +#endif + + size_t total_size = pickle.payload_size() + sizeof(ArrayHeader); + DCHECK_LT(total_size, std::numeric_limits<uint32_t>::max()); + + // Allocate a uint8 array, initialize its header, and copy the Pickle in. + ArrayHeader* header = + reinterpret_cast<ArrayHeader*>(buffer->Allocate(total_size)); + header->num_bytes = static_cast<uint32_t>(total_size); + header->num_elements = static_cast<uint32_t>(pickle.payload_size()); + memcpy(reinterpret_cast<char*>(header) + sizeof(ArrayHeader), + pickle.payload(), pickle.payload_size()); + + *out = reinterpret_cast<NativeStruct_Data*>(header); + } + + static bool Deserialize(NativeStruct_Data* data, + UserType* out, + SerializationContext* context) { + if (!data) + return false; + + // Construct a temporary base::Pickle view over the array data. Note that + // the Array_Data is laid out like this: + // + // [num_bytes (4 bytes)] [num_elements (4 bytes)] [elements...] + // + // and base::Pickle expects to view data like this: + // + // [payload_size (4 bytes)] [header bytes ...] [payload...] + // + // Because ArrayHeader's num_bytes includes the length of the header and + // Pickle's payload_size does not, we need to adjust the stored value + // momentarily so Pickle can view the data. + ArrayHeader* header = reinterpret_cast<ArrayHeader*>(data); + DCHECK_GE(header->num_bytes, sizeof(ArrayHeader)); + header->num_bytes -= sizeof(ArrayHeader); + + { + // Construct a view over the full Array_Data, including our hacked up + // header. Pickle will infer from this that the header is 8 bytes long, + // and the payload will contain all of the pickled bytes. + base::Pickle pickle_view(reinterpret_cast<const char*>(header), + header->num_bytes + sizeof(ArrayHeader)); + base::PickleIterator iter(pickle_view); + if (!Traits::Read(&pickle_view, &iter, out)) + return false; + } + + // Return the header to its original state. + header->num_bytes += sizeof(ArrayHeader); + + return true; + } +}; + +struct MOJO_CPP_BINDINGS_EXPORT UnmappedNativeStructSerializerImpl { + static size_t PrepareToSerialize(const NativeStructPtr& input, + SerializationContext* context); + static void Serialize(const NativeStructPtr& input, + Buffer* buffer, + NativeStruct_Data** output, + SerializationContext* context); + static bool Deserialize(NativeStruct_Data* input, + NativeStructPtr* output, + SerializationContext* context); +}; + +template <> +struct NativeStructSerializerImpl<NativeStructPtr> + : public UnmappedNativeStructSerializerImpl {}; + +template <> +struct NativeStructSerializerImpl<const NativeStructPtr> + : public UnmappedNativeStructSerializerImpl {}; + +template <typename MaybeConstUserType> +struct Serializer<NativeStructDataView, MaybeConstUserType> + : public NativeStructSerializerImpl<MaybeConstUserType> {}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc b/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc new file mode 100644 index 0000000000..d451c05a5f --- /dev/null +++ b/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc @@ -0,0 +1,90 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/pipe_control_message_handler.h" + +#include "base/logging.h" +#include "mojo/public/cpp/bindings/interface_id.h" +#include "mojo/public/cpp/bindings/lib/message_builder.h" +#include "mojo/public/cpp/bindings/lib/serialization.h" +#include "mojo/public/cpp/bindings/lib/serialization_context.h" +#include "mojo/public/cpp/bindings/lib/validation_context.h" +#include "mojo/public/cpp/bindings/lib/validation_util.h" +#include "mojo/public/cpp/bindings/pipe_control_message_handler_delegate.h" +#include "mojo/public/interfaces/bindings/pipe_control_messages.mojom.h" + +namespace mojo { + +PipeControlMessageHandler::PipeControlMessageHandler( + PipeControlMessageHandlerDelegate* delegate) + : delegate_(delegate) {} + +PipeControlMessageHandler::~PipeControlMessageHandler() {} + +void PipeControlMessageHandler::SetDescription(const std::string& description) { + description_ = description; +} + +// static +bool PipeControlMessageHandler::IsPipeControlMessage(const Message* message) { + return !IsValidInterfaceId(message->interface_id()); +} + +bool PipeControlMessageHandler::Accept(Message* message) { + if (!Validate(message)) + return false; + + if (message->name() == pipe_control::kRunOrClosePipeMessageId) + return RunOrClosePipe(message); + + NOTREACHED(); + return false; +} + +bool PipeControlMessageHandler::Validate(Message* message) { + internal::ValidationContext validation_context(message->payload(), + message->payload_num_bytes(), + 0, 0, message, description_); + + if (message->name() == pipe_control::kRunOrClosePipeMessageId) { + if (!internal::ValidateMessageIsRequestWithoutResponse( + message, &validation_context)) { + return false; + } + return internal::ValidateMessagePayload< + pipe_control::internal::RunOrClosePipeMessageParams_Data>( + message, &validation_context); + } + + return false; +} + +bool PipeControlMessageHandler::RunOrClosePipe(Message* message) { + internal::SerializationContext context; + pipe_control::internal::RunOrClosePipeMessageParams_Data* params = + reinterpret_cast< + pipe_control::internal::RunOrClosePipeMessageParams_Data*>( + message->mutable_payload()); + pipe_control::RunOrClosePipeMessageParamsPtr params_ptr; + internal::Deserialize<pipe_control::RunOrClosePipeMessageParamsDataView>( + params, ¶ms_ptr, &context); + + if (params_ptr->input->is_peer_associated_endpoint_closed_event()) { + const auto& event = + params_ptr->input->get_peer_associated_endpoint_closed_event(); + + base::Optional<DisconnectReason> reason; + if (event->disconnect_reason) { + reason.emplace(event->disconnect_reason->custom_reason, + event->disconnect_reason->description); + } + return delegate_->OnPeerAssociatedEndpointClosed(event->id, reason); + } + + DVLOG(1) << "Unsupported command in a RunOrClosePipe message pipe control " + << "message. Closing the pipe."; + return false; +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc b/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc new file mode 100644 index 0000000000..1029c2c491 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc @@ -0,0 +1,68 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/pipe_control_message_proxy.h" + +#include <stddef.h> +#include <utility> + +#include "base/logging.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/lib/message_builder.h" +#include "mojo/public/cpp/bindings/lib/serialization.h" +#include "mojo/public/interfaces/bindings/pipe_control_messages.mojom.h" + +namespace mojo { +namespace { + +Message ConstructRunOrClosePipeMessage( + pipe_control::RunOrClosePipeInputPtr input_ptr) { + internal::SerializationContext context; + + auto params_ptr = pipe_control::RunOrClosePipeMessageParams::New(); + params_ptr->input = std::move(input_ptr); + + size_t size = internal::PrepareToSerialize< + pipe_control::RunOrClosePipeMessageParamsDataView>(params_ptr, &context); + internal::MessageBuilder builder(pipe_control::kRunOrClosePipeMessageId, 0, + size, 0); + + pipe_control::internal::RunOrClosePipeMessageParams_Data* params = nullptr; + internal::Serialize<pipe_control::RunOrClosePipeMessageParamsDataView>( + params_ptr, builder.buffer(), ¶ms, &context); + builder.message()->set_interface_id(kInvalidInterfaceId); + return std::move(*builder.message()); +} + +} // namespace + +PipeControlMessageProxy::PipeControlMessageProxy(MessageReceiver* receiver) + : receiver_(receiver) {} + +void PipeControlMessageProxy::NotifyPeerEndpointClosed( + InterfaceId id, + const base::Optional<DisconnectReason>& reason) { + Message message(ConstructPeerEndpointClosedMessage(id, reason)); + ignore_result(receiver_->Accept(&message)); +} + +// static +Message PipeControlMessageProxy::ConstructPeerEndpointClosedMessage( + InterfaceId id, + const base::Optional<DisconnectReason>& reason) { + auto event = pipe_control::PeerAssociatedEndpointClosedEvent::New(); + event->id = id; + if (reason) { + event->disconnect_reason = pipe_control::DisconnectReason::New(); + event->disconnect_reason->custom_reason = reason->custom_reason; + event->disconnect_reason->description = reason->description; + } + + auto input = pipe_control::RunOrClosePipeInput::New(); + input->set_peer_associated_endpoint_closed_event(std::move(event)); + + return ConstructRunOrClosePipeMessage(std::move(input)); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc b/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc new file mode 100644 index 0000000000..c1345079a5 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc @@ -0,0 +1,382 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" + +#include "base/bind.h" +#include "base/logging.h" +#include "base/synchronization/lock.h" +#include "mojo/public/cpp/bindings/associated_group_controller.h" +#include "mojo/public/cpp/bindings/lib/may_auto_lock.h" + +namespace mojo { + +// ScopedInterfaceEndpointHandle::State ---------------------------------------- + +// State could be called from multiple threads. +class ScopedInterfaceEndpointHandle::State + : public base::RefCountedThreadSafe<State> { + public: + State() = default; + + State(InterfaceId id, + scoped_refptr<AssociatedGroupController> group_controller) + : id_(id), group_controller_(group_controller) {} + + void InitPendingState(scoped_refptr<State> peer) { + DCHECK(!lock_); + DCHECK(!pending_association_); + + lock_.emplace(); + pending_association_ = true; + peer_state_ = std::move(peer); + } + + void Close(const base::Optional<DisconnectReason>& reason) { + scoped_refptr<AssociatedGroupController> cached_group_controller; + InterfaceId cached_id = kInvalidInterfaceId; + scoped_refptr<State> cached_peer_state; + + { + internal::MayAutoLock locker(&lock_); + + if (!association_event_handler_.is_null()) { + association_event_handler_.Reset(); + runner_ = nullptr; + } + + if (!pending_association_) { + if (IsValidInterfaceId(id_)) { + // Intentionally keep |group_controller_| unchanged. + // That is because the callback created by + // CreateGroupControllerGetter() could still be used after this point, + // potentially from another thread. We would like it to continue + // returning the same group controller. + // + // Imagine there is a ThreadSafeForwarder A: + // (1) On the IO thread, A's underlying associated interface pointer + // is closed. + // (2) On the proxy thread, the user makes a call on A to pass an + // associated request B_asso_req. The callback returned by + // CreateGroupControllerGetter() is used to associate B_asso_req. + // (3) On the proxy thread, the user immediately binds B_asso_ptr_info + // to B_asso_ptr and makes calls on it. + // + // If we reset |group_controller_| in step (1), step (2) won't be able + // to associate B_asso_req. Therefore, in step (3) B_asso_ptr won't be + // able to serialize associated endpoints or send message because it + // is still in "pending_association" state and doesn't have a group + // controller. + // + // We could "address" this issue by ignoring messages if there isn't a + // group controller. But the side effect is that we cannot detect + // programming errors of "using associated interface pointer before + // sending associated request". + + cached_group_controller = group_controller_; + cached_id = id_; + id_ = kInvalidInterfaceId; + } + } else { + pending_association_ = false; + cached_peer_state = std::move(peer_state_); + } + } + + if (cached_group_controller) { + cached_group_controller->CloseEndpointHandle(cached_id, reason); + } else if (cached_peer_state) { + cached_peer_state->OnPeerClosedBeforeAssociation(reason); + } + } + + void SetAssociationEventHandler(AssociationEventCallback handler) { + internal::MayAutoLock locker(&lock_); + + if (!pending_association_ && !IsValidInterfaceId(id_)) + return; + + association_event_handler_ = std::move(handler); + if (association_event_handler_.is_null()) { + runner_ = nullptr; + return; + } + + runner_ = base::ThreadTaskRunnerHandle::Get(); + if (!pending_association_) { + runner_->PostTask( + FROM_HERE, + base::Bind( + &ScopedInterfaceEndpointHandle::State::RunAssociationEventHandler, + this, runner_, ASSOCIATED)); + } else if (!peer_state_) { + runner_->PostTask( + FROM_HERE, + base::Bind( + &ScopedInterfaceEndpointHandle::State::RunAssociationEventHandler, + this, runner_, PEER_CLOSED_BEFORE_ASSOCIATION)); + } + } + + bool NotifyAssociation( + InterfaceId id, + scoped_refptr<AssociatedGroupController> peer_group_controller) { + scoped_refptr<State> cached_peer_state; + { + internal::MayAutoLock locker(&lock_); + + DCHECK(pending_association_); + pending_association_ = false; + cached_peer_state = std::move(peer_state_); + } + + if (cached_peer_state) { + cached_peer_state->OnAssociated(id, std::move(peer_group_controller)); + return true; + } + return false; + } + + bool is_valid() const { + internal::MayAutoLock locker(&lock_); + return pending_association_ || IsValidInterfaceId(id_); + } + + bool pending_association() const { + internal::MayAutoLock locker(&lock_); + return pending_association_; + } + + InterfaceId id() const { + internal::MayAutoLock locker(&lock_); + return id_; + } + + AssociatedGroupController* group_controller() const { + internal::MayAutoLock locker(&lock_); + return group_controller_.get(); + } + + const base::Optional<DisconnectReason>& disconnect_reason() const { + internal::MayAutoLock locker(&lock_); + return disconnect_reason_; + } + + private: + friend class base::RefCountedThreadSafe<State>; + + ~State() { + DCHECK(!pending_association_); + DCHECK(!IsValidInterfaceId(id_)); + } + + // Called by the peer, maybe from a different thread. + void OnAssociated(InterfaceId id, + scoped_refptr<AssociatedGroupController> group_controller) { + AssociationEventCallback handler; + { + internal::MayAutoLock locker(&lock_); + + // There may be race between Close() of endpoint A and + // NotifyPeerAssociation() of endpoint A_peer on different threads. + // Therefore, it is possible that endpoint A has been closed but it + // still gets OnAssociated() call from its peer. + if (!pending_association_) + return; + + pending_association_ = false; + peer_state_ = nullptr; + id_ = id; + group_controller_ = std::move(group_controller); + + if (!association_event_handler_.is_null()) { + if (runner_->BelongsToCurrentThread()) { + handler = std::move(association_event_handler_); + runner_ = nullptr; + } else { + runner_->PostTask(FROM_HERE, + base::Bind(&ScopedInterfaceEndpointHandle::State:: + RunAssociationEventHandler, + this, runner_, ASSOCIATED)); + } + } + } + + if (!handler.is_null()) + std::move(handler).Run(ASSOCIATED); + } + + // Called by the peer, maybe from a different thread. + void OnPeerClosedBeforeAssociation( + const base::Optional<DisconnectReason>& reason) { + AssociationEventCallback handler; + { + internal::MayAutoLock locker(&lock_); + + // There may be race between Close()/NotifyPeerAssociation() of endpoint + // A and Close() of endpoint A_peer on different threads. + // Therefore, it is possible that endpoint A is not in pending association + // state but still gets OnPeerClosedBeforeAssociation() call from its + // peer. + if (!pending_association_) + return; + + disconnect_reason_ = reason; + // NOTE: This handle itself is still pending. + peer_state_ = nullptr; + + if (!association_event_handler_.is_null()) { + if (runner_->BelongsToCurrentThread()) { + handler = std::move(association_event_handler_); + runner_ = nullptr; + } else { + runner_->PostTask( + FROM_HERE, + base::Bind(&ScopedInterfaceEndpointHandle::State:: + RunAssociationEventHandler, + this, runner_, PEER_CLOSED_BEFORE_ASSOCIATION)); + } + } + } + + if (!handler.is_null()) + std::move(handler).Run(PEER_CLOSED_BEFORE_ASSOCIATION); + } + + void RunAssociationEventHandler( + scoped_refptr<base::SingleThreadTaskRunner> posted_to_runner, + AssociationEvent event) { + AssociationEventCallback handler; + + { + internal::MayAutoLock locker(&lock_); + if (posted_to_runner == runner_) { + runner_ = nullptr; + handler = std::move(association_event_handler_); + } + } + + if (!handler.is_null()) + std::move(handler).Run(event); + } + + // Protects the following members if the handle is initially set to pending + // association. + mutable base::Optional<base::Lock> lock_; + + bool pending_association_ = false; + base::Optional<DisconnectReason> disconnect_reason_; + + scoped_refptr<State> peer_state_; + + AssociationEventCallback association_event_handler_; + scoped_refptr<base::SingleThreadTaskRunner> runner_; + + InterfaceId id_ = kInvalidInterfaceId; + scoped_refptr<AssociatedGroupController> group_controller_; + + DISALLOW_COPY_AND_ASSIGN(State); +}; + +// ScopedInterfaceEndpointHandle ----------------------------------------------- + +// static +void ScopedInterfaceEndpointHandle::CreatePairPendingAssociation( + ScopedInterfaceEndpointHandle* handle0, + ScopedInterfaceEndpointHandle* handle1) { + ScopedInterfaceEndpointHandle result0; + ScopedInterfaceEndpointHandle result1; + result0.state_->InitPendingState(result1.state_); + result1.state_->InitPendingState(result0.state_); + + *handle0 = std::move(result0); + *handle1 = std::move(result1); +} + +ScopedInterfaceEndpointHandle::ScopedInterfaceEndpointHandle() + : state_(new State) {} + +ScopedInterfaceEndpointHandle::ScopedInterfaceEndpointHandle( + ScopedInterfaceEndpointHandle&& other) + : state_(new State) { + state_.swap(other.state_); +} + +ScopedInterfaceEndpointHandle::~ScopedInterfaceEndpointHandle() { + state_->Close(base::nullopt); +} + +ScopedInterfaceEndpointHandle& ScopedInterfaceEndpointHandle::operator=( + ScopedInterfaceEndpointHandle&& other) { + reset(); + state_.swap(other.state_); + return *this; +} + +bool ScopedInterfaceEndpointHandle::is_valid() const { + return state_->is_valid(); +} + +bool ScopedInterfaceEndpointHandle::pending_association() const { + return state_->pending_association(); +} + +InterfaceId ScopedInterfaceEndpointHandle::id() const { + return state_->id(); +} + +AssociatedGroupController* ScopedInterfaceEndpointHandle::group_controller() + const { + return state_->group_controller(); +} + +const base::Optional<DisconnectReason>& +ScopedInterfaceEndpointHandle::disconnect_reason() const { + return state_->disconnect_reason(); +} + +void ScopedInterfaceEndpointHandle::SetAssociationEventHandler( + AssociationEventCallback handler) { + state_->SetAssociationEventHandler(std::move(handler)); +} + +void ScopedInterfaceEndpointHandle::reset() { + ResetInternal(base::nullopt); +} + +void ScopedInterfaceEndpointHandle::ResetWithReason( + uint32_t custom_reason, + const std::string& description) { + ResetInternal(DisconnectReason(custom_reason, description)); +} + +ScopedInterfaceEndpointHandle::ScopedInterfaceEndpointHandle( + InterfaceId id, + scoped_refptr<AssociatedGroupController> group_controller) + : state_(new State(id, std::move(group_controller))) { + DCHECK(!IsValidInterfaceId(state_->id()) || state_->group_controller()); +} + +bool ScopedInterfaceEndpointHandle::NotifyAssociation( + InterfaceId id, + scoped_refptr<AssociatedGroupController> peer_group_controller) { + return state_->NotifyAssociation(id, peer_group_controller); +} + +void ScopedInterfaceEndpointHandle::ResetInternal( + const base::Optional<DisconnectReason>& reason) { + scoped_refptr<State> new_state(new State); + state_->Close(reason); + state_.swap(new_state); +} + +base::Callback<AssociatedGroupController*()> +ScopedInterfaceEndpointHandle::CreateGroupControllerGetter() const { + // We allow this callback to be run on any thread. If this handle is created + // in non-pending state, we don't have a lock but it should still be safe + // because the group controller never changes. + return base::Bind(&State::group_controller, state_); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/serialization.h b/mojo/public/cpp/bindings/lib/serialization.h new file mode 100644 index 0000000000..2a7d288d55 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/serialization.h @@ -0,0 +1,107 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_SERIALIZATION_H_ + +#include <string.h> + +#include "mojo/public/cpp/bindings/array_traits_carray.h" +#include "mojo/public/cpp/bindings/array_traits_stl.h" +#include "mojo/public/cpp/bindings/lib/array_serialization.h" +#include "mojo/public/cpp/bindings/lib/buffer.h" +#include "mojo/public/cpp/bindings/lib/handle_interface_serialization.h" +#include "mojo/public/cpp/bindings/lib/map_serialization.h" +#include "mojo/public/cpp/bindings/lib/native_enum_serialization.h" +#include "mojo/public/cpp/bindings/lib/native_struct_serialization.h" +#include "mojo/public/cpp/bindings/lib/string_serialization.h" +#include "mojo/public/cpp/bindings/lib/template_util.h" +#include "mojo/public/cpp/bindings/map_traits_stl.h" +#include "mojo/public/cpp/bindings/string_traits_stl.h" +#include "mojo/public/cpp/bindings/string_traits_string16.h" +#include "mojo/public/cpp/bindings/string_traits_string_piece.h" + +namespace mojo { +namespace internal { + +template <typename MojomType, typename DataArrayType, typename UserType> +DataArrayType StructSerializeImpl(UserType* input) { + static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value, + "Unexpected type."); + + SerializationContext context; + size_t size = PrepareToSerialize<MojomType>(*input, &context); + DCHECK_EQ(size, Align(size)); + + DataArrayType result(size); + if (size == 0) + return result; + + void* result_buffer = &result.front(); + // The serialization logic requires that the buffer is 8-byte aligned. If the + // result buffer is not properly aligned, we have to do an extra copy. In + // practice, this should never happen for std::vector. + bool need_copy = !IsAligned(result_buffer); + + if (need_copy) { + // calloc sets the memory to all zero. + result_buffer = calloc(size, 1); + DCHECK(IsAligned(result_buffer)); + } + + Buffer buffer; + buffer.Initialize(result_buffer, size); + typename MojomTypeTraits<MojomType>::Data* data = nullptr; + Serialize<MojomType>(*input, &buffer, &data, &context); + + if (need_copy) { + memcpy(&result.front(), result_buffer, size); + free(result_buffer); + } + + return result; +} + +template <typename MojomType, typename DataArrayType, typename UserType> +bool StructDeserializeImpl(const DataArrayType& input, + UserType* output, + bool (*validate_func)(const void*, + ValidationContext*)) { + static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value, + "Unexpected type."); + using DataType = typename MojomTypeTraits<MojomType>::Data; + + // TODO(sammc): Use DataArrayType::empty() once WTF::Vector::empty() exists. + void* input_buffer = + input.size() == 0 + ? nullptr + : const_cast<void*>(reinterpret_cast<const void*>(&input.front())); + + // Please see comments in StructSerializeImpl. + bool need_copy = !IsAligned(input_buffer); + + if (need_copy) { + input_buffer = malloc(input.size()); + DCHECK(IsAligned(input_buffer)); + memcpy(input_buffer, &input.front(), input.size()); + } + + ValidationContext validation_context(input_buffer, input.size(), 0, 0); + bool result = false; + if (validate_func(input_buffer, &validation_context)) { + auto data = reinterpret_cast<DataType*>(input_buffer); + SerializationContext context; + result = Deserialize<MojomType>(data, output, &context); + } + + if (need_copy) + free(input_buffer); + + return result; +} + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/serialization_context.cc b/mojo/public/cpp/bindings/lib/serialization_context.cc new file mode 100644 index 0000000000..e2fd5c6e18 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/serialization_context.cc @@ -0,0 +1,57 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/serialization_context.h" + +#include <limits> + +#include "base/logging.h" +#include "mojo/public/cpp/system/core.h" + +namespace mojo { +namespace internal { + +SerializedHandleVector::SerializedHandleVector() {} + +SerializedHandleVector::~SerializedHandleVector() { + for (auto handle : handles_) { + if (handle.is_valid()) { + MojoResult rv = MojoClose(handle.value()); + DCHECK_EQ(rv, MOJO_RESULT_OK); + } + } +} + +Handle_Data SerializedHandleVector::AddHandle(mojo::Handle handle) { + Handle_Data data; + if (!handle.is_valid()) { + data.value = kEncodedInvalidHandleValue; + } else { + DCHECK_LT(handles_.size(), std::numeric_limits<uint32_t>::max()); + data.value = static_cast<uint32_t>(handles_.size()); + handles_.push_back(handle); + } + return data; +} + +mojo::Handle SerializedHandleVector::TakeHandle( + const Handle_Data& encoded_handle) { + if (!encoded_handle.is_valid()) + return mojo::Handle(); + DCHECK_LT(encoded_handle.value, handles_.size()); + return FetchAndReset(&handles_[encoded_handle.value]); +} + +void SerializedHandleVector::Swap(std::vector<mojo::Handle>* other) { + handles_.swap(*other); +} + +SerializationContext::SerializationContext() {} + +SerializationContext::~SerializationContext() { + DCHECK(!custom_contexts || custom_contexts->empty()); +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/serialization_context.h b/mojo/public/cpp/bindings/lib/serialization_context.h new file mode 100644 index 0000000000..a34fe3d4ed --- /dev/null +++ b/mojo/public/cpp/bindings/lib/serialization_context.h @@ -0,0 +1,77 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_BINDINGS_SERIALIZATION_CONTEXT_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_BINDINGS_SERIALIZATION_CONTEXT_H_ + +#include <stddef.h> + +#include <memory> +#include <queue> +#include <vector> + +#include "base/macros.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" +#include "mojo/public/cpp/system/handle.h" + +namespace mojo { +namespace internal { + +// A container for handles during serialization/deserialization. +class MOJO_CPP_BINDINGS_EXPORT SerializedHandleVector { + public: + SerializedHandleVector(); + ~SerializedHandleVector(); + + size_t size() const { return handles_.size(); } + + // Adds a handle to the handle list and returns its index for encoding. + Handle_Data AddHandle(mojo::Handle handle); + + // Takes a handle from the list of serialized handle data. + mojo::Handle TakeHandle(const Handle_Data& encoded_handle); + + // Takes a handle from the list of serialized handle data and returns it in + // |*out_handle| as a specific scoped handle type. + template <typename T> + ScopedHandleBase<T> TakeHandleAs(const Handle_Data& encoded_handle) { + return MakeScopedHandle(T(TakeHandle(encoded_handle).value())); + } + + // Swaps all owned handles out with another Handle vector. + void Swap(std::vector<mojo::Handle>* other); + + private: + // Handles are owned by this object. + std::vector<mojo::Handle> handles_; + + DISALLOW_COPY_AND_ASSIGN(SerializedHandleVector); +}; + +// Context information for serialization/deserialization routines. +struct MOJO_CPP_BINDINGS_EXPORT SerializationContext { + SerializationContext(); + + ~SerializationContext(); + + // Opaque context pointers returned by StringTraits::SetUpContext(). + std::unique_ptr<std::queue<void*>> custom_contexts; + + // Stashes handles encoded in a message by index. + SerializedHandleVector handles; + + // The number of ScopedInterfaceEndpointHandles that need to be serialized. + // It is calculated by PrepareToSerialize(). + uint32_t associated_endpoint_count = 0; + + // Stashes ScopedInterfaceEndpointHandles encoded in a message by index. + std::vector<ScopedInterfaceEndpointHandle> associated_endpoint_handles; +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_BINDINGS_SERIALIZATION_CONTEXT_H_ diff --git a/mojo/public/cpp/bindings/lib/serialization_forward.h b/mojo/public/cpp/bindings/lib/serialization_forward.h new file mode 100644 index 0000000000..55c9982ccc --- /dev/null +++ b/mojo/public/cpp/bindings/lib/serialization_forward.h @@ -0,0 +1,123 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_SERIALIZATION_FORWARD_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_SERIALIZATION_FORWARD_H_ + +#include "base/optional.h" +#include "mojo/public/cpp/bindings/array_traits.h" +#include "mojo/public/cpp/bindings/enum_traits.h" +#include "mojo/public/cpp/bindings/lib/template_util.h" +#include "mojo/public/cpp/bindings/map_traits.h" +#include "mojo/public/cpp/bindings/string_traits.h" +#include "mojo/public/cpp/bindings/struct_traits.h" +#include "mojo/public/cpp/bindings/union_traits.h" + +// This file is included by serialization implementation files to avoid circular +// includes. +// Users of the serialization funtions should include serialization.h (and also +// wtf_serialization.h if necessary). + +namespace mojo { +namespace internal { + +template <typename MojomType, typename MaybeConstUserType> +struct Serializer; + +template <typename T> +struct IsOptionalWrapper { + static const bool value = IsSpecializationOf< + base::Optional, + typename std::remove_const< + typename std::remove_reference<T>::type>::type>::value; +}; + +// PrepareToSerialize() must be matched by a Serialize() for the same input +// later. Moreover, within the same SerializationContext if PrepareToSerialize() +// is called for |input_1|, ..., |input_n|, Serialize() must be called for +// those objects in the exact same order. +template <typename MojomType, + typename InputUserType, + typename... Args, + typename std::enable_if< + !IsOptionalWrapper<InputUserType>::value>::type* = nullptr> +size_t PrepareToSerialize(InputUserType&& input, Args&&... args) { + return Serializer<MojomType, + typename std::remove_reference<InputUserType>::type>:: + PrepareToSerialize(std::forward<InputUserType>(input), + std::forward<Args>(args)...); +} + +template <typename MojomType, + typename InputUserType, + typename... Args, + typename std::enable_if< + !IsOptionalWrapper<InputUserType>::value>::type* = nullptr> +void Serialize(InputUserType&& input, Args&&... args) { + Serializer<MojomType, typename std::remove_reference<InputUserType>::type>:: + Serialize(std::forward<InputUserType>(input), + std::forward<Args>(args)...); +} + +template <typename MojomType, + typename DataType, + typename InputUserType, + typename... Args, + typename std::enable_if< + !IsOptionalWrapper<InputUserType>::value>::type* = nullptr> +bool Deserialize(DataType&& input, InputUserType* output, Args&&... args) { + return Serializer<MojomType, InputUserType>::Deserialize( + std::forward<DataType>(input), output, std::forward<Args>(args)...); +} + +// Specialization that unwraps base::Optional<>. +template <typename MojomType, + typename InputUserType, + typename... Args, + typename std::enable_if< + IsOptionalWrapper<InputUserType>::value>::type* = nullptr> +size_t PrepareToSerialize(InputUserType&& input, Args&&... args) { + if (!input) + return 0; + return PrepareToSerialize<MojomType>(*input, std::forward<Args>(args)...); +} + +template <typename MojomType, + typename InputUserType, + typename DataType, + typename... Args, + typename std::enable_if< + IsOptionalWrapper<InputUserType>::value>::type* = nullptr> +void Serialize(InputUserType&& input, + Buffer* buffer, + DataType** output, + Args&&... args) { + if (!input) { + *output = nullptr; + return; + } + Serialize<MojomType>(*input, buffer, output, std::forward<Args>(args)...); +} + +template <typename MojomType, + typename DataType, + typename InputUserType, + typename... Args, + typename std::enable_if< + IsOptionalWrapper<InputUserType>::value>::type* = nullptr> +bool Deserialize(DataType&& input, InputUserType* output, Args&&... args) { + if (!input) { + *output = base::nullopt; + return true; + } + if (!*output) + output->emplace(); + return Deserialize<MojomType>(std::forward<DataType>(input), &output->value(), + std::forward<Args>(args)...); +} + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_SERIALIZATION_FORWARD_H_ diff --git a/mojo/public/cpp/bindings/lib/serialization_util.h b/mojo/public/cpp/bindings/lib/serialization_util.h new file mode 100644 index 0000000000..4820a014ec --- /dev/null +++ b/mojo/public/cpp/bindings/lib/serialization_util.h @@ -0,0 +1,213 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_SERIALIZATION_UTIL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_SERIALIZATION_UTIL_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <queue> + +#include "base/logging.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_context.h" + +namespace mojo { +namespace internal { + +template <typename T> +struct HasIsNullMethod { + template <typename U> + static char Test(decltype(U::IsNull)*); + template <typename U> + static int Test(...); + static const bool value = sizeof(Test<T>(0)) == sizeof(char); + + private: + EnsureTypeIsComplete<T> check_t_; +}; + +template < + typename Traits, + typename UserType, + typename std::enable_if<HasIsNullMethod<Traits>::value>::type* = nullptr> +bool CallIsNullIfExists(const UserType& input) { + return Traits::IsNull(input); +} + +template < + typename Traits, + typename UserType, + typename std::enable_if<!HasIsNullMethod<Traits>::value>::type* = nullptr> +bool CallIsNullIfExists(const UserType& input) { + return false; +} +template <typename T> +struct HasSetToNullMethod { + template <typename U> + static char Test(decltype(U::SetToNull)*); + template <typename U> + static int Test(...); + static const bool value = sizeof(Test<T>(0)) == sizeof(char); + + private: + EnsureTypeIsComplete<T> check_t_; +}; + +template < + typename Traits, + typename UserType, + typename std::enable_if<HasSetToNullMethod<Traits>::value>::type* = nullptr> +bool CallSetToNullIfExists(UserType* output) { + Traits::SetToNull(output); + return true; +} + +template <typename Traits, + typename UserType, + typename std::enable_if<!HasSetToNullMethod<Traits>::value>::type* = + nullptr> +bool CallSetToNullIfExists(UserType* output) { + LOG(ERROR) << "A null value is received. But the Struct/Array/StringTraits " + << "class doesn't define a SetToNull() function and therefore is " + << "unable to deserialize the value."; + return false; +} + +template <typename T> +struct HasSetUpContextMethod { + template <typename U> + static char Test(decltype(U::SetUpContext)*); + template <typename U> + static int Test(...); + static const bool value = sizeof(Test<T>(0)) == sizeof(char); + + private: + EnsureTypeIsComplete<T> check_t_; +}; + +template <typename Traits, + bool has_context = HasSetUpContextMethod<Traits>::value> +struct CustomContextHelper; + +template <typename Traits> +struct CustomContextHelper<Traits, true> { + template <typename MaybeConstUserType> + static void* SetUp(MaybeConstUserType& input, SerializationContext* context) { + void* custom_context = Traits::SetUpContext(input); + if (!context->custom_contexts) + context->custom_contexts.reset(new std::queue<void*>()); + context->custom_contexts->push(custom_context); + return custom_context; + } + + static void* GetNext(SerializationContext* context) { + void* custom_context = context->custom_contexts->front(); + context->custom_contexts->pop(); + return custom_context; + } + + template <typename MaybeConstUserType> + static void TearDown(MaybeConstUserType& input, void* custom_context) { + Traits::TearDownContext(input, custom_context); + } +}; + +template <typename Traits> +struct CustomContextHelper<Traits, false> { + template <typename MaybeConstUserType> + static void* SetUp(MaybeConstUserType& input, SerializationContext* context) { + return nullptr; + } + + static void* GetNext(SerializationContext* context) { return nullptr; } + + template <typename MaybeConstUserType> + static void TearDown(MaybeConstUserType& input, void* custom_context) { + DCHECK(!custom_context); + } +}; + +template <typename ReturnType, typename ParamType, typename InputUserType> +ReturnType CallWithContext(ReturnType (*f)(ParamType, void*), + InputUserType&& input, + void* context) { + return f(std::forward<InputUserType>(input), context); +} + +template <typename ReturnType, typename ParamType, typename InputUserType> +ReturnType CallWithContext(ReturnType (*f)(ParamType), + InputUserType&& input, + void* context) { + return f(std::forward<InputUserType>(input)); +} + +template <typename T, typename MaybeConstUserType> +struct HasGetBeginMethod { + template <typename U> + static char Test(decltype(U::GetBegin(std::declval<MaybeConstUserType&>()))*); + template <typename U> + static int Test(...); + static const bool value = sizeof(Test<T>(0)) == sizeof(char); + + private: + EnsureTypeIsComplete<T> check_t_; +}; + +template < + typename Traits, + typename MaybeConstUserType, + typename std::enable_if< + HasGetBeginMethod<Traits, MaybeConstUserType>::value>::type* = nullptr> +decltype(Traits::GetBegin(std::declval<MaybeConstUserType&>())) +CallGetBeginIfExists(MaybeConstUserType& input) { + return Traits::GetBegin(input); +} + +template < + typename Traits, + typename MaybeConstUserType, + typename std::enable_if< + !HasGetBeginMethod<Traits, MaybeConstUserType>::value>::type* = nullptr> +size_t CallGetBeginIfExists(MaybeConstUserType& input) { + return 0; +} + +template <typename T, typename MaybeConstUserType> +struct HasGetDataMethod { + template <typename U> + static char Test(decltype(U::GetData(std::declval<MaybeConstUserType&>()))*); + template <typename U> + static int Test(...); + static const bool value = sizeof(Test<T>(0)) == sizeof(char); + + private: + EnsureTypeIsComplete<T> check_t_; +}; + +template < + typename Traits, + typename MaybeConstUserType, + typename std::enable_if< + HasGetDataMethod<Traits, MaybeConstUserType>::value>::type* = nullptr> +decltype(Traits::GetData(std::declval<MaybeConstUserType&>())) +CallGetDataIfExists(MaybeConstUserType& input) { + return Traits::GetData(input); +} + +template < + typename Traits, + typename MaybeConstUserType, + typename std::enable_if< + !HasGetDataMethod<Traits, MaybeConstUserType>::value>::type* = nullptr> +void* CallGetDataIfExists(MaybeConstUserType& input) { + return nullptr; +} + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_SERIALIZATION_UTIL_H_ diff --git a/mojo/public/cpp/bindings/lib/string_serialization.h b/mojo/public/cpp/bindings/lib/string_serialization.h new file mode 100644 index 0000000000..6e0c758576 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/string_serialization.h @@ -0,0 +1,70 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_STRING_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_STRING_SERIALIZATION_H_ + +#include <stddef.h> +#include <string.h> + +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" +#include "mojo/public/cpp/bindings/lib/serialization_util.h" +#include "mojo/public/cpp/bindings/string_data_view.h" +#include "mojo/public/cpp/bindings/string_traits.h" + +namespace mojo { +namespace internal { + +template <typename MaybeConstUserType> +struct Serializer<StringDataView, MaybeConstUserType> { + using UserType = typename std::remove_const<MaybeConstUserType>::type; + using Traits = StringTraits<UserType>; + + static size_t PrepareToSerialize(MaybeConstUserType& input, + SerializationContext* context) { + if (CallIsNullIfExists<Traits>(input)) + return 0; + + void* custom_context = CustomContextHelper<Traits>::SetUp(input, context); + return Align(sizeof(String_Data) + + CallWithContext(Traits::GetSize, input, custom_context)); + } + + static void Serialize(MaybeConstUserType& input, + Buffer* buffer, + String_Data** output, + SerializationContext* context) { + if (CallIsNullIfExists<Traits>(input)) { + *output = nullptr; + return; + } + + void* custom_context = CustomContextHelper<Traits>::GetNext(context); + + String_Data* result = String_Data::New( + CallWithContext(Traits::GetSize, input, custom_context), buffer); + if (result) { + memcpy(result->storage(), + CallWithContext(Traits::GetData, input, custom_context), + CallWithContext(Traits::GetSize, input, custom_context)); + } + *output = result; + + CustomContextHelper<Traits>::TearDown(input, custom_context); + } + + static bool Deserialize(String_Data* input, + UserType* output, + SerializationContext* context) { + if (!input) + return CallSetToNullIfExists<Traits>(output); + return Traits::Read(StringDataView(input, context), output); + } +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_STRING_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/string_traits_string16.cc b/mojo/public/cpp/bindings/lib/string_traits_string16.cc new file mode 100644 index 0000000000..95ff6ccf25 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/string_traits_string16.cc @@ -0,0 +1,42 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/string_traits_string16.h" + +#include <string> + +#include "base/strings/utf_string_conversions.h" + +namespace mojo { + +// static +void* StringTraits<base::string16>::SetUpContext(const base::string16& input) { + return new std::string(base::UTF16ToUTF8(input)); +} + +// static +void StringTraits<base::string16>::TearDownContext(const base::string16& input, + void* context) { + delete static_cast<std::string*>(context); +} + +// static +size_t StringTraits<base::string16>::GetSize(const base::string16& input, + void* context) { + return static_cast<std::string*>(context)->size(); +} + +// static +const char* StringTraits<base::string16>::GetData(const base::string16& input, + void* context) { + return static_cast<std::string*>(context)->data(); +} + +// static +bool StringTraits<base::string16>::Read(StringDataView input, + base::string16* output) { + return base::UTF8ToUTF16(input.storage(), input.size(), output); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/string_traits_wtf.cc b/mojo/public/cpp/bindings/lib/string_traits_wtf.cc new file mode 100644 index 0000000000..203f6f5903 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/string_traits_wtf.cc @@ -0,0 +1,84 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/string_traits_wtf.h" + +#include <string.h> + +#include "base/logging.h" +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "third_party/WebKit/Source/wtf/text/StringUTF8Adaptor.h" + +namespace mojo { +namespace { + +struct UTF8AdaptorInfo { + explicit UTF8AdaptorInfo(const WTF::String& input) : utf8_adaptor(input) { +#if DCHECK_IS_ON() + original_size_in_bytes = input.charactersSizeInBytes(); +#endif + } + + ~UTF8AdaptorInfo() {} + + WTF::StringUTF8Adaptor utf8_adaptor; + +#if DCHECK_IS_ON() + // For sanity check only. + size_t original_size_in_bytes; +#endif +}; + +UTF8AdaptorInfo* ToAdaptor(const WTF::String& input, void* context) { + UTF8AdaptorInfo* adaptor = static_cast<UTF8AdaptorInfo*>(context); + +#if DCHECK_IS_ON() + DCHECK_EQ(adaptor->original_size_in_bytes, input.charactersSizeInBytes()); +#endif + return adaptor; +} + +} // namespace + +// static +void StringTraits<WTF::String>::SetToNull(WTF::String* output) { + if (output->isNull()) + return; + + WTF::String result; + output->swap(result); +} + +// static +void* StringTraits<WTF::String>::SetUpContext(const WTF::String& input) { + return new UTF8AdaptorInfo(input); +} + +// static +void StringTraits<WTF::String>::TearDownContext(const WTF::String& input, + void* context) { + delete ToAdaptor(input, context); +} + +// static +size_t StringTraits<WTF::String>::GetSize(const WTF::String& input, + void* context) { + return ToAdaptor(input, context)->utf8_adaptor.length(); +} + +// static +const char* StringTraits<WTF::String>::GetData(const WTF::String& input, + void* context) { + return ToAdaptor(input, context)->utf8_adaptor.data(); +} + +// static +bool StringTraits<WTF::String>::Read(StringDataView input, + WTF::String* output) { + WTF::String result = WTF::String::fromUTF8(input.storage(), input.size()); + output->swap(result); + return true; +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc b/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc new file mode 100644 index 0000000000..585a8f094c --- /dev/null +++ b/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc @@ -0,0 +1,93 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/sync_call_restrictions.h" + +#if ENABLE_SYNC_CALL_RESTRICTIONS + +#include "base/debug/leak_annotations.h" +#include "base/lazy_instance.h" +#include "base/logging.h" +#include "base/threading/thread_local.h" +#include "mojo/public/c/system/core.h" + +namespace mojo { + +namespace { + +class SyncCallSettings { + public: + static SyncCallSettings* current(); + + bool allowed() const { + return scoped_allow_count_ > 0 || system_defined_value_; + } + + void IncreaseScopedAllowCount() { scoped_allow_count_++; } + void DecreaseScopedAllowCount() { + DCHECK_LT(0u, scoped_allow_count_); + scoped_allow_count_--; + } + + private: + SyncCallSettings(); + ~SyncCallSettings(); + + bool system_defined_value_ = true; + size_t scoped_allow_count_ = 0; +}; + +base::LazyInstance<base::ThreadLocalPointer<SyncCallSettings>>::DestructorAtExit + g_sync_call_settings = LAZY_INSTANCE_INITIALIZER; + +// static +SyncCallSettings* SyncCallSettings::current() { + SyncCallSettings* result = g_sync_call_settings.Pointer()->Get(); + if (!result) { + result = new SyncCallSettings(); + ANNOTATE_LEAKING_OBJECT_PTR(result); + DCHECK_EQ(result, g_sync_call_settings.Pointer()->Get()); + } + return result; +} + +SyncCallSettings::SyncCallSettings() { + MojoResult result = MojoGetProperty(MOJO_PROPERTY_TYPE_SYNC_CALL_ALLOWED, + &system_defined_value_); + DCHECK_EQ(MOJO_RESULT_OK, result); + + DCHECK(!g_sync_call_settings.Pointer()->Get()); + g_sync_call_settings.Pointer()->Set(this); +} + +SyncCallSettings::~SyncCallSettings() { + g_sync_call_settings.Pointer()->Set(nullptr); +} + +} // namespace + +// static +void SyncCallRestrictions::AssertSyncCallAllowed() { + if (!SyncCallSettings::current()->allowed()) { + LOG(FATAL) << "Mojo sync calls are not allowed in this process because " + << "they can lead to jank and deadlock. If you must make an " + << "exception, please see " + << "SyncCallRestrictions::ScopedAllowSyncCall and consult " + << "mojo/OWNERS."; + } +} + +// static +void SyncCallRestrictions::IncreaseScopedAllowCount() { + SyncCallSettings::current()->IncreaseScopedAllowCount(); +} + +// static +void SyncCallRestrictions::DecreaseScopedAllowCount() { + SyncCallSettings::current()->DecreaseScopedAllowCount(); +} + +} // namespace mojo + +#endif // ENABLE_SYNC_CALL_RESTRICTIONS diff --git a/mojo/public/cpp/bindings/lib/sync_event_watcher.cc b/mojo/public/cpp/bindings/lib/sync_event_watcher.cc new file mode 100644 index 0000000000..b1c97e3691 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/sync_event_watcher.cc @@ -0,0 +1,67 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/sync_event_watcher.h" + +#include "base/logging.h" + +namespace mojo { + +SyncEventWatcher::SyncEventWatcher(base::WaitableEvent* event, + const base::Closure& callback) + : event_(event), + callback_(callback), + registry_(SyncHandleRegistry::current()), + destroyed_(new base::RefCountedData<bool>(false)) {} + +SyncEventWatcher::~SyncEventWatcher() { + DCHECK(thread_checker_.CalledOnValidThread()); + if (registered_) + registry_->UnregisterEvent(event_); + destroyed_->data = true; +} + +void SyncEventWatcher::AllowWokenUpBySyncWatchOnSameThread() { + DCHECK(thread_checker_.CalledOnValidThread()); + IncrementRegisterCount(); +} + +bool SyncEventWatcher::SyncWatch(const bool* should_stop) { + DCHECK(thread_checker_.CalledOnValidThread()); + IncrementRegisterCount(); + if (!registered_) { + DecrementRegisterCount(); + return false; + } + + // This object may be destroyed during the Wait() call. So we have to preserve + // the boolean that Wait uses. + auto destroyed = destroyed_; + const bool* should_stop_array[] = {should_stop, &destroyed->data}; + bool result = registry_->Wait(should_stop_array, 2); + + // This object has been destroyed. + if (destroyed->data) + return false; + + DecrementRegisterCount(); + return result; +} + +void SyncEventWatcher::IncrementRegisterCount() { + register_request_count_++; + if (!registered_) + registered_ = registry_->RegisterEvent(event_, callback_); +} + +void SyncEventWatcher::DecrementRegisterCount() { + DCHECK_GT(register_request_count_, 0u); + register_request_count_--; + if (register_request_count_ == 0 && registered_) { + registry_->UnregisterEvent(event_); + registered_ = false; + } +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/sync_handle_registry.cc b/mojo/public/cpp/bindings/lib/sync_handle_registry.cc new file mode 100644 index 0000000000..fd3df396ec --- /dev/null +++ b/mojo/public/cpp/bindings/lib/sync_handle_registry.cc @@ -0,0 +1,135 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/sync_handle_registry.h" + +#include "base/lazy_instance.h" +#include "base/logging.h" +#include "base/stl_util.h" +#include "base/threading/thread_local.h" +#include "mojo/public/c/system/core.h" + +namespace mojo { +namespace { + +base::LazyInstance<base::ThreadLocalPointer<SyncHandleRegistry>>::Leaky + g_current_sync_handle_watcher = LAZY_INSTANCE_INITIALIZER; + +} // namespace + +// static +scoped_refptr<SyncHandleRegistry> SyncHandleRegistry::current() { + scoped_refptr<SyncHandleRegistry> result( + g_current_sync_handle_watcher.Pointer()->Get()); + if (!result) { + result = new SyncHandleRegistry(); + DCHECK_EQ(result.get(), g_current_sync_handle_watcher.Pointer()->Get()); + } + return result; +} + +bool SyncHandleRegistry::RegisterHandle(const Handle& handle, + MojoHandleSignals handle_signals, + const HandleCallback& callback) { + DCHECK(thread_checker_.CalledOnValidThread()); + + if (base::ContainsKey(handles_, handle)) + return false; + + MojoResult result = wait_set_.AddHandle(handle, handle_signals); + if (result != MOJO_RESULT_OK) + return false; + + handles_[handle] = callback; + return true; +} + +void SyncHandleRegistry::UnregisterHandle(const Handle& handle) { + DCHECK(thread_checker_.CalledOnValidThread()); + if (!base::ContainsKey(handles_, handle)) + return; + + MojoResult result = wait_set_.RemoveHandle(handle); + DCHECK_EQ(MOJO_RESULT_OK, result); + handles_.erase(handle); +} + +bool SyncHandleRegistry::RegisterEvent(base::WaitableEvent* event, + const base::Closure& callback) { + auto result = events_.insert({event, callback}); + DCHECK(result.second); + MojoResult rv = wait_set_.AddEvent(event); + if (rv == MOJO_RESULT_OK) + return true; + DCHECK_EQ(MOJO_RESULT_ALREADY_EXISTS, rv); + return false; +} + +void SyncHandleRegistry::UnregisterEvent(base::WaitableEvent* event) { + auto it = events_.find(event); + DCHECK(it != events_.end()); + events_.erase(it); + MojoResult rv = wait_set_.RemoveEvent(event); + DCHECK_EQ(MOJO_RESULT_OK, rv); +} + +bool SyncHandleRegistry::Wait(const bool* should_stop[], size_t count) { + DCHECK(thread_checker_.CalledOnValidThread()); + + size_t num_ready_handles; + Handle ready_handle; + MojoResult ready_handle_result; + + scoped_refptr<SyncHandleRegistry> preserver(this); + while (true) { + for (size_t i = 0; i < count; ++i) + if (*should_stop[i]) + return true; + + // TODO(yzshen): Theoretically it can reduce sync call re-entrancy if we + // give priority to the handle that is waiting for sync response. + base::WaitableEvent* ready_event = nullptr; + num_ready_handles = 1; + wait_set_.Wait(&ready_event, &num_ready_handles, &ready_handle, + &ready_handle_result); + if (num_ready_handles) { + DCHECK_EQ(1u, num_ready_handles); + const auto iter = handles_.find(ready_handle); + iter->second.Run(ready_handle_result); + } + + if (ready_event) { + const auto iter = events_.find(ready_event); + DCHECK(iter != events_.end()); + iter->second.Run(); + } + }; + + return false; +} + +SyncHandleRegistry::SyncHandleRegistry() { + DCHECK(!g_current_sync_handle_watcher.Pointer()->Get()); + g_current_sync_handle_watcher.Pointer()->Set(this); +} + +SyncHandleRegistry::~SyncHandleRegistry() { + DCHECK(thread_checker_.CalledOnValidThread()); + + // This object may be destructed after the thread local storage slot used by + // |g_current_sync_handle_watcher| is reset during thread shutdown. + // For example, another slot in the thread local storage holds a referrence to + // this object, and that slot is cleaned up after + // |g_current_sync_handle_watcher|. + if (!g_current_sync_handle_watcher.Pointer()->Get()) + return; + + // If this breaks, it is likely that the global variable is bulit into and + // accessed from multiple modules. + DCHECK_EQ(this, g_current_sync_handle_watcher.Pointer()->Get()); + + g_current_sync_handle_watcher.Pointer()->Set(nullptr); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc b/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc new file mode 100644 index 0000000000..f20af56b20 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc @@ -0,0 +1,76 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/sync_handle_watcher.h" + +#include "base/logging.h" + +namespace mojo { + +SyncHandleWatcher::SyncHandleWatcher( + const Handle& handle, + MojoHandleSignals handle_signals, + const SyncHandleRegistry::HandleCallback& callback) + : handle_(handle), + handle_signals_(handle_signals), + callback_(callback), + registered_(false), + register_request_count_(0), + registry_(SyncHandleRegistry::current()), + destroyed_(new base::RefCountedData<bool>(false)) {} + +SyncHandleWatcher::~SyncHandleWatcher() { + DCHECK(thread_checker_.CalledOnValidThread()); + if (registered_) + registry_->UnregisterHandle(handle_); + + destroyed_->data = true; +} + +void SyncHandleWatcher::AllowWokenUpBySyncWatchOnSameThread() { + DCHECK(thread_checker_.CalledOnValidThread()); + IncrementRegisterCount(); +} + +bool SyncHandleWatcher::SyncWatch(const bool* should_stop) { + DCHECK(thread_checker_.CalledOnValidThread()); + IncrementRegisterCount(); + if (!registered_) { + DecrementRegisterCount(); + return false; + } + + // This object may be destroyed during the Wait() call. So we have to preserve + // the boolean that Wait uses. + auto destroyed = destroyed_; + const bool* should_stop_array[] = {should_stop, &destroyed->data}; + bool result = registry_->Wait(should_stop_array, 2); + + // This object has been destroyed. + if (destroyed->data) + return false; + + DecrementRegisterCount(); + return result; +} + +void SyncHandleWatcher::IncrementRegisterCount() { + register_request_count_++; + if (!registered_) { + registered_ = + registry_->RegisterHandle(handle_, handle_signals_, callback_); + } +} + +void SyncHandleWatcher::DecrementRegisterCount() { + DCHECK_GT(register_request_count_, 0u); + + register_request_count_--; + if (register_request_count_ == 0 && registered_) { + registry_->UnregisterHandle(handle_); + registered_ = false; + } +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/template_util.h b/mojo/public/cpp/bindings/lib/template_util.h new file mode 100644 index 0000000000..5151123ac0 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/template_util.h @@ -0,0 +1,120 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_TEMPLATE_UTIL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_TEMPLATE_UTIL_H_ + +#include <type_traits> + +namespace mojo { +namespace internal { + +template <class T, T v> +struct IntegralConstant { + static const T value = v; +}; + +template <class T, T v> +const T IntegralConstant<T, v>::value; + +typedef IntegralConstant<bool, true> TrueType; +typedef IntegralConstant<bool, false> FalseType; + +template <class T> +struct IsConst : FalseType {}; +template <class T> +struct IsConst<const T> : TrueType {}; + +template <class T> +struct IsPointer : FalseType {}; +template <class T> +struct IsPointer<T*> : TrueType {}; + +template <bool B, typename T = void> +struct EnableIf {}; + +template <typename T> +struct EnableIf<true, T> { + typedef T type; +}; + +// Types YesType and NoType are guaranteed such that sizeof(YesType) < +// sizeof(NoType). +typedef char YesType; + +struct NoType { + YesType dummy[2]; +}; + +// A helper template to determine if given type is non-const move-only-type, +// i.e. if a value of the given type should be passed via std::move() in a +// destructive way. +template <typename T> +struct IsMoveOnlyType { + static const bool value = std::is_constructible<T, T&&>::value && + !std::is_constructible<T, const T&>::value; +}; + +// This goop is a trick used to implement a template that can be used to +// determine if a given class is the base class of another given class. +template <typename, typename> +struct IsSame { + static bool const value = false; +}; +template <typename A> +struct IsSame<A, A> { + static bool const value = true; +}; + +template <typename T> +struct EnsureTypeIsComplete { + // sizeof() cannot be applied to incomplete types, this line will fail + // compilation if T is forward declaration. + using CheckSize = char (*)[sizeof(T)]; +}; + +template <typename Base, typename Derived> +struct IsBaseOf { + private: + static Derived* CreateDerived(); + static char(&Check(Base*))[1]; + static char(&Check(...))[2]; + + EnsureTypeIsComplete<Base> check_base_; + EnsureTypeIsComplete<Derived> check_derived_; + + public: + static bool const value = sizeof Check(CreateDerived()) == 1 && + !IsSame<Base const, void const>::value; +}; + +template <class T> +struct RemovePointer { + typedef T type; +}; +template <class T> +struct RemovePointer<T*> { + typedef T type; +}; + +template <template <typename...> class Template, typename T> +struct IsSpecializationOf : FalseType {}; + +template <template <typename...> class Template, typename... Args> +struct IsSpecializationOf<Template, Template<Args...>> : TrueType {}; + +template <bool B, typename T, typename F> +struct Conditional { + typedef T type; +}; + +template <typename T, typename F> +struct Conditional<false, T, F> { + typedef F type; +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_TEMPLATE_UTIL_H_ diff --git a/mojo/public/cpp/bindings/lib/union_accessor.h b/mojo/public/cpp/bindings/lib/union_accessor.h new file mode 100644 index 0000000000..821aede595 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/union_accessor.h @@ -0,0 +1,33 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_UNION_ACCESSOR_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_UNION_ACCESSOR_H_ + +namespace mojo { +namespace internal { + +// When serializing and deserializing Unions, it is necessary to access +// the private fields and methods of the Union. This allows us to do that +// without leaking those same fields and methods in the Union interface. +// All Union wrappers are friends of this class allowing such access. +template <typename U> +class UnionAccessor { + public: + explicit UnionAccessor(U* u) : u_(u) {} + + typename U::Union_* data() { return &(u_->data_); } + + typename U::Tag* tag() { return &(u_->tag_); } + + void SwitchActive(typename U::Tag new_tag) { u_->SwitchActive(new_tag); } + + private: + U* u_; +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_UNION_ACCESSOR_H_ diff --git a/mojo/public/cpp/bindings/lib/validate_params.h b/mojo/public/cpp/bindings/lib/validate_params.h new file mode 100644 index 0000000000..c0ee8e02a7 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/validate_params.h @@ -0,0 +1,88 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATE_PARAMS_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATE_PARAMS_H_ + +#include <stdint.h> + +#include "base/macros.h" + +namespace mojo { +namespace internal { + +class ValidationContext; + +using ValidateEnumFunc = bool (*)(int32_t, ValidationContext*); + +class ContainerValidateParams { + public: + // Validates a map. A map is validated as a pair of arrays, one for the keys + // and one for the values. Both arguments must be non-null. + // + // ContainerValidateParams takes ownership of |in_key_validate params| and + // |in_element_validate params|. + ContainerValidateParams(ContainerValidateParams* in_key_validate_params, + ContainerValidateParams* in_element_validate_params) + : key_validate_params(in_key_validate_params), + element_validate_params(in_element_validate_params) { + DCHECK(in_key_validate_params) + << "Map validate params require key validate params"; + DCHECK(in_element_validate_params) + << "Map validate params require element validate params"; + } + + // Validates an array. + // + // ContainerValidateParams takes ownership of |in_element_validate params|. + ContainerValidateParams(uint32_t in_expected_num_elements, + bool in_element_is_nullable, + ContainerValidateParams* in_element_validate_params) + : expected_num_elements(in_expected_num_elements), + element_is_nullable(in_element_is_nullable), + element_validate_params(in_element_validate_params) {} + + // Validates an array of enums. + ContainerValidateParams(uint32_t in_expected_num_elements, + ValidateEnumFunc in_validate_enum_func) + : expected_num_elements(in_expected_num_elements), + validate_enum_func(in_validate_enum_func) {} + + ~ContainerValidateParams() { + if (element_validate_params) + delete element_validate_params; + if (key_validate_params) + delete key_validate_params; + } + + // If |expected_num_elements| is not 0, the array is expected to have exactly + // that number of elements. + uint32_t expected_num_elements = 0; + + // Whether the elements are nullable. + bool element_is_nullable = false; + + // Validation information for the map key array. May contain other + // ArrayValidateParams e.g. if the keys are strings. + ContainerValidateParams* key_validate_params = nullptr; + + // For arrays: validation information for elements. It is either a pointer to + // another instance of ArrayValidateParams (if elements are arrays or maps), + // or nullptr. + // + // For maps: validation information for the whole value array. May contain + // other ArrayValidateParams e.g. if the values are arrays or maps. + ContainerValidateParams* element_validate_params = nullptr; + + // Validation function for enum elements. + ValidateEnumFunc validate_enum_func = nullptr; + + private: + DISALLOW_COPY_AND_ASSIGN(ContainerValidateParams); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATE_PARAMS_H_ diff --git a/mojo/public/cpp/bindings/lib/validation_context.cc b/mojo/public/cpp/bindings/lib/validation_context.cc new file mode 100644 index 0000000000..ad0a3646eb --- /dev/null +++ b/mojo/public/cpp/bindings/lib/validation_context.cc @@ -0,0 +1,50 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/validation_context.h" + +#include "base/logging.h" + +namespace mojo { +namespace internal { + +ValidationContext::ValidationContext(const void* data, + size_t data_num_bytes, + size_t num_handles, + size_t num_associated_endpoint_handles, + Message* message, + const base::StringPiece& description, + int stack_depth) + : message_(message), + description_(description), + data_begin_(reinterpret_cast<uintptr_t>(data)), + data_end_(data_begin_ + data_num_bytes), + handle_begin_(0), + handle_end_(static_cast<uint32_t>(num_handles)), + associated_endpoint_handle_begin_(0), + associated_endpoint_handle_end_( + static_cast<uint32_t>(num_associated_endpoint_handles)), + stack_depth_(stack_depth) { + // Check whether the calculation of |data_end_| or static_cast from size_t to + // uint32_t causes overflow. + // They shouldn't happen but they do, set the corresponding range to empty. + if (data_end_ < data_begin_) { + NOTREACHED(); + data_end_ = data_begin_; + } + if (handle_end_ < num_handles) { + NOTREACHED(); + handle_end_ = 0; + } + if (associated_endpoint_handle_end_ < num_associated_endpoint_handles) { + NOTREACHED(); + associated_endpoint_handle_end_ = 0; + } +} + +ValidationContext::~ValidationContext() { +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/validation_context.h b/mojo/public/cpp/bindings/lib/validation_context.h new file mode 100644 index 0000000000..ed6c6542e7 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/validation_context.h @@ -0,0 +1,169 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_CONTEXT_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_CONTEXT_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "base/compiler_specific.h" +#include "base/macros.h" +#include "base/strings/string_piece.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" + +static const int kMaxRecursionDepth = 100; + +namespace mojo { + +class Message; + +namespace internal { + +// ValidationContext is used when validating object sizes, pointers and handle +// indices in the payload of incoming messages. +class MOJO_CPP_BINDINGS_EXPORT ValidationContext { + public: + // [data, data + data_num_bytes) specifies the initial valid memory range. + // [0, num_handles) specifies the initial valid range of handle indices. + // [0, num_associated_endpoint_handles) specifies the initial valid range of + // associated endpoint handle indices. + // + // If provided, |message| and |description| provide additional information + // to use when reporting validation errors. In addition if |message| is + // provided, the MojoNotifyBadMessage API will be used to notify the system of + // such errors. + ValidationContext(const void* data, + size_t data_num_bytes, + size_t num_handles, + size_t num_associated_endpoint_handles, + Message* message = nullptr, + const base::StringPiece& description = "", + int stack_depth = 0); + + ~ValidationContext(); + + // Claims the specified memory range. + // The method succeeds if the range is valid to claim. (Please see + // the comments for IsValidRange().) + // On success, the valid memory range is shrinked to begin right after the end + // of the claimed range. + bool ClaimMemory(const void* position, uint32_t num_bytes) { + uintptr_t begin = reinterpret_cast<uintptr_t>(position); + uintptr_t end = begin + num_bytes; + + if (!InternalIsValidRange(begin, end)) + return false; + + data_begin_ = end; + return true; + } + + // Claims the specified encoded handle (which is basically a handle index). + // The method succeeds if: + // - |encoded_handle|'s value is |kEncodedInvalidHandleValue|. + // - the handle is contained inside the valid range of handle indices. In this + // case, the valid range is shinked to begin right after the claimed handle. + bool ClaimHandle(const Handle_Data& encoded_handle) { + uint32_t index = encoded_handle.value; + if (index == kEncodedInvalidHandleValue) + return true; + + if (index < handle_begin_ || index >= handle_end_) + return false; + + // |index| + 1 shouldn't overflow, because |index| is not the max value of + // uint32_t (it is less than |handle_end_|). + handle_begin_ = index + 1; + return true; + } + + // Claims the specified encoded associated endpoint handle. + // The method succeeds if: + // - |encoded_handle|'s value is |kEncodedInvalidHandleValue|. + // - the handle is contained inside the valid range of associated endpoint + // handle indices. In this case, the valid range is shinked to begin right + // after the claimed handle. + bool ClaimAssociatedEndpointHandle( + const AssociatedEndpointHandle_Data& encoded_handle) { + uint32_t index = encoded_handle.value; + if (index == kEncodedInvalidHandleValue) + return true; + + if (index < associated_endpoint_handle_begin_ || + index >= associated_endpoint_handle_end_) + return false; + + // |index| + 1 shouldn't overflow, because |index| is not the max value of + // uint32_t (it is less than |associated_endpoint_handle_end_|). + associated_endpoint_handle_begin_ = index + 1; + return true; + } + + // Returns true if the specified range is not empty, and the range is + // contained inside the valid memory range. + bool IsValidRange(const void* position, uint32_t num_bytes) const { + uintptr_t begin = reinterpret_cast<uintptr_t>(position); + uintptr_t end = begin + num_bytes; + + return InternalIsValidRange(begin, end); + } + + // This object should be created on the stack once every time we recurse down + // into a subfield during validation to make sure we don't recurse too deep + // and blow the stack. + class ScopedDepthTracker { + public: + // |ctx| must outlive this object. + explicit ScopedDepthTracker(ValidationContext* ctx) : ctx_(ctx) { + ++ctx_->stack_depth_; + } + + ~ScopedDepthTracker() { --ctx_->stack_depth_; } + + private: + ValidationContext* ctx_; + + DISALLOW_COPY_AND_ASSIGN(ScopedDepthTracker); + }; + + // Returns true if the recursion depth limit has been reached. + bool ExceedsMaxDepth() WARN_UNUSED_RESULT { + return stack_depth_ > kMaxRecursionDepth; + } + + Message* message() const { return message_; } + const base::StringPiece& description() const { return description_; } + + private: + bool InternalIsValidRange(uintptr_t begin, uintptr_t end) const { + return end > begin && begin >= data_begin_ && end <= data_end_; + } + + Message* const message_; + const base::StringPiece description_; + + // [data_begin_, data_end_) is the valid memory range. + uintptr_t data_begin_; + uintptr_t data_end_; + + // [handle_begin_, handle_end_) is the valid handle index range. + uint32_t handle_begin_; + uint32_t handle_end_; + + // [associated_endpoint_handle_begin_, associated_endpoint_handle_end_) is the + // valid associated endpoint handle index range. + uint32_t associated_endpoint_handle_begin_; + uint32_t associated_endpoint_handle_end_; + + int stack_depth_; + + DISALLOW_COPY_AND_ASSIGN(ValidationContext); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_CONTEXT_H_ diff --git a/mojo/public/cpp/bindings/lib/validation_errors.cc b/mojo/public/cpp/bindings/lib/validation_errors.cc new file mode 100644 index 0000000000..904f5e4c72 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/validation_errors.cc @@ -0,0 +1,150 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/validation_errors.h" + +#include "base/strings/stringprintf.h" +#include "mojo/public/cpp/bindings/message.h" + +namespace mojo { +namespace internal { +namespace { + +ValidationErrorObserverForTesting* g_validation_error_observer = nullptr; +SerializationWarningObserverForTesting* g_serialization_warning_observer = + nullptr; +bool g_suppress_logging = false; + +} // namespace + +const char* ValidationErrorToString(ValidationError error) { + switch (error) { + case VALIDATION_ERROR_NONE: + return "VALIDATION_ERROR_NONE"; + case VALIDATION_ERROR_MISALIGNED_OBJECT: + return "VALIDATION_ERROR_MISALIGNED_OBJECT"; + case VALIDATION_ERROR_ILLEGAL_MEMORY_RANGE: + return "VALIDATION_ERROR_ILLEGAL_MEMORY_RANGE"; + case VALIDATION_ERROR_UNEXPECTED_STRUCT_HEADER: + return "VALIDATION_ERROR_UNEXPECTED_STRUCT_HEADER"; + case VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER: + return "VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER"; + case VALIDATION_ERROR_ILLEGAL_HANDLE: + return "VALIDATION_ERROR_ILLEGAL_HANDLE"; + case VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE: + return "VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE"; + case VALIDATION_ERROR_ILLEGAL_POINTER: + return "VALIDATION_ERROR_ILLEGAL_POINTER"; + case VALIDATION_ERROR_UNEXPECTED_NULL_POINTER: + return "VALIDATION_ERROR_UNEXPECTED_NULL_POINTER"; + case VALIDATION_ERROR_ILLEGAL_INTERFACE_ID: + return "VALIDATION_ERROR_ILLEGAL_INTERFACE_ID"; + case VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID: + return "VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID"; + case VALIDATION_ERROR_MESSAGE_HEADER_INVALID_FLAGS: + return "VALIDATION_ERROR_MESSAGE_HEADER_INVALID_FLAGS"; + case VALIDATION_ERROR_MESSAGE_HEADER_MISSING_REQUEST_ID: + return "VALIDATION_ERROR_MESSAGE_HEADER_MISSING_REQUEST_ID"; + case VALIDATION_ERROR_MESSAGE_HEADER_UNKNOWN_METHOD: + return "VALIDATION_ERROR_MESSAGE_HEADER_UNKNOWN_METHOD"; + case VALIDATION_ERROR_DIFFERENT_SIZED_ARRAYS_IN_MAP: + return "VALIDATION_ERROR_DIFFERENT_SIZED_ARRAYS_IN_MAP"; + case VALIDATION_ERROR_UNKNOWN_UNION_TAG: + return "VALIDATION_ERROR_UNKNOWN_UNION_TAG"; + case VALIDATION_ERROR_UNKNOWN_ENUM_VALUE: + return "VALIDATION_ERROR_UNKNOWN_ENUM_VALUE"; + case VALIDATION_ERROR_DESERIALIZATION_FAILED: + return "VALIDATION_ERROR_DESERIALIZATION_FAILED"; + case VALIDATION_ERROR_MAX_RECURSION_DEPTH: + return "VALIDATION_ERROR_MAX_RECURSION_DEPTH"; + } + + return "Unknown error"; +} + +void ReportValidationError(ValidationContext* context, + ValidationError error, + const char* description) { + if (g_validation_error_observer) { + g_validation_error_observer->set_last_error(error); + return; + } + + if (description) { + if (!g_suppress_logging) { + LOG(ERROR) << "Invalid message: " << ValidationErrorToString(error) + << " (" << description << ")"; + } + if (context->message()) { + context->message()->NotifyBadMessage( + base::StringPrintf("Validation failed for %s [%s (%s)]", + context->description().data(), + ValidationErrorToString(error), description)); + } + } else { + if (!g_suppress_logging) + LOG(ERROR) << "Invalid message: " << ValidationErrorToString(error); + if (context->message()) { + context->message()->NotifyBadMessage( + base::StringPrintf("Validation failed for %s [%s]", + context->description().data(), + ValidationErrorToString(error))); + } + } +} + +void ReportValidationErrorForMessage( + mojo::Message* message, + ValidationError error, + const char* description) { + ValidationContext validation_context(nullptr, 0, 0, 0, message, description); + ReportValidationError(&validation_context, error); +} + +ScopedSuppressValidationErrorLoggingForTests + ::ScopedSuppressValidationErrorLoggingForTests() + : was_suppressed_(g_suppress_logging) { + g_suppress_logging = true; +} + +ScopedSuppressValidationErrorLoggingForTests + ::~ScopedSuppressValidationErrorLoggingForTests() { + g_suppress_logging = was_suppressed_; +} + +ValidationErrorObserverForTesting::ValidationErrorObserverForTesting( + const base::Closure& callback) + : last_error_(VALIDATION_ERROR_NONE), callback_(callback) { + DCHECK(!g_validation_error_observer); + g_validation_error_observer = this; +} + +ValidationErrorObserverForTesting::~ValidationErrorObserverForTesting() { + DCHECK(g_validation_error_observer == this); + g_validation_error_observer = nullptr; +} + +bool ReportSerializationWarning(ValidationError error) { + if (g_serialization_warning_observer) { + g_serialization_warning_observer->set_last_warning(error); + return true; + } + + return false; +} + +SerializationWarningObserverForTesting::SerializationWarningObserverForTesting() + : last_warning_(VALIDATION_ERROR_NONE) { + DCHECK(!g_serialization_warning_observer); + g_serialization_warning_observer = this; +} + +SerializationWarningObserverForTesting:: + ~SerializationWarningObserverForTesting() { + DCHECK(g_serialization_warning_observer == this); + g_serialization_warning_observer = nullptr; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/validation_errors.h b/mojo/public/cpp/bindings/lib/validation_errors.h new file mode 100644 index 0000000000..122418d9e3 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/validation_errors.h @@ -0,0 +1,167 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_ERRORS_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_ERRORS_H_ + +#include "base/callback.h" +#include "base/logging.h" +#include "base/macros.h" +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/validation_context.h" + +namespace mojo { + +class Message; + +namespace internal { + +enum ValidationError { + // There is no validation error. + VALIDATION_ERROR_NONE, + // An object (struct or array) is not 8-byte aligned. + VALIDATION_ERROR_MISALIGNED_OBJECT, + // An object is not contained inside the message data, or it overlaps other + // objects. + VALIDATION_ERROR_ILLEGAL_MEMORY_RANGE, + // A struct header doesn't make sense, for example: + // - |num_bytes| is smaller than the size of the struct header. + // - |num_bytes| and |version| don't match. + // TODO(yzshen): Consider splitting it into two different error codes. Because + // the former indicates someone is misbehaving badly whereas the latter could + // be due to an inappropriately-modified .mojom file. + VALIDATION_ERROR_UNEXPECTED_STRUCT_HEADER, + // An array header doesn't make sense, for example: + // - |num_bytes| is smaller than the size of the header plus the size required + // to store |num_elements| elements. + // - For fixed-size arrays, |num_elements| is different than the specified + // size. + VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER, + // An encoded handle is illegal. + VALIDATION_ERROR_ILLEGAL_HANDLE, + // A non-nullable handle field is set to invalid handle. + VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, + // An encoded pointer is illegal. + VALIDATION_ERROR_ILLEGAL_POINTER, + // A non-nullable pointer field is set to null. + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + // An interface ID is illegal. + VALIDATION_ERROR_ILLEGAL_INTERFACE_ID, + // A non-nullable interface ID field is set to invalid. + VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, + // |flags| in the message header is invalid. The flags are either + // inconsistent with one another, inconsistent with other parts of the + // message, or unexpected for the message receiver. For example the + // receiver is expecting a request message but the flags indicate that + // the message is a response message. + VALIDATION_ERROR_MESSAGE_HEADER_INVALID_FLAGS, + // |flags| in the message header indicates that a request ID is required but + // there isn't one. + VALIDATION_ERROR_MESSAGE_HEADER_MISSING_REQUEST_ID, + // The |name| field in a message header contains an unexpected value. + VALIDATION_ERROR_MESSAGE_HEADER_UNKNOWN_METHOD, + // Two parallel arrays which are supposed to represent a map have different + // lengths. + VALIDATION_ERROR_DIFFERENT_SIZED_ARRAYS_IN_MAP, + // Attempted to deserialize a tagged union with an unknown tag. + VALIDATION_ERROR_UNKNOWN_UNION_TAG, + // A value of a non-extensible enum type is unknown. + VALIDATION_ERROR_UNKNOWN_ENUM_VALUE, + // Message deserialization failure, for example due to rejection by custom + // validation logic. + VALIDATION_ERROR_DESERIALIZATION_FAILED, + // The message contains a too deeply nested value, for example a recursively + // defined field which runtime value is too large. + VALIDATION_ERROR_MAX_RECURSION_DEPTH, +}; + +MOJO_CPP_BINDINGS_EXPORT const char* ValidationErrorToString( + ValidationError error); + +MOJO_CPP_BINDINGS_EXPORT void ReportValidationError( + ValidationContext* context, + ValidationError error, + const char* description = nullptr); + +MOJO_CPP_BINDINGS_EXPORT void ReportValidationErrorForMessage( + mojo::Message* message, + ValidationError error, + const char* description = nullptr); + +// This class may be used by tests to suppress validation error logging. This is +// not thread-safe and must only be instantiated on the main thread with no +// other threads using Mojo bindings at the time of construction or destruction. +class MOJO_CPP_BINDINGS_EXPORT ScopedSuppressValidationErrorLoggingForTests { + public: + ScopedSuppressValidationErrorLoggingForTests(); + ~ScopedSuppressValidationErrorLoggingForTests(); + + private: + const bool was_suppressed_; + + DISALLOW_COPY_AND_ASSIGN(ScopedSuppressValidationErrorLoggingForTests); +}; + +// Only used by validation tests and when there is only one thread doing message +// validation. +class MOJO_CPP_BINDINGS_EXPORT ValidationErrorObserverForTesting { + public: + explicit ValidationErrorObserverForTesting(const base::Closure& callback); + ~ValidationErrorObserverForTesting(); + + ValidationError last_error() const { return last_error_; } + void set_last_error(ValidationError error) { + last_error_ = error; + callback_.Run(); + } + + private: + ValidationError last_error_; + base::Closure callback_; + + DISALLOW_COPY_AND_ASSIGN(ValidationErrorObserverForTesting); +}; + +// Used only by MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING. Don't use it directly. +// +// The function returns true if the error is recorded (by a +// SerializationWarningObserverForTesting object), false otherwise. +MOJO_CPP_BINDINGS_EXPORT bool ReportSerializationWarning(ValidationError error); + +// Only used by serialization tests and when there is only one thread doing +// message serialization. +class MOJO_CPP_BINDINGS_EXPORT SerializationWarningObserverForTesting { + public: + SerializationWarningObserverForTesting(); + ~SerializationWarningObserverForTesting(); + + ValidationError last_warning() const { return last_warning_; } + void set_last_warning(ValidationError error) { last_warning_ = error; } + + private: + ValidationError last_warning_; + + DISALLOW_COPY_AND_ASSIGN(SerializationWarningObserverForTesting); +}; + +} // namespace internal +} // namespace mojo + +// In debug build, logs a serialization warning if |condition| evaluates to +// true: +// - if there is a SerializationWarningObserverForTesting object alive, +// records |error| in it; +// - otherwise, logs a fatal-level message. +// |error| is the validation error that will be triggered by the receiver +// of the serialzation result. +// +// In non-debug build, does nothing (not even compiling |condition|). +#define MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING(condition, error, \ + description) \ + DLOG_IF(FATAL, (condition) && !ReportSerializationWarning(error)) \ + << "The outgoing message will trigger " \ + << ValidationErrorToString(error) << " at the receiving side (" \ + << description << ")."; + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_ERRORS_H_ diff --git a/mojo/public/cpp/bindings/lib/validation_util.cc b/mojo/public/cpp/bindings/lib/validation_util.cc new file mode 100644 index 0000000000..7614df5cbc --- /dev/null +++ b/mojo/public/cpp/bindings/lib/validation_util.cc @@ -0,0 +1,210 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/validation_util.h" + +#include <stdint.h> + +#include <limits> + +#include "mojo/public/cpp/bindings/lib/message_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_util.h" +#include "mojo/public/cpp/bindings/lib/validation_errors.h" +#include "mojo/public/interfaces/bindings/interface_control_messages.mojom.h" + +namespace mojo { +namespace internal { + +bool ValidateStructHeaderAndClaimMemory(const void* data, + ValidationContext* validation_context) { + if (!IsAligned(data)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MISALIGNED_OBJECT); + return false; + } + if (!validation_context->IsValidRange(data, sizeof(StructHeader))) { + ReportValidationError(validation_context, + VALIDATION_ERROR_ILLEGAL_MEMORY_RANGE); + return false; + } + + const StructHeader* header = static_cast<const StructHeader*>(data); + + if (header->num_bytes < sizeof(StructHeader)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_STRUCT_HEADER); + return false; + } + + if (!validation_context->ClaimMemory(data, header->num_bytes)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_ILLEGAL_MEMORY_RANGE); + return false; + } + + return true; +} + +bool ValidateNonInlinedUnionHeaderAndClaimMemory( + const void* data, + ValidationContext* validation_context) { + if (!IsAligned(data)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MISALIGNED_OBJECT); + return false; + } + + if (!validation_context->ClaimMemory(data, kUnionDataSize) || + *static_cast<const uint32_t*>(data) != kUnionDataSize) { + ReportValidationError(validation_context, + VALIDATION_ERROR_ILLEGAL_MEMORY_RANGE); + return false; + } + + return true; +} + +bool ValidateMessageIsRequestWithoutResponse( + const Message* message, + ValidationContext* validation_context) { + if (message->has_flag(Message::kFlagIsResponse) || + message->has_flag(Message::kFlagExpectsResponse)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MESSAGE_HEADER_INVALID_FLAGS); + return false; + } + return true; +} + +bool ValidateMessageIsRequestExpectingResponse( + const Message* message, + ValidationContext* validation_context) { + if (message->has_flag(Message::kFlagIsResponse) || + !message->has_flag(Message::kFlagExpectsResponse)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MESSAGE_HEADER_INVALID_FLAGS); + return false; + } + return true; +} + +bool ValidateMessageIsResponse(const Message* message, + ValidationContext* validation_context) { + if (message->has_flag(Message::kFlagExpectsResponse) || + !message->has_flag(Message::kFlagIsResponse)) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MESSAGE_HEADER_INVALID_FLAGS); + return false; + } + return true; +} + +bool IsHandleOrInterfaceValid(const AssociatedInterface_Data& input) { + return input.handle.is_valid(); +} + +bool IsHandleOrInterfaceValid(const AssociatedEndpointHandle_Data& input) { + return input.is_valid(); +} + +bool IsHandleOrInterfaceValid(const Interface_Data& input) { + return input.handle.is_valid(); +} + +bool IsHandleOrInterfaceValid(const Handle_Data& input) { + return input.is_valid(); +} + +bool ValidateHandleOrInterfaceNonNullable( + const AssociatedInterface_Data& input, + const char* error_message, + ValidationContext* validation_context) { + if (IsHandleOrInterfaceValid(input)) + return true; + + ReportValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, + error_message); + return false; +} + +bool ValidateHandleOrInterfaceNonNullable( + const AssociatedEndpointHandle_Data& input, + const char* error_message, + ValidationContext* validation_context) { + if (IsHandleOrInterfaceValid(input)) + return true; + + ReportValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, + error_message); + return false; +} + +bool ValidateHandleOrInterfaceNonNullable( + const Interface_Data& input, + const char* error_message, + ValidationContext* validation_context) { + if (IsHandleOrInterfaceValid(input)) + return true; + + ReportValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, + error_message); + return false; +} + +bool ValidateHandleOrInterfaceNonNullable( + const Handle_Data& input, + const char* error_message, + ValidationContext* validation_context) { + if (IsHandleOrInterfaceValid(input)) + return true; + + ReportValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, + error_message); + return false; +} + +bool ValidateHandleOrInterface(const AssociatedInterface_Data& input, + ValidationContext* validation_context) { + if (validation_context->ClaimAssociatedEndpointHandle(input.handle)) + return true; + + ReportValidationError(validation_context, + VALIDATION_ERROR_ILLEGAL_INTERFACE_ID); + return false; +} + +bool ValidateHandleOrInterface(const AssociatedEndpointHandle_Data& input, + ValidationContext* validation_context) { + if (validation_context->ClaimAssociatedEndpointHandle(input)) + return true; + + ReportValidationError(validation_context, + VALIDATION_ERROR_ILLEGAL_INTERFACE_ID); + return false; +} + +bool ValidateHandleOrInterface(const Interface_Data& input, + ValidationContext* validation_context) { + if (validation_context->ClaimHandle(input.handle)) + return true; + + ReportValidationError(validation_context, VALIDATION_ERROR_ILLEGAL_HANDLE); + return false; +} + +bool ValidateHandleOrInterface(const Handle_Data& input, + ValidationContext* validation_context) { + if (validation_context->ClaimHandle(input)) + return true; + + ReportValidationError(validation_context, VALIDATION_ERROR_ILLEGAL_HANDLE); + return false; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/validation_util.h b/mojo/public/cpp/bindings/lib/validation_util.h new file mode 100644 index 0000000000..ea5a991668 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/validation_util.h @@ -0,0 +1,206 @@ +// Copyright 2015 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_UTIL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_UTIL_H_ + +#include <stdint.h> + +#include "mojo/public/cpp/bindings/bindings_export.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_util.h" +#include "mojo/public/cpp/bindings/lib/validate_params.h" +#include "mojo/public/cpp/bindings/lib/validation_context.h" +#include "mojo/public/cpp/bindings/lib/validation_errors.h" +#include "mojo/public/cpp/bindings/message.h" + +namespace mojo { +namespace internal { + +// Checks whether decoding the pointer will overflow and produce a pointer +// smaller than |offset|. +inline bool ValidateEncodedPointer(const uint64_t* offset) { + // - Make sure |*offset| is no more than 32-bits. + // - Cast |offset| to uintptr_t so overflow behavior is well defined across + // 32-bit and 64-bit systems. + return *offset <= std::numeric_limits<uint32_t>::max() && + (reinterpret_cast<uintptr_t>(offset) + + static_cast<uint32_t>(*offset) >= + reinterpret_cast<uintptr_t>(offset)); +} + +template <typename T> +bool ValidatePointer(const Pointer<T>& input, + ValidationContext* validation_context) { + bool result = ValidateEncodedPointer(&input.offset); + if (!result) + ReportValidationError(validation_context, VALIDATION_ERROR_ILLEGAL_POINTER); + + return result; +} + +// Validates that |data| contains a valid struct header, in terms of alignment +// and size (i.e., the |num_bytes| field of the header is sufficient for storing +// the header itself). Besides, it checks that the memory range +// [data, data + num_bytes) is not marked as occupied by other objects in +// |validation_context|. On success, the memory range is marked as occupied. +// Note: Does not verify |version| or that |num_bytes| is correct for the +// claimed version. +MOJO_CPP_BINDINGS_EXPORT bool ValidateStructHeaderAndClaimMemory( + const void* data, + ValidationContext* validation_context); + +// Validates that |data| contains a valid union header, in terms of alignment +// and size. It checks that the memory range [data, data + kUnionDataSize) is +// not marked as occupied by other objects in |validation_context|. On success, +// the memory range is marked as occupied. +MOJO_CPP_BINDINGS_EXPORT bool ValidateNonInlinedUnionHeaderAndClaimMemory( + const void* data, + ValidationContext* validation_context); + +// Validates that the message is a request which doesn't expect a response. +MOJO_CPP_BINDINGS_EXPORT bool ValidateMessageIsRequestWithoutResponse( + const Message* message, + ValidationContext* validation_context); + +// Validates that the message is a request expecting a response. +MOJO_CPP_BINDINGS_EXPORT bool ValidateMessageIsRequestExpectingResponse( + const Message* message, + ValidationContext* validation_context); + +// Validates that the message is a response. +MOJO_CPP_BINDINGS_EXPORT bool ValidateMessageIsResponse( + const Message* message, + ValidationContext* validation_context); + +// Validates that the message payload is a valid struct of type ParamsType. +template <typename ParamsType> +bool ValidateMessagePayload(const Message* message, + ValidationContext* validation_context) { + return ParamsType::Validate(message->payload(), validation_context); +} + +// The following Validate.*NonNullable() functions validate that the given +// |input| is not null/invalid. +template <typename T> +bool ValidatePointerNonNullable(const T& input, + const char* error_message, + ValidationContext* validation_context) { + if (input.offset) + return true; + + ReportValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + error_message); + return false; +} + +template <typename T> +bool ValidateInlinedUnionNonNullable(const T& input, + const char* error_message, + ValidationContext* validation_context) { + if (!input.is_null()) + return true; + + ReportValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + error_message); + return false; +} + +MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( + const AssociatedInterface_Data& input); +MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( + const AssociatedEndpointHandle_Data& input); +MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( + const Interface_Data& input); +MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( + const Handle_Data& input); + +MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( + const AssociatedInterface_Data& input, + const char* error_message, + ValidationContext* validation_context); +MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( + const AssociatedEndpointHandle_Data& input, + const char* error_message, + ValidationContext* validation_context); +MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( + const Interface_Data& input, + const char* error_message, + ValidationContext* validation_context); +MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( + const Handle_Data& input, + const char* error_message, + ValidationContext* validation_context); + +template <typename T> +bool ValidateContainer(const Pointer<T>& input, + ValidationContext* validation_context, + const ContainerValidateParams* validate_params) { + ValidationContext::ScopedDepthTracker depth_tracker(validation_context); + if (validation_context->ExceedsMaxDepth()) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MAX_RECURSION_DEPTH); + return false; + } + return ValidatePointer(input, validation_context) && + T::Validate(input.Get(), validation_context, validate_params); +} + +template <typename T> +bool ValidateStruct(const Pointer<T>& input, + ValidationContext* validation_context) { + ValidationContext::ScopedDepthTracker depth_tracker(validation_context); + if (validation_context->ExceedsMaxDepth()) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MAX_RECURSION_DEPTH); + return false; + } + return ValidatePointer(input, validation_context) && + T::Validate(input.Get(), validation_context); +} + +template <typename T> +bool ValidateInlinedUnion(const T& input, + ValidationContext* validation_context) { + ValidationContext::ScopedDepthTracker depth_tracker(validation_context); + if (validation_context->ExceedsMaxDepth()) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MAX_RECURSION_DEPTH); + return false; + } + return T::Validate(&input, validation_context, true); +} + +template <typename T> +bool ValidateNonInlinedUnion(const Pointer<T>& input, + ValidationContext* validation_context) { + ValidationContext::ScopedDepthTracker depth_tracker(validation_context); + if (validation_context->ExceedsMaxDepth()) { + ReportValidationError(validation_context, + VALIDATION_ERROR_MAX_RECURSION_DEPTH); + return false; + } + return ValidatePointer(input, validation_context) && + T::Validate(input.Get(), validation_context, false); +} + +MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( + const AssociatedInterface_Data& input, + ValidationContext* validation_context); +MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( + const AssociatedEndpointHandle_Data& input, + ValidationContext* validation_context); +MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( + const Interface_Data& input, + ValidationContext* validation_context); +MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( + const Handle_Data& input, + ValidationContext* validation_context); + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_UTIL_H_ diff --git a/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h b/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h new file mode 100644 index 0000000000..cb24bc46ee --- /dev/null +++ b/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h @@ -0,0 +1,78 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_CLONE_EQUALS_UTIL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_CLONE_EQUALS_UTIL_H_ + +#include <type_traits> + +#include "mojo/public/cpp/bindings/clone_traits.h" +#include "mojo/public/cpp/bindings/lib/equals_traits.h" +#include "third_party/WebKit/Source/wtf/HashMap.h" +#include "third_party/WebKit/Source/wtf/Optional.h" +#include "third_party/WebKit/Source/wtf/Vector.h" +#include "third_party/WebKit/Source/wtf/text/WTFString.h" + +namespace mojo { + +template <typename T> +struct CloneTraits<WTF::Vector<T>, false> { + static WTF::Vector<T> Clone(const WTF::Vector<T>& input) { + WTF::Vector<T> result; + result.reserveCapacity(input.size()); + for (const auto& element : input) + result.push_back(mojo::Clone(element)); + + return result; + } +}; + +template <typename K, typename V> +struct CloneTraits<WTF::HashMap<K, V>, false> { + static WTF::HashMap<K, V> Clone(const WTF::HashMap<K, V>& input) { + WTF::HashMap<K, V> result; + auto input_end = input.end(); + for (auto it = input.begin(); it != input_end; ++it) + result.add(mojo::Clone(it->key), mojo::Clone(it->value)); + return result; + } +}; + +namespace internal { + +template <typename T> +struct EqualsTraits<WTF::Vector<T>, false> { + static bool Equals(const WTF::Vector<T>& a, const WTF::Vector<T>& b) { + if (a.size() != b.size()) + return false; + for (size_t i = 0; i < a.size(); ++i) { + if (!internal::Equals(a[i], b[i])) + return false; + } + return true; + } +}; + +template <typename K, typename V> +struct EqualsTraits<WTF::HashMap<K, V>, false> { + static bool Equals(const WTF::HashMap<K, V>& a, const WTF::HashMap<K, V>& b) { + if (a.size() != b.size()) + return false; + + auto a_end = a.end(); + auto b_end = b.end(); + + for (auto iter = a.begin(); iter != a_end; ++iter) { + auto b_iter = b.find(iter->key); + if (b_iter == b_end || !internal::Equals(iter->value, b_iter->value)) + return false; + } + return true; + } +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_CLONE_EQUALS_UTIL_H_ diff --git a/mojo/public/cpp/bindings/lib/wtf_hash_util.h b/mojo/public/cpp/bindings/lib/wtf_hash_util.h new file mode 100644 index 0000000000..cc590da67a --- /dev/null +++ b/mojo/public/cpp/bindings/lib/wtf_hash_util.h @@ -0,0 +1,132 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_HASH_UTIL_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_HASH_UTIL_H_ + +#include <type_traits> + +#include "mojo/public/cpp/bindings/lib/hash_util.h" +#include "mojo/public/cpp/bindings/struct_ptr.h" +#include "third_party/WebKit/Source/wtf/HashFunctions.h" +#include "third_party/WebKit/Source/wtf/text/StringHash.h" +#include "third_party/WebKit/Source/wtf/text/WTFString.h" + +namespace mojo { +namespace internal { + +template <typename T> +size_t WTFHashCombine(size_t seed, const T& value) { + // Based on proposal in: + // http://www.open-std.org/JTC1/SC22/WG21/docs/papers/2005/n1756.pdf + // + // TODO(tibell): We'd like to use WTF::DefaultHash instead of std::hash, but + // there is no general template specialization of DefaultHash for enums + // and there can't be an instance for bool. + return seed ^ (std::hash<T>()(value) + (seed << 6) + (seed >> 2)); +} + +template <typename T, bool has_hash_method = HasHashMethod<T>::value> +struct WTFHashTraits; + +template <typename T> +size_t WTFHash(size_t seed, const T& value); + +template <typename T> +struct WTFHashTraits<T, true> { + static size_t Hash(size_t seed, const T& value) { return value.Hash(seed); } +}; + +template <typename T> +struct WTFHashTraits<T, false> { + static size_t Hash(size_t seed, const T& value) { + return WTFHashCombine(seed, value); + } +}; + +template <> +struct WTFHashTraits<WTF::String, false> { + static size_t Hash(size_t seed, const WTF::String& value) { + return HashCombine(seed, WTF::StringHash::hash(value)); + } +}; + +template <typename T> +size_t WTFHash(size_t seed, const T& value) { + return WTFHashTraits<T>::Hash(seed, value); +} + +template <typename T> +struct StructPtrHashFn { + static unsigned hash(const StructPtr<T>& value) { + return value.Hash(kHashSeed); + } + static bool equal(const StructPtr<T>& left, const StructPtr<T>& right) { + return left.Equals(right); + } + static const bool safeToCompareToEmptyOrDeleted = false; +}; + +template <typename T> +struct InlinedStructPtrHashFn { + static unsigned hash(const InlinedStructPtr<T>& value) { + return value.Hash(kHashSeed); + } + static bool equal(const InlinedStructPtr<T>& left, + const InlinedStructPtr<T>& right) { + return left.Equals(right); + } + static const bool safeToCompareToEmptyOrDeleted = false; +}; + +} // namespace internal +} // namespace mojo + +namespace WTF { + +template <typename T> +struct DefaultHash<mojo::StructPtr<T>> { + using Hash = mojo::internal::StructPtrHashFn<T>; +}; + +template <typename T> +struct HashTraits<mojo::StructPtr<T>> + : public GenericHashTraits<mojo::StructPtr<T>> { + static const bool hasIsEmptyValueFunction = true; + static bool isEmptyValue(const mojo::StructPtr<T>& value) { + return value.is_null(); + } + static void constructDeletedValue(mojo::StructPtr<T>& slot, bool) { + mojo::internal::StructPtrWTFHelper<T>::ConstructDeletedValue(slot); + } + static bool isDeletedValue(const mojo::StructPtr<T>& value) { + return mojo::internal::StructPtrWTFHelper<T>::IsHashTableDeletedValue( + value); + } +}; + +template <typename T> +struct DefaultHash<mojo::InlinedStructPtr<T>> { + using Hash = mojo::internal::InlinedStructPtrHashFn<T>; +}; + +template <typename T> +struct HashTraits<mojo::InlinedStructPtr<T>> + : public GenericHashTraits<mojo::InlinedStructPtr<T>> { + static const bool hasIsEmptyValueFunction = true; + static bool isEmptyValue(const mojo::InlinedStructPtr<T>& value) { + return value.is_null(); + } + static void constructDeletedValue(mojo::InlinedStructPtr<T>& slot, bool) { + mojo::internal::InlinedStructPtrWTFHelper<T>::ConstructDeletedValue(slot); + } + static bool isDeletedValue(const mojo::InlinedStructPtr<T>& value) { + return mojo::internal::InlinedStructPtrWTFHelper< + T>::IsHashTableDeletedValue(value); + } +}; + +} // namespace WTF + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_HASH_UTIL_H_ diff --git a/mojo/public/cpp/bindings/lib/wtf_serialization.h b/mojo/public/cpp/bindings/lib/wtf_serialization.h new file mode 100644 index 0000000000..0f112b9143 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/wtf_serialization.h @@ -0,0 +1,12 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_SERIALIZATION_H_ + +#include "mojo/public/cpp/bindings/array_traits_wtf_vector.h" +#include "mojo/public/cpp/bindings/map_traits_wtf_hash_map.h" +#include "mojo/public/cpp/bindings/string_traits_wtf.h" + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_SERIALIZATION_H_ |