diff options
Diffstat (limited to 'mojo/public/cpp/bindings')
181 files changed, 8028 insertions, 4119 deletions
diff --git a/mojo/public/cpp/bindings/BUILD.gn b/mojo/public/cpp/bindings/BUILD.gn index bd87965fc8..27152835ac 100644 --- a/mojo/public/cpp/bindings/BUILD.gn +++ b/mojo/public/cpp/bindings/BUILD.gn @@ -2,112 +2,66 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -interfaces_bindings_gen_dir = "$root_gen_dir/mojo/public/interfaces/bindings" +import("//build/buildflag_header.gni") +import("//build/config/nacl/config.gni") +import("//tools/ipc_fuzzer/ipc_fuzzer.gni") -component("bindings") { +declare_args() { + enable_mojo_tracing = false +} + +buildflag_header("mojo_buildflags") { + header = "mojo_buildflags.h" + + flags = [ "MOJO_TRACE_ENABLED=$enable_mojo_tracing" ] +} + +# Headers and sources which generated bindings can depend upon. No need for +# other targets to depend on this directly: just use the "bindings" target. +component("bindings_base") { sources = [ - # Normally, targets should depend on the source_sets generated by mojom - # targets. However, the generated source_sets use portions of the bindings - # library. In order to avoid linker warnings about locally-defined imports - # in Windows components build, this target depends on the generated C++ - # files directly so that the EXPORT macro defintions match. - "$interfaces_bindings_gen_dir/interface_control_messages.mojom-shared-internal.h", - "$interfaces_bindings_gen_dir/interface_control_messages.mojom-shared.cc", - "$interfaces_bindings_gen_dir/interface_control_messages.mojom-shared.h", - "$interfaces_bindings_gen_dir/interface_control_messages.mojom.cc", - "$interfaces_bindings_gen_dir/interface_control_messages.mojom.h", - "$interfaces_bindings_gen_dir/pipe_control_messages.mojom-shared-internal.h", - "$interfaces_bindings_gen_dir/pipe_control_messages.mojom-shared.cc", - "$interfaces_bindings_gen_dir/pipe_control_messages.mojom-shared.h", - "$interfaces_bindings_gen_dir/pipe_control_messages.mojom.cc", - "$interfaces_bindings_gen_dir/pipe_control_messages.mojom.h", "array_data_view.h", "array_traits.h", - "array_traits_carray.h", + "array_traits_span.h", "array_traits_stl.h", - "associated_binding.h", - "associated_binding_set.h", "associated_group.h", "associated_group_controller.h", - "associated_interface_ptr.h", - "associated_interface_ptr_info.h", - "associated_interface_request.h", - "binding.h", - "binding_set.h", - "bindings_export.h", "clone_traits.h", - "connection_error_callback.h", - "connector.h", "disconnect_reason.h", - "filter_chain.h", + "enum_traits.h", + "equals_traits.h", "interface_data_view.h", - "interface_endpoint_client.h", - "interface_endpoint_controller.h", "interface_id.h", - "interface_ptr.h", - "interface_ptr_info.h", - "interface_ptr_set.h", - "interface_request.h", "lib/array_internal.cc", "lib/array_internal.h", "lib/array_serialization.h", - "lib/associated_binding.cc", "lib/associated_group.cc", "lib/associated_group_controller.cc", - "lib/associated_interface_ptr.cc", - "lib/associated_interface_ptr_state.h", - "lib/binding_state.cc", - "lib/binding_state.h", "lib/bindings_internal.h", + "lib/buffer.cc", "lib/buffer.h", - "lib/connector.cc", - "lib/control_message_handler.cc", - "lib/control_message_handler.h", - "lib/control_message_proxy.cc", - "lib/control_message_proxy.h", - "lib/equals_traits.h", - "lib/filter_chain.cc", "lib/fixed_buffer.cc", "lib/fixed_buffer.h", - "lib/handle_interface_serialization.h", + "lib/handle_serialization.h", "lib/hash_util.h", - "lib/interface_endpoint_client.cc", - "lib/interface_ptr_state.h", "lib/map_data_internal.h", "lib/map_serialization.h", "lib/may_auto_lock.h", "lib/message.cc", - "lib/message_buffer.cc", - "lib/message_buffer.h", - "lib/message_builder.cc", - "lib/message_builder.h", "lib/message_header_validator.cc", + "lib/message_internal.cc", "lib/message_internal.h", - "lib/multiplex_router.cc", - "lib/multiplex_router.h", - "lib/native_enum_data.h", - "lib/native_enum_serialization.h", - "lib/native_struct.cc", - "lib/native_struct_data.cc", - "lib/native_struct_data.h", - "lib/native_struct_serialization.cc", - "lib/native_struct_serialization.h", - "lib/pipe_control_message_handler.cc", - "lib/pipe_control_message_proxy.cc", "lib/scoped_interface_endpoint_handle.cc", "lib/serialization.h", + "lib/serialization.h", "lib/serialization_context.cc", "lib/serialization_context.h", "lib/serialization_forward.h", "lib/serialization_util.h", "lib/string_serialization.h", - "lib/string_traits_string16.cc", - "lib/sync_call_restrictions.cc", - "lib/sync_event_watcher.cc", - "lib/sync_handle_registry.cc", - "lib/sync_handle_watcher.cc", "lib/template_util.h", - "lib/union_accessor.h", + "lib/unserialized_message_context.cc", + "lib/unserialized_message_context.h", "lib/validate_params.h", "lib/validation_context.cc", "lib/validation_context.h", @@ -118,47 +72,121 @@ component("bindings") { "map.h", "map_data_view.h", "map_traits.h", + "map_traits_flat_map.h", "map_traits_stl.h", "message.h", "message_header_validator.h", - "native_enum.h", - "native_struct.h", - "native_struct_data_view.h", - "pipe_control_message_handler.h", - "pipe_control_message_handler_delegate.h", - "pipe_control_message_proxy.h", - "raw_ptr_impl_ref_traits.h", "scoped_interface_endpoint_handle.h", "string_data_view.h", "string_traits.h", "string_traits_stl.h", - "string_traits_string16.h", "string_traits_string_piece.h", + "struct_ptr.h", + "struct_traits.h", + "type_converter.h", + "union_traits.h", + ] + + defines = [ "IS_MOJO_CPP_BINDINGS_BASE_IMPL" ] + + public_deps = [ + ":mojo_buildflags", + "//base", + "//mojo/public/cpp/system", + ] + + if (enable_ipc_fuzzer) { + all_dependent_configs = [ "//tools/ipc_fuzzer:ipc_fuzzer_config" ] + } +} + +component("bindings") { + sources = [ + "associated_binding.h", + "associated_binding_set.h", + "associated_interface_ptr.h", + "associated_interface_ptr_info.h", + "associated_interface_request.h", + "binding.h", + "binding_set.h", + "bindings_export.h", + "callback_helpers.h", + "connection_error_callback.h", + "connector.h", + "filter_chain.h", + "interface_endpoint_client.h", + "interface_endpoint_controller.h", + "interface_ptr.h", + "interface_ptr_info.h", + "interface_ptr_set.h", + "interface_request.h", + "lib/associated_binding.cc", + "lib/associated_interface_ptr.cc", + "lib/associated_interface_ptr_state.cc", + "lib/associated_interface_ptr_state.h", + "lib/binding_state.cc", + "lib/binding_state.h", + "lib/connector.cc", + "lib/control_message_handler.cc", + "lib/control_message_handler.h", + "lib/control_message_proxy.cc", + "lib/control_message_proxy.h", + "lib/filter_chain.cc", + "lib/interface_endpoint_client.cc", + "lib/interface_ptr_state.cc", + "lib/interface_ptr_state.h", + "lib/interface_serialization.h", + "lib/multiplex_router.cc", + "lib/multiplex_router.h", + "lib/native_enum_data.h", + "lib/native_enum_serialization.h", + "lib/native_struct_serialization.cc", + "lib/native_struct_serialization.h", + "lib/pipe_control_message_handler.cc", + "lib/pipe_control_message_proxy.cc", + "lib/sequence_local_sync_event_watcher.cc", + "lib/sync_call_restrictions.cc", + "lib/sync_event_watcher.cc", + "lib/sync_handle_registry.cc", + "lib/sync_handle_watcher.cc", + "lib/task_runner_helper.cc", + "lib/task_runner_helper.h", + "native_enum.h", + "pipe_control_message_handler.h", + "pipe_control_message_handler_delegate.h", + "pipe_control_message_proxy.h", + "raw_ptr_impl_ref_traits.h", + "sequence_local_sync_event_watcher.h", "strong_associated_binding.h", "strong_binding.h", "strong_binding_set.h", - "struct_ptr.h", "sync_call_restrictions.h", "sync_event_watcher.h", "sync_handle_registry.h", "sync_handle_watcher.h", "thread_safe_interface_ptr.h", - "type_converter.h", - "union_traits.h", "unique_ptr_impl_ref_traits.h", ] + if (enable_ipc_fuzzer && !is_nacl_nonsfi) { + sources += [ + "lib/message_dumper.cc", + "message_dumper.h", + ] + } + public_deps = [ + ":bindings_base", ":struct_traits", "//base", + "//ipc:message_support", "//ipc:param_traits", "//mojo/public/cpp/system", + "//mojo/public/interfaces/bindings", ] deps = [ - "//base", - "//mojo/public/interfaces/bindings:bindings__generator", - "//mojo/public/interfaces/bindings:bindings_shared__generator", + "//ipc:native_handle_type_converters", ] defines = [ "MOJO_CPP_BINDINGS_IMPLEMENTATION" ] @@ -166,8 +194,13 @@ component("bindings") { source_set("struct_traits") { sources = [ + "array_traits.h", "enum_traits.h", + "lib/template_util.h", + "map_traits.h", + "string_traits.h", "struct_traits.h", + "union_traits.h", ] } @@ -186,9 +219,10 @@ if (!is_ios) { public_deps = [ ":bindings", - "//third_party/WebKit/Source/wtf", + "//third_party/blink/renderer/platform:platform_export", + "//third_party/blink/renderer/platform/wtf", ] - public_configs = [ "//third_party/WebKit/Source:config" ] + public_configs = [ "//third_party/blink/renderer:config" ] } } diff --git a/mojo/public/cpp/bindings/DEPS b/mojo/public/cpp/bindings/DEPS deleted file mode 100644 index 36eba448e8..0000000000 --- a/mojo/public/cpp/bindings/DEPS +++ /dev/null @@ -1,3 +0,0 @@ -include_rules = [ - "+third_party/WebKit/Source/wtf", -] diff --git a/mojo/public/cpp/bindings/README.md b/mojo/public/cpp/bindings/README.md index b37267a338..89e5d95b15 100644 --- a/mojo/public/cpp/bindings/README.md +++ b/mojo/public/cpp/bindings/README.md @@ -1,19 +1,23 @@ -# ![Mojo Graphic](https://goo.gl/6CdlbH) Mojo C++ Bindings API -This document is a subset of the [Mojo documentation](/mojo). +# Mojo C++ Bindings API +This document is a subset of the [Mojo documentation](/mojo/README.md). [TOC] ## Overview The Mojo C++ Bindings API leverages the -[C++ System API](/mojo/public/cpp/system) to provide a more natural set of -primitives for communicating over Mojo message pipes. Combined with generated -code from the [Mojom IDL and bindings generator](/mojo/public/tools/bindings), -users can easily connect interface clients and implementations across arbitrary -intra- and inter-process bounaries. +[C++ System API](/mojo/public/cpp/system/README.md) to provide a more natural +set of primitives for communicating over Mojo message pipes. Combined with +generated code from the +[Mojom IDL and bindings generator](/mojo/public/tools/bindings/README.md), users +can easily connect interface clients and implementations across arbitrary intra- +and inter-process bounaries. This document provides a detailed guide to bindings API usage with example code snippets. For a detailed API references please consult the headers in -[//mojo/public/cpp/bindings](https://cs.chromium.org/chromium/src/mojo/public/cpp/bindings/). +[//mojo/public/cpp/bindings](https://cs.chromium.org/chromium/src/mojo/public/cpp/bindings/README.md). + +For a simplified guide targeted at Chromium developers, see [this +link](/docs/mojo_guide.md). ## Getting Started @@ -47,6 +51,12 @@ mojom("interfaces") { } ``` +Ensure that any target that needs this interface depends on it, e.g. with a line like: + +``` + deps += [ '//services/db/public/interfaces' ] +``` + If we then build this target: ``` @@ -57,8 +67,8 @@ This will produce several generated source files, some of which are relevant to C++ bindings. Two of these files are: ``` -out/gen/services/business/public/interfaces/factory.mojom.cc -out/gen/services/business/public/interfaces/factory.mojom.h +out/gen/services/db/public/interfaces/db.mojom.cc +out/gen/services/db/public/interfaces/db.mojom.h ``` You can include the above generated header in your sources in order to use the @@ -143,58 +153,49 @@ routed to some implementation which will **bind** it. The `InterfaceRequest<T>` doesn't actually *do* anything other than hold onto a pipe endpoint and carry useful compile-time type information. -![Diagram illustrating InterfacePtr and InterfaceRequest on either end of a message pipe](https://docs.google.com/drawings/d/17d5gvErbQ6DthEBMS7I1WhCh9bz0n12pvNjydzuRfTI/pub?w=600&h=100) +![Diagram illustrating InterfacePtr and InterfaceRequest on either end of a message pipe](https://docs.google.com/drawings/d/1_Ocprq7EGgTKcSE_WlOn_RBfXcr5C3FJyIbWhwzwNX8/pub?w=608&h=100) So how do we create a strongly-typed message pipe? ### Creating Interface Pipes -One way to do this is by manually creating a pipe and binding each end: +One way to do this is by manually creating a pipe and wrapping each end with a +strongly-typed object: ``` cpp #include "sample/logger.mojom.h" mojo::MessagePipe pipe; -sample::mojom::LoggerPtr logger; -sample::mojom::LoggerRequest request; - -logger.Bind(sample::mojom::LoggerPtrInfo(std::move(pipe.handle0), 0u)); -request.Bind(std::move(pipe.handle1)); +sample::mojom::LoggerPtr logger( + sample::mojom::LoggerPtrInfo(std::move(pipe.handle0), 0)); +sample::mojom::LoggerRequest request(std::move(pipe.handle1)); ``` -That's pretty verbose, but the C++ Bindings library provides more convenient -ways to accomplish the same thing. [interface_request.h](https://cs.chromium.org/chromium/src/mojo/public/cpp/bindings/interface_request.h) +That's pretty verbose, but the C++ Bindings library provides a more convenient +way to accomplish the same thing. [interface_request.h](https://cs.chromium.org/chromium/src/mojo/public/cpp/bindings/interface_request.h) defines a `MakeRequest` function: ``` cpp sample::mojom::LoggerPtr logger; -sample::mojom::LoggerRequest request = mojo::MakeRequest(&logger); +auto request = mojo::MakeRequest(&logger); ``` -and the `InterfaceRequest<T>` constructor can also take an explicit -`InterfacePtr<T>*` output argument: - -``` cpp -sample::mojom::LoggerPtr logger; -sample::mojom::LoggerRequest request(&logger); -``` - -Both of these last two snippets are equivalent to the first one. +This second snippet is equivalent to the first one. *** note **NOTE:** In the first example above you may notice usage of the `LoggerPtrInfo` type, which is a generated alias for `mojo::InterfacePtrInfo<Logger>`. This is similar to an `InterfaceRequest<T>` in that it merely holds onto a pipe handle and cannot actually read or write messages on the pipe. Both this type and -`InterfaceRequest<T>` are safe to move freely from thread to thread, whereas a -bound `InterfacePtr<T>` is bound to a single thread. +`InterfaceRequest<T>` are safe to move freely from sequence to sequence, whereas +a bound `InterfacePtr<T>` is bound to a single sequence. An `InterfacePtr<T>` may be unbound by calling its `PassInterface()` method, which returns a new `InterfacePtrInfo<T>`. Conversely, an `InterfacePtr<T>` may bind (and thus take ownership of) an `InterfacePtrInfo<T>` so that interface calls can be made on the pipe. -The thread-bound nature of `InterfacePtr<T>` is necessary to support safe +The sequence-bound nature of `InterfacePtr<T>` is necessary to support safe dispatch of its [message responses](#Receiving-Responses) and [connection error notifications](#Connection-Errors). *** @@ -210,7 +211,7 @@ logger->Log("Hello!"); This actually writes a `Log` message to the pipe. -![Diagram illustrating a message traveling on a pipe from LoggerPtr to LoggerRequest](https://docs.google.com/a/google.com/drawings/d/1jWEc6jJIP2ed77Gg4JJ3EVC7hvnwcImNqQJywFwpT8g/pub?w=648&h=123) +![Diagram illustrating a message traveling on a pipe from LoggerPtr to LoggerRequest](https://docs.google.com/drawings/d/11vnOpNP3UBLlWg4KplQuIU3r_e1XqwDFETD-O_bV-2w/pub?w=635&h=112) But as mentioned above, `InterfaceRequest` *doesn't actually do anything*, so that message will just sit on the pipe forever. We need a way to read messages @@ -260,7 +261,7 @@ class LoggerImpl : public sample::mojom::Logger { Now we can construct a `LoggerImpl` over our pending `LoggerRequest`, and the previously queued `Log` message will be dispatched ASAP on the `LoggerImpl`'s -thread: +sequence: ``` cpp LoggerImpl impl(std::move(request)); @@ -277,7 +278,7 @@ motion by the above line of code: 3. The `Log` message is read and deserialized, causing the `Binding` to invoke the `Logger::Log` implementation on its bound `LoggerImpl`. -![Diagram illustrating the progression of binding a request, reading a pending message, and dispatching it](https://docs.google.com/drawings/d/1c73-PegT4lmjfHoxhWrHTQXRvzxgb0wdeBa35WBwZ3Q/pub?w=550&h=500) +![Diagram illustrating the progression of binding a request, reading a pending message, and dispatching it](https://docs.google.com/drawings/d/1F2VvfoOINGuNibomqeEU8KekYCtxYVFC00146CFGGQY/pub?w=550&h=500) As a result, our implementation will eventually log the client's `"Hello!"` message via `LOG(ERROR)`. @@ -314,9 +315,9 @@ class Logger { virtual void Log(const std::string& message) = 0; - using GetTailCallback = base::Callback<void(const std::string& message)>; + using GetTailCallback = base::OnceCallback<void(const std::string& message)>; - virtual void GetTail(const GetTailCallback& callback) = 0; + virtual void GetTail(GetTailCallback callback) = 0; } } // namespace mojom @@ -344,8 +345,8 @@ class LoggerImpl : public sample::mojom::Logger { lines_.push_back(message); } - void GetTail(const GetTailCallback& callback) override { - callback.Run(lines_.back()); + void GetTail(GetTailCallback callback) override { + std::move(callback).Run(lines_.back()); } private: @@ -363,7 +364,7 @@ void OnGetTail(const std::string& message) { LOG(ERROR) << "Tail was: " << message; } -logger->GetTail(base::Bind(&OnGetTail)); +logger->GetTail(base::BindOnce(&OnGetTail)); ``` Behind the scenes, the implementation-side callback is actually serializing the @@ -375,10 +376,18 @@ parameters. ### Connection Errors -If there are no remaining messages available on a pipe and the remote end has -been closed, a connection error will be triggered on the local end. Connection -errors may also be triggered by automatic forced local pipe closure due to -*e.g.* a validation error when processing a received message. +If a pipe is disconnected, both endpoints will be able to observe the connection +error (unless the disconnection is caused by closing/destroying an endpoint, in +which case that endpoint won't get such a notification). If there are remaining +incoming messages for an endpoint on disconnection, the connection error won't +be triggered until the messages are drained. + +Pipe disconnecition may be caused by: +* Mojo system-level causes: process terminated, resource exhausted, etc. +* The bindings close the pipe due to a validation error when processing a + received message. +* The peer endpoint is closed. For example, the remote side is a bound + `mojo::InterfacePtr<T>` and it is destroyed. Regardless of the underlying cause, when a connection error is encountered on a binding endpoint, that endpoint's **connection error handler** (if set) is @@ -398,7 +407,7 @@ invocation: ``` cpp sample::mojom::LoggerPtr logger; LoggerImpl impl(mojo::MakeRequest(&logger)); -impl.set_connection_error_handler(base::Bind([] { LOG(ERROR) << "Bye."; })); +impl.set_connection_error_handler(base::BindOnce([] { LOG(ERROR) << "Bye."; })); logger->Log("OK cool"); logger.reset(); // Closes the client end. ``` @@ -415,7 +424,7 @@ handler within its constructor: LoggerImpl::LoggerImpl(sample::mojom::LoggerRequest request) : binding_(this, std::move(request)) { binding_.set_connection_error_handler( - base::Bind(&LoggerImpl::OnError, base::Unretained(this))); + base::BindOnce(&LoggerImpl::OnError, base::Unretained(this))); } void LoggerImpl::OnError() { @@ -427,6 +436,39 @@ void LoggerImpl::OnError() { The use of `base::Unretained` is *safe* because the error handler will never be invoked beyond the lifetime of `binding_`, and `this` owns `binding_`. +### A Note About Endpoint Lifetime and Callbacks +Once a `mojo::InterfacePtr<T>` is destroyed, it is guaranteed that pending +callbacks as well as the connection error handler (if registered) won't be +called. + +Once a `mojo::Binding<T>` is destroyed, it is guaranteed that no more method +calls are dispatched to the implementation and the connection error handler (if +registered) won't be called. + +### Best practices for dealing with process crashes and callbacks +A common situation when calling mojo interface methods that take a callback is +that the caller wants to know if the other endpoint is torn down (e.g. because +of a crash). In that case, the consumer usually wants to know if the response +callback won't be run. There are different solutions for this problem, depending +on how the `InterfacePtr<T>` is held: +1. The consumer owns the `InterfacePtr<T>`: `set_connection_error_handler` + should be used. +2. The consumer doesn't own the `InterfacePtr<T>`: there are two helpers + depending on the behavior that the caller wants. If the caller wants to + ensure that an error handler is run, then + [**`mojo::WrapCallbackWithDropHandler`**](https://cs.chromium.org/chromium/src/mojo/public/cpp/bindings/callback_helpers.h?l=46) + should be used. If the caller wants the callback to always be run, then + [**`mojo::WrapCallbackWithDefaultInvokeIfNotRun`**](https://cs.chromium.org/chromium/src/mojo/public/cpp/bindings/callback_helpers.h?l=40) + helper should be used. With both of these helpers, usual callback care should + be followed to ensure that the callbacks don't run after the consumer is + destructed (e.g. because the owner of the `InterfacePtr<T>` outlives the + consumer). This includes using + [**`base::WeakPtr`**](https://cs.chromium.org/chromium/src/base/memory/weak_ptr.h?l=5) + or + [**`base::RefCounted`**](https://cs.chromium.org/chromium/src/base/memory/ref_counted.h?l=246). + It should also be noted that with these helpers, the callbacks could be run + synchronously while the InterfacePtr<T> is reset or destroyed. + ### A Note About Ordering As mentioned in the previous section, closing one end of a pipe will eventually @@ -457,13 +499,228 @@ pipe, but the impl-side won't notice this until it receives the sent `Log` message. Thus the `impl` above will first log our message and *then* see a connection error and break out of the run loop. +## Types + +### Enums + +[Mojom enums](/mojo/public/tools/bindings/README.md#Enumeration-Types) translate +directly to equivalent strongly-typed C++11 enum classes with `int32_t` as the +underlying type. The typename and value names are identical between Mojom and +C++. Mojo also always defines a special enumerator `kMaxValue` that shares the +value of the highest enumerator: this makes it easy to record Mojo enums in +histograms and interoperate with legacy IPC. + +For example, consider the following Mojom definition: + +```cpp +module business.mojom; + +enum Department { + kEngineering, + kMarketing, + kSales, +}; +``` + +This translates to the following C++ definition: + +```cpp +namespace business { +namespace mojom { + +enum class Department : int32_t { + kEngineering, + kMarketing, + kSales, + kMaxValue = kSales, +}; + +} // namespace mojom +} // namespace business +``` + +### Structs + +[Mojom structs](mojo/public/tools/bindings/README.md#Structs) can be used to +define logical groupings of fields into a new composite type. Every Mojom struct +elicits the generation of an identically named, representative C++ class, with +identically named public fields of corresponding C++ types, and several helpful +public methods. + +For example, consider the following Mojom struct: + +```cpp +module business.mojom; + +struct Employee { + int64 id; + string username; + Department department; +}; +``` + +This would generate a C++ class like so: + +```cpp +namespace business { +namespace mojom { + +class Employee; + +using EmployeePtr = mojo::StructPtr<Employee>; + +class Employee { + public: + // Default constructor - applies default values, potentially ones specified + // explicitly within the Mojom. + Employee(); + + // Value constructor - an explicit argument for every field in the struct, in + // lexical Mojom definition order. + Employee(int64_t id, const std::string& username, Department department); + + // Creates a new copy of this struct value + EmployeePtr Clone(); + + // Tests for equality with another struct value of the same type. + bool Equals(const Employee& other); + + // Equivalent public fields with names identical to the Mojom. + int64_t id; + std::string username; + Department department; +}; + +} // namespace mojom +} // namespace business +``` + +Note when used as a message parameter or as a field within another Mojom struct, +a `struct` type is wrapped by the move-only `mojo::StructPtr` helper, which is +roughly equivalent to a `std::unique_ptr` with some additional utility methods. +This allows struct values to be nullable and struct types to be potentially +self-referential. + +Every genereated struct class has a static `New()` method which returns a new +`mojo::StructPtr<T>` wrapping a new instance of the class constructed by +forwarding the arguments from `New`. For example: + +```cpp +mojom::EmployeePtr e1 = mojom::Employee::New(); +e1->id = 42; +e1->username = "mojo"; +e1->department = mojom::Department::kEngineering; +``` + +is equivalent to + +```cpp +auto e1 = mojom::Employee::New(42, "mojo", mojom::Department::kEngineering); +``` + +Now if we define an interface like: + +```cpp +interface EmployeeManager { + AddEmployee(Employee e); +}; +``` + +We'll get this C++ interface to implement: + +```cpp +class EmployeeManager { + public: + virtual ~EmployeManager() {} + + virtual void AddEmployee(EmployeePtr e) = 0; +}; +``` + +And we can send this message from C++ code as follows: + +```cpp +mojom::EmployeManagerPtr manager = ...; +manager->AddEmployee( + Employee::New(42, "mojo", mojom::Department::kEngineering)); + +// or +auto e = Employee::New(42, "mojo", mojom::Department::kEngineering); +manager->AddEmployee(std::move(e)); +``` + +### Unions + +Similarly to [structs](#Structs), tagged unions generate an identically named, +representative C++ class which is typically wrapped in a `mojo::StructPtr<T>`. + +Unlike structs, all generated union fields are private and must be retrieved and +manipulated using accessors. A field `foo` is accessible by `foo()` and +settable by `set_foo()`. There is also a boolean `is_foo()` for each field which +indicates whether the union is currently taking on the value of field `foo` in +exclusion to all other union fields. + +Finally, every generated union class also has a nested `Tag` enum class which +enumerates all of the named union fields. A Mojom union value's current type can +be determined by calling the `which()` method which returns a `Tag`. + +For example, consider the following Mojom definitions: + +```cpp +union Value { + int64 int_value; + float32 float_value; + string string_value; +}; + +interface Dictionary { + AddValue(string key, Value value); +}; +``` + +This generates a the following C++ interface: + +```cpp +class Value { + public: + virtual ~Value() {} + + virtual void AddValue(const std::string& key, ValuePtr value) = 0; +}; +``` + +And we can use it like so: + +```cpp +ValuePtr value = Value::New(); +value->set_int_value(42); +CHECK(value->is_int_value()); +CHECK_EQ(value->which(), Value::Tag::INT_VALUE); + +value->set_float_value(42); +CHECK(value->is_float_value()); +CHECK_EQ(value->which(), Value::Tag::FLOAT_VALUE); + +value->set_string_value("bananas"); +CHECK(value->is_string_value()); +CHECK_EQ(value->which(), Value::Tag::STRING_VALUE); +``` + +Finally, note that if a union value is not currently occupied by a given field, +attempts to access that field will DCHECK: + +```cpp +ValuePtr value = Value::New(); +value->set_int_value(42); +LOG(INFO) << "Value is " << value->string_value(); // DCHECK! +``` + ### Sending Interfaces Over Interfaces -Now we know how to create interface pipes and use their Ptr and Request -endpoints in some interesting ways. This still doesn't add up to interesting -IPC! The bread and butter of Mojo IPC is the ability to transfer interface -endpoints across other interfaces, so let's take a look at how to accomplish -that. +We know how to create interface pipes and use their Ptr and Request endpoints +in some interesting ways. This still doesn't add up to interesting IPC! The +bread and butter of Mojo IPC is the ability to transfer interface endpoints +across other interfaces, so let's take a look at how to accomplish that. #### Sending Interface Requests @@ -482,7 +739,7 @@ interface Database { ``` As noted in the -[Mojom IDL documentation](/mojo/public/tools/bindings#Primitive-Types), +[Mojom IDL documentation](/mojo/public/tools/bindings/README.md#Primitive-Types), the `Table&` syntax denotes a `Table` interface request. This corresponds precisely to the `InterfaceRequest<T>` type discussed in the sections above, and in fact the generated code for these interfaces is approximately: @@ -545,7 +802,7 @@ class DatabaseImpl : public db::mojom::Database { // db::mojom::Database: void AddTable(db::mojom::TableRequest table) { - tables_.emplace_back(base::MakeUnique<TableImpl>(std::move(table))); + tables_.emplace_back(std::make_unique<TableImpl>(std::move(table))); } private: @@ -620,7 +877,7 @@ pipes. A **strong binding** exists as a standalone object which owns its interface implementation and automatically cleans itself up when its bound interface endpoint detects an error. The -[**`MakeStrongBinding`**](https://cs.chromim.org/chromium/src//mojo/public/cpp/bindings/strong_binding.h) +[**`MakeStrongBinding`**](https://cs.chromium.org/chromium/src/mojo/public/cpp/bindings/strong_binding.h) function is used to create such a binding. . @@ -640,14 +897,14 @@ class LoggerImpl : public sample::mojom::Logger { }; db::mojom::LoggerPtr logger; -mojo::MakeStrongBinding(base::MakeUnique<DatabaseImpl>(), +mojo::MakeStrongBinding(std::make_unique<LoggerImpl>(), mojo::MakeRequest(&logger)); logger->Log("NOM NOM NOM MESSAGES"); ``` Now as long as `logger` remains open somewhere in the system, the bound -`DatabaseImpl` on the other end will remain alive. +`LoggerImpl` on the other end will remain alive. ### Binding Sets @@ -744,9 +1001,182 @@ class TableImpl : public db::mojom::Table { ## Associated Interfaces -See [this document](https://www.chromium.org/developers/design-documents/mojo/associated-interfaces). +Associated interfaces are interfaces which: -TODO: Move the above doc into the repository markdown docs. +* enable running multiple interfaces over a single message pipe while + preserving message ordering. +* make it possible for the bindings to access a single message pipe from + multiple sequences. + +### Mojom + +A new keyword `associated` is introduced for interface pointer/request +fields. For example: + +``` cpp +interface Bar {}; + +struct Qux { + associated Bar bar3; +}; + +interface Foo { + // Uses associated interface pointer. + SetBar(associated Bar bar1); + // Uses associated interface request. + GetBar(associated Bar& bar2); + // Passes a struct with associated interface pointer. + PassQux(Qux qux); + // Uses associated interface pointer in callback. + AsyncGetBar() => (associated Bar bar4); +}; +``` + +It means the interface impl/client will communicate using the same +message pipe over which the associated interface pointer/request is +passed. + +### Using associated interfaces in C++ + +When generating C++ bindings, the associated interface pointer of `Bar` is +mapped to `BarAssociatedPtrInfo` (which is an alias of +`mojo::AssociatedInterfacePtrInfo<Bar>`); associated interface request to +`BarAssociatedRequest` (which is an alias of +`mojo::AssociatedInterfaceRequest<Bar>`). + +``` cpp +// In mojom: +interface Foo { + ... + SetBar(associated Bar bar1); + GetBar(associated Bar& bar2); + ... +}; + +// In C++: +class Foo { + ... + virtual void SetBar(BarAssociatedPtrInfo bar1) = 0; + virtual void GetBar(BarAssociatedRequest bar2) = 0; + ... +}; +``` + +#### Passing associated interface requests + +Assume you have already got an `InterfacePtr<Foo> foo_ptr`, and you would like +to call `GetBar()` on it. You can do: + +``` cpp +BarAssociatedPtrInfo bar_ptr_info; +BarAssociatedRequest bar_request = MakeRequest(&bar_ptr_info); +foo_ptr->GetBar(std::move(bar_request)); + +// BarAssociatedPtr is an alias of AssociatedInterfacePtr<Bar>. +BarAssociatedPtr bar_ptr; +bar_ptr.Bind(std::move(bar_ptr_info)); +bar_ptr->DoSomething(); +``` + +First, the code creates an associated interface of type `Bar`. It looks very +similar to what you would do to setup a non-associated interface. An +important difference is that one of the two associated endpoints (either +`bar_request` or `bar_ptr_info`) must be sent over another interface. That is +how the interface is associated with an existing message pipe. + +It should be noted that you cannot call `bar_ptr->DoSomething()` before passing +`bar_request`. This is required by the FIFO-ness guarantee: at the receiver +side, when the message of `DoSomething` call arrives, we want to dispatch it to +the corresponding `AssociatedBinding<Bar>` before processing any subsequent +messages. If `bar_request` is in a subsequent message, message dispatching gets +into a deadlock. On the other hand, as soon as `bar_request` is sent, `bar_ptr` +is usable. There is no need to wait until `bar_request` is bound to an +implementation at the remote side. + +A `MakeRequest` overload which takes an `AssociatedInterfacePtr` pointer +(instead of an `AssociatedInterfacePtrInfo` pointer) is provided to make the +code a little shorter. The following code achieves the same purpose: + +``` cpp +BarAssociatedPtr bar_ptr; +foo_ptr->GetBar(MakeRequest(&bar_ptr)); +bar_ptr->DoSomething(); +``` + +The implementation of `Foo` looks like this: + +``` cpp +class FooImpl : public Foo { + ... + void GetBar(BarAssociatedRequest bar2) override { + bar_binding_.Bind(std::move(bar2)); + ... + } + ... + + Binding<Foo> foo_binding_; + AssociatedBinding<Bar> bar_binding_; +}; +``` + +In this example, `bar_binding_`'s lifespan is tied to that of `FooImpl`. But you +don't have to do that. You can, for example, pass `bar2` to another sequence to +bind to an `AssociatedBinding<Bar>` there. + +When the underlying message pipe is disconnected (e.g., `foo_ptr` or +`foo_binding_` is destroyed), all associated interface endpoints (e.g., +`bar_ptr` and `bar_binding_`) will receive a connection error. + +#### Passing associated interface pointers + +Similarly, assume you have already got an `InterfacePtr<Foo> foo_ptr`, and you +would like to call `SetBar()` on it. You can do: + +``` cpp +AssociatedBind<Bar> bar_binding(some_bar_impl); +BarAssociatedPtrInfo bar_ptr_info; +BarAssociatedRequest bar_request = MakeRequest(&bar_ptr_info); +foo_ptr->SetBar(std::move(bar_ptr_info)); +bar_binding.Bind(std::move(bar_request)); +``` + +The following code achieves the same purpose: + +``` cpp +AssociatedBind<Bar> bar_binding(some_bar_impl); +BarAssociatedPtrInfo bar_ptr_info; +bar_binding.Bind(&bar_ptr_info); +foo_ptr->SetBar(std::move(bar_ptr_info)); +``` + +### Performance considerations + +When using associated interfaces on different sequences than the master sequence +(where the master interface lives): + +* Sending messages: send happens directly on the calling sequence. So there + isn't sequence hopping. +* Receiving messages: associated interfaces bound on a different sequence from + the master interface incur an extra sequence hop during dispatch. + +Therefore, performance-wise associated interfaces are better suited for +scenarios where message receiving happens on the master sequence. + +### Testing + +Associated interfaces need to be associated with a master interface before +they can be used. This means one end of the associated interface must be sent +over one end of the master interface, or over one end of another associated +interface which itself already has a master interface. + +If you want to test an associated interface endpoint without first +associating it, you can use `mojo::MakeIsolatedRequest()`. This will create +working associated interface endpoints which are not actually associated with +anything else. + +### Read more + +* [Design: Mojo Associated Interfaces](https://docs.google.com/document/d/1nq3J_HbS-gvVfIoEhcVyxm1uY-9G_7lhD-4Kyxb1WIY/edit) ## Synchronous Calls @@ -808,9 +1238,9 @@ viral concept: if `gfx::mojom::Rect` is mapped to `gfx::Rect` anywhere, the mapping needs to apply *everywhere*. For this reason we have a few global typemap configurations defined in -[chromium_bindings_configuration.gni](https://cs.chromium.com/chromium/src/mojo/public/tools/bindings/chromium_bindings_configuration.gni) +[chromium_bindings_configuration.gni](https://cs.chromium.org/chromium/src/mojo/public/tools/bindings/chromium_bindings_configuration.gni) and -[blink_bindings_configuration.gni](https://cs.chromium.com/chromium/src/mojo/public/tools/bindings/blink_bindings_configuration.gni). These configure the two supported [variants](#Variants) of Mojom generated +[blink_bindings_configuration.gni](https://cs.chromium.org/chromium/src/mojo/public/tools/bindings/blink_bindings_configuration.gni). These configure the two supported [variants](#Variants) of Mojom generated bindings in the repository. Read more on this in the sections that follow. For now, let's take a look at how to express the mapping from `gfx::mojom::Rect` @@ -910,7 +1340,10 @@ Let's place this `geometry.typemap` file alongside our Mojom file: mojom = "//ui/gfx/geometry/mojo/geometry.mojom" public_headers = [ "//ui/gfx/geometry/rect.h" ] traits_headers = [ "//ui/gfx/geometry/mojo/geometry_struct_traits.h" ] -sources = [ "//ui/gfx/geometry/mojo/geometry_struct_traits.cc" ] +sources = [ + "//ui/gfx/geometry/mojo/geometry_struct_traits.cc", + "//ui/gfx/geometry/mojo/geometry_struct_traits.h", +] public_deps = [ "//ui/gfx/geometry" ] type_mappings = [ "gfx.mojom.Rect=gfx::Rect", @@ -928,8 +1361,9 @@ Let's look at each of the variables above: here. * `traits_headers`: Headers which contain the relevant `StructTraits` specialization(s) for any type mappings described by this file. -* `sources`: Any private implementation sources needed for the `StructTraits` - definition. +* `sources`: Any implementation sources and headers needed for the + `StructTraits` definition. These sources are compiled directly into the + generated C++ bindings target for a `mojom` file applying this typemap. * `public_deps`: Target dependencies exposed by the `public_headers` and `traits_headers`. * `deps`: Target dependencies exposed by `sources` but not already covered by @@ -952,6 +1386,10 @@ Let's look at each of the variables above: `StructTraits` definition for this type mapping must define additional `IsNull` and `SetToNull` methods. See [Specializing Nullability](#Specializing-Nullability) below. + * `force_serialize`: The typemap is incompatible with lazy serialization + (e.g. consider a typemap to a `base::StringPiece`, where retaining a + copy is unsafe). Any messages carrying the type will be forced down the + eager serailization path. Now that we have the typemap file we need to add it to a local list of typemaps @@ -966,7 +1404,7 @@ typemaps = [ And finally we can reference this file in the global default (Chromium) bindings configuration by adding it to `_typemap_imports` in -[chromium_bindings_configuration.gni](https://cs.chromium.com/chromium/src/mojo/public/tools/bindings/chromium_bindings_configuration.gni): +[chromium_bindings_configuration.gni](https://cs.chromium.org/chromium/src/mojo/public/tools/bindings/chromium_bindings_configuration.gni): ``` _typemap_imports = [ @@ -1103,6 +1541,7 @@ class StructTraits Generated `ReadFoo` methods always convert `multi_word_field_name` fields to `ReadMultiWordFieldName` methods. +<a name="Blink-Type-Mapping"></a> ### Variants By now you may have noticed that additional C++ sources are generated when a @@ -1152,8 +1591,8 @@ out/gen/sample/db.mojom-shared-internal.h ``` Including either variant's header (`db.mojom.h` or `db.mojom-blink.h`) -implicitly includes the shared header, but you have on some occasions wish to -include *only* the shared header in some instances. +implicitly includes the shared header, but may wish to include *only* the shared +header in some instances. Finally, note that for `mojom` GN targets, there is implicitly a corresponding `mojom_{variant}` target defined for any supported bindings configuration. So @@ -1175,7 +1614,7 @@ depend on `"//sample:interfaces_blink"`. ## Versioning Considerations For general documentation of versioning in the Mojom IDL see -[Versioning](/mojo/public/tools/bindings#Versioning). +[Versioning](/mojo/public/tools/bindings/README.md#Versiwoning). This section briefly discusses some C++-specific considerations relevant to versioned Mojom types. @@ -1224,6 +1663,10 @@ generates the function in the same namespace as the generated C++ enum type: inline bool IsKnownEnumValue(Department value); ``` +### Using Mojo Bindings in Chrome + +See [Converting Legacy Chrome IPC To Mojo](/ipc/README.md). + ### Additional Documentation [Calling Mojo From Blink](https://www.chromium.org/developers/design-documents/mojo/calling-mojo-from-blink) diff --git a/mojo/public/cpp/bindings/array_traits.h b/mojo/public/cpp/bindings/array_traits.h index 594b2e0789..3bf232875c 100644 --- a/mojo/public/cpp/bindings/array_traits.h +++ b/mojo/public/cpp/bindings/array_traits.h @@ -5,6 +5,8 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_ARRAY_TRAITS_H_ #define MOJO_PUBLIC_CPP_BINDINGS_ARRAY_TRAITS_H_ +#include "mojo/public/cpp/bindings/lib/template_util.h" + namespace mojo { // This must be specialized for any type |T| to be serialized/deserialized as @@ -24,6 +26,8 @@ namespace mojo { // // using ConstIterator = T::const_iterator; // // // These two methods are optional. Please see comments in struct_traits.h +// // Note that unlike with StructTraits, IsNull() is called *twice* during +// // serialization for ArrayTraits. // static bool IsNull(const Container<T>& input); // static void SetToNull(Container<T>* output); // @@ -64,7 +68,11 @@ namespace mojo { // }; // template <typename T> -struct ArrayTraits; +struct ArrayTraits { + static_assert(internal::AlwaysFalse<T>::value, + "Cannot find the mojo::ArrayTraits specialization. Did you " + "forget to include the corresponding header file?"); +}; } // namespace mojo diff --git a/mojo/public/cpp/bindings/array_traits_carray.h b/mojo/public/cpp/bindings/array_traits_carray.h deleted file mode 100644 index 3ff694b882..0000000000 --- a/mojo/public/cpp/bindings/array_traits_carray.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_ARRAY_TRAITS_CARRAY_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_ARRAY_TRAITS_CARRAY_H_ - -#include "mojo/public/cpp/bindings/array_traits.h" - -namespace mojo { - -template <typename T> -struct CArray { - CArray() : size(0), max_size(0), data(nullptr) {} - CArray(size_t size, size_t max_size, T* data) - : size(size), max_size(max_size), data(data) {} - size_t size; - const size_t max_size; - T* data; -}; - -template <typename T> -struct ConstCArray { - ConstCArray() : size(0), data(nullptr) {} - ConstCArray(size_t size, const T* data) : size(size), data(data) {} - size_t size; - const T* data; -}; - -template <typename T> -struct ArrayTraits<CArray<T>> { - using Element = T; - - static bool IsNull(const CArray<T>& input) { return !input.data; } - - static void SetToNull(CArray<T>* output) { output->data = nullptr; } - - static size_t GetSize(const CArray<T>& input) { return input.size; } - - static T* GetData(CArray<T>& input) { return input.data; } - - static const T* GetData(const CArray<T>& input) { return input.data; } - - static T& GetAt(CArray<T>& input, size_t index) { return input.data[index]; } - - static const T& GetAt(const CArray<T>& input, size_t index) { - return input.data[index]; - } - - static bool Resize(CArray<T>& input, size_t size) { - if (size > input.max_size) - return false; - - input.size = size; - return true; - } -}; - -template <typename T> -struct ArrayTraits<ConstCArray<T>> { - using Element = T; - - static bool IsNull(const ConstCArray<T>& input) { return !input.data; } - - static size_t GetSize(const ConstCArray<T>& input) { return input.size; } - - static const T* GetData(const ConstCArray<T>& input) { return input.data; } - - static const T& GetAt(const ConstCArray<T>& input, size_t index) { - return input.data[index]; - } -}; - -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_ARRAY_TRAITS_CARRAY_H_ diff --git a/mojo/public/cpp/bindings/array_traits_span.h b/mojo/public/cpp/bindings/array_traits_span.h new file mode 100644 index 0000000000..d8364030f3 --- /dev/null +++ b/mojo/public/cpp/bindings/array_traits_span.h @@ -0,0 +1,47 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_ARRAY_TRAITS_CARRAY_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_ARRAY_TRAITS_CARRAY_H_ + +#include <cstddef> + +#include "base/containers/span.h" +#include "mojo/public/cpp/bindings/array_traits.h" + +namespace mojo { + +template <typename T> +struct ArrayTraits<base::span<T>> { + using Element = T; + + // There is no concept of a null span, as it is indistinguishable from the + // empty span. + static bool IsNull(const base::span<T>& input) { return false; } + + static size_t GetSize(const base::span<T>& input) { return input.size(); } + + static T* GetData(base::span<T>& input) { return input.data(); } + + static const T* GetData(const base::span<T>& input) { return input.data(); } + + static T& GetAt(base::span<T>& input, size_t index) { + return input.data()[index]; + } + + static const T& GetAt(const base::span<T>& input, size_t index) { + return input.data()[index]; + } + + static bool Resize(base::span<T>& input, size_t size) { + if (size > input.size()) + return false; + input = input.subspan(0, size); + return true; + } +}; + +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_ARRAY_TRAITS_CARRAY_H_ diff --git a/mojo/public/cpp/bindings/array_traits_wtf_vector.h b/mojo/public/cpp/bindings/array_traits_wtf_vector.h index 6e207351fd..83d2ba3f51 100644 --- a/mojo/public/cpp/bindings/array_traits_wtf_vector.h +++ b/mojo/public/cpp/bindings/array_traits_wtf_vector.h @@ -6,37 +6,46 @@ #define MOJO_PUBLIC_CPP_BINDINGS_ARRAY_TRAITS_WTF_VECTOR_H_ #include "mojo/public/cpp/bindings/array_traits.h" -#include "third_party/WebKit/Source/wtf/Vector.h" +#include "third_party/blink/renderer/platform/wtf/vector.h" namespace mojo { -template <typename U> -struct ArrayTraits<WTF::Vector<U>> { +template <typename U, size_t InlineCapacity> +struct ArrayTraits<WTF::Vector<U, InlineCapacity>> { using Element = U; - static bool IsNull(const WTF::Vector<U>& input) { + static bool IsNull(const WTF::Vector<U, InlineCapacity>& input) { // WTF::Vector<> is always converted to non-null mojom array. return false; } - static void SetToNull(WTF::Vector<U>* output) { + static void SetToNull(WTF::Vector<U, InlineCapacity>* output) { // WTF::Vector<> doesn't support null state. Set it to empty instead. output->clear(); } - static size_t GetSize(const WTF::Vector<U>& input) { return input.size(); } + static size_t GetSize(const WTF::Vector<U, InlineCapacity>& input) { + return input.size(); + } - static U* GetData(WTF::Vector<U>& input) { return input.data(); } + static U* GetData(WTF::Vector<U, InlineCapacity>& input) { + return input.data(); + } - static const U* GetData(const WTF::Vector<U>& input) { return input.data(); } + static const U* GetData(const WTF::Vector<U, InlineCapacity>& input) { + return input.data(); + } - static U& GetAt(WTF::Vector<U>& input, size_t index) { return input[index]; } + static U& GetAt(WTF::Vector<U, InlineCapacity>& input, size_t index) { + return input[index]; + } - static const U& GetAt(const WTF::Vector<U>& input, size_t index) { + static const U& GetAt(const WTF::Vector<U, InlineCapacity>& input, + size_t index) { return input[index]; } - static bool Resize(WTF::Vector<U>& input, size_t size) { + static bool Resize(WTF::Vector<U, InlineCapacity>& input, size_t size) { input.resize(size); return true; } diff --git a/mojo/public/cpp/bindings/associated_binding.h b/mojo/public/cpp/bindings/associated_binding.h index 59411666f5..e8e0cb1e25 100644 --- a/mojo/public/cpp/bindings/associated_binding.h +++ b/mojo/public/cpp/bindings/associated_binding.h @@ -16,7 +16,6 @@ #include "base/memory/ptr_util.h" #include "base/memory/ref_counted.h" #include "base/single_thread_task_runner.h" -#include "base/threading/thread_task_runner_handle.h" #include "mojo/public/cpp/bindings/associated_interface_ptr_info.h" #include "mojo/public/cpp/bindings/associated_interface_request.h" #include "mojo/public/cpp/bindings/bindings_export.h" @@ -53,14 +52,16 @@ class MOJO_CPP_BINDINGS_EXPORT AssociatedBindingBase { // This method may only be called after this AssociatedBinding has been bound // to a message pipe. The error handler will be reset when this // AssociatedBinding is unbound or closed. - void set_connection_error_handler(const base::Closure& error_handler); + void set_connection_error_handler(base::OnceClosure error_handler); void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler); + ConnectionErrorWithReasonCallback error_handler); // Indicates whether the associated binding has been completed. bool is_bound() const { return !!endpoint_client_; } + explicit operator bool() const { return !!endpoint_client_; } + // Sends a message on the underlying message pipe and runs the current // message loop until its response is received. This can be used in tests to // verify that no message was sent on a message pipe in response to some @@ -85,9 +86,9 @@ class MOJO_CPP_BINDINGS_EXPORT AssociatedBindingBase { // base::SingleThreadTaskRunner. This task runner must belong to the same // thread. It will be used to dispatch incoming method calls and connection // error notification. It is useful when you attach multiple task runners to a -// single thread for the purposes of task scheduling. Please note that incoming -// synchrounous method calls may not be run from this task runner, when they -// reenter outgoing synchrounous calls on the same thread. +// single thread for the purposes of task scheduling. Please note that +// incoming synchronous method calls may not be run from this task runner, when +// they reenter outgoing synchronous calls on the same thread. template <typename Interface, typename ImplRefTraits = RawPtrImplRefTraits<Interface>> class AssociatedBinding : public AssociatedBindingBase { @@ -96,47 +97,26 @@ class AssociatedBinding : public AssociatedBindingBase { // Constructs an incomplete associated binding that will use the // implementation |impl|. It may be completed with a subsequent call to the - // |Bind| method. Does not take ownership of |impl|, which must outlive this - // object. - explicit AssociatedBinding(ImplPointerType impl) { stub_.set_sink(impl); } - - // Constructs a completed associated binding of |impl|. The output |ptr_info| - // should be sent by another interface. |impl| must outlive this object. - AssociatedBinding(ImplPointerType impl, - AssociatedInterfacePtrInfo<Interface>* ptr_info, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) - : AssociatedBinding(std::move(impl)) { - Bind(ptr_info, std::move(runner)); + // |Bind| method. + explicit AssociatedBinding(ImplPointerType impl) { + stub_.set_sink(std::move(impl)); } // Constructs a completed associated binding of |impl|. |impl| must outlive // the binding. - AssociatedBinding(ImplPointerType impl, - AssociatedInterfaceRequest<Interface> request, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) + AssociatedBinding( + ImplPointerType impl, + AssociatedInterfaceRequest<Interface> request, + scoped_refptr<base::SingleThreadTaskRunner> runner = nullptr) : AssociatedBinding(std::move(impl)) { Bind(std::move(request), std::move(runner)); } ~AssociatedBinding() {} - // Creates an associated inteface and sets up this object as the - // implementation side. The output |ptr_info| should be sent by another - // interface. - void Bind(AssociatedInterfacePtrInfo<Interface>* ptr_info, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { - auto request = MakeRequest(ptr_info); - ptr_info->set_version(Interface::Version_); - Bind(std::move(request), std::move(runner)); - } - // Sets up this object as the implementation side of an associated interface. void Bind(AssociatedInterfaceRequest<Interface> request, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { + scoped_refptr<base::SingleThreadTaskRunner> runner = nullptr) { BindImpl(request.PassHandle(), &stub_, base::WrapUnique(new typename Interface::RequestValidator_()), Interface::HasSyncMethods_, std::move(runner), @@ -144,22 +124,26 @@ class AssociatedBinding : public AssociatedBindingBase { } // Unbinds and returns the associated interface request so it can be - // used in another context, such as on another thread or with a different + // used in another context, such as on another sequence or with a different // implementation. Puts this object into a state where it can be rebound. AssociatedInterfaceRequest<Interface> Unbind() { DCHECK(endpoint_client_); - - AssociatedInterfaceRequest<Interface> request; - request.Bind(endpoint_client_->PassHandle()); - + AssociatedInterfaceRequest<Interface> request( + endpoint_client_->PassHandle()); endpoint_client_.reset(); - return request; } // Returns the interface implementation that was previously specified. Interface* impl() { return ImplRefTraits::GetRawPointer(&stub_.sink()); } + // Allows test code to swap the interface implementation. + ImplPointerType SwapImplForTesting(ImplPointerType new_impl) { + Interface* old_impl = impl(); + stub_.set_sink(std::move(new_impl)); + return old_impl; + } + private: typename Interface::template Stub_<ImplRefTraits> stub_; diff --git a/mojo/public/cpp/bindings/associated_group.h b/mojo/public/cpp/bindings/associated_group.h index 14e78ec3f9..1a31f904ea 100644 --- a/mojo/public/cpp/bindings/associated_group.h +++ b/mojo/public/cpp/bindings/associated_group.h @@ -6,8 +6,8 @@ #define MOJO_PUBLIC_CPP_BINDINGS_ASSOCIATED_GROUP_H_ #include "base/callback.h" +#include "base/component_export.h" #include "base/memory/ref_counted.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" namespace mojo { @@ -17,7 +17,7 @@ class AssociatedGroupController; // AssociatedGroup refers to all the interface endpoints running at one end of a // message pipe. // It is thread safe and cheap to make copies. -class MOJO_CPP_BINDINGS_EXPORT AssociatedGroup { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) AssociatedGroup { public: AssociatedGroup(); diff --git a/mojo/public/cpp/bindings/associated_group_controller.h b/mojo/public/cpp/bindings/associated_group_controller.h index d33c2776d5..386ebdf860 100644 --- a/mojo/public/cpp/bindings/associated_group_controller.h +++ b/mojo/public/cpp/bindings/associated_group_controller.h @@ -7,11 +7,11 @@ #include <memory> +#include "base/component_export.h" #include "base/macros.h" #include "base/memory/ref_counted.h" #include "base/optional.h" -#include "base/single_thread_task_runner.h" -#include "mojo/public/cpp/bindings/bindings_export.h" +#include "base/sequenced_task_runner.h" #include "mojo/public/cpp/bindings/disconnect_reason.h" #include "mojo/public/cpp/bindings/interface_id.h" #include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" @@ -23,7 +23,7 @@ class InterfaceEndpointController; // An internal interface used to manage endpoints within an associated group, // which corresponds to one end of a message pipe. -class MOJO_CPP_BINDINGS_EXPORT AssociatedGroupController +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) AssociatedGroupController : public base::RefCountedThreadSafe<AssociatedGroupController> { public: // Associates an interface with this AssociatedGroupController's message pipe. @@ -53,15 +53,15 @@ class MOJO_CPP_BINDINGS_EXPORT AssociatedGroupController // Attaches a client to the specified endpoint to send and receive messages. // The returned object is still owned by the controller. It must only be used - // on the same thread as this call, and only before the client is detached + // on the same sequence as this call, and only before the client is detached // using DetachEndpointClient(). virtual InterfaceEndpointController* AttachEndpointClient( const ScopedInterfaceEndpointHandle& handle, InterfaceEndpointClient* endpoint_client, - scoped_refptr<base::SingleThreadTaskRunner> runner) = 0; + scoped_refptr<base::SequencedTaskRunner> runner) = 0; // Detaches the client attached to the specified endpoint. It must be called - // on the same thread as the corresponding AttachEndpointClient() call. + // on the same sequence as the corresponding AttachEndpointClient() call. virtual void DetachEndpointClient( const ScopedInterfaceEndpointHandle& handle) = 0; @@ -69,6 +69,10 @@ class MOJO_CPP_BINDINGS_EXPORT AssociatedGroupController // and notifies all interfaces running on this pipe. virtual void RaiseError() = 0; + // Indicates whether or this endpoint prefers to accept outgoing messages in + // serializaed form only. + virtual bool PrefersSerializedMessages() = 0; + protected: friend class base::RefCountedThreadSafe<AssociatedGroupController>; diff --git a/mojo/public/cpp/bindings/associated_interface_ptr.h b/mojo/public/cpp/bindings/associated_interface_ptr.h index 8806a3e090..3d08001230 100644 --- a/mojo/public/cpp/bindings/associated_interface_ptr.h +++ b/mojo/public/cpp/bindings/associated_interface_ptr.h @@ -14,8 +14,7 @@ #include "base/logging.h" #include "base/macros.h" #include "base/memory/ref_counted.h" -#include "base/single_thread_task_runner.h" -#include "base/threading/thread_task_runner_handle.h" +#include "base/sequenced_task_runner.h" #include "mojo/public/cpp/bindings/associated_interface_ptr_info.h" #include "mojo/public/cpp/bindings/associated_interface_request.h" #include "mojo/public/cpp/bindings/bindings_export.h" @@ -33,6 +32,7 @@ class AssociatedInterfacePtr { public: using InterfaceType = Interface; using PtrInfoType = AssociatedInterfacePtrInfo<Interface>; + using Proxy = typename Interface::Proxy_; // Constructs an unbound AssociatedInterfacePtr. AssociatedInterfacePtr() {} @@ -42,6 +42,8 @@ class AssociatedInterfacePtr { internal_state_.Swap(&other.internal_state_); } + explicit AssociatedInterfacePtr(PtrInfoType&& info) { Bind(std::move(info)); } + AssociatedInterfacePtr& operator=(AssociatedInterfacePtr&& other) { reset(); internal_state_.Swap(&other.internal_state_); @@ -61,18 +63,17 @@ class AssociatedInterfacePtr { // Calling with an invalid |info| has the same effect as reset(). In this // case, the AssociatedInterfacePtr is not considered as bound. // - // |runner| must belong to the same thread. It will be used to dispatch all - // callbacks and connection error notification. It is useful when you attach - // multiple task runners to a single thread for the purposes of task - // scheduling. + // Optionally, |runner| is a SequencedTaskRunner bound to the current sequence + // on which all callbacks and connection error notifications will be + // dispatched. It is only useful to specify this to use a different + // SequencedTaskRunner than SequencedTaskRunnerHandle::Get(). // // NOTE: The corresponding AssociatedInterfaceRequest must be sent over // another interface before using this object to make calls. Please see the // comments of MakeRequest(AssociatedInterfacePtr<Interface>*) for more // details. void Bind(AssociatedInterfacePtrInfo<Interface> info, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { + scoped_refptr<base::SequencedTaskRunner> runner = nullptr) { reset(); if (info.is_valid()) @@ -81,11 +82,11 @@ class AssociatedInterfacePtr { bool is_bound() const { return internal_state_.is_bound(); } - Interface* get() const { return internal_state_.instance(); } + Proxy* get() const { return internal_state_.instance(); } // Functions like a pointer to Interface. Must already be bound. - Interface* operator->() const { return get(); } - Interface& operator*() const { return *get(); } + Proxy* operator->() const { return get(); } + Proxy& operator*() const { return *get(); } // Returns the version number of the interface that the remote side supports. uint32_t version() const { return internal_state_.version(); } @@ -136,18 +137,19 @@ class AssociatedInterfacePtr { // // This method may only be called after the AssociatedInterfacePtr has been // bound. - void set_connection_error_handler(const base::Closure& error_handler) { - internal_state_.set_connection_error_handler(error_handler); + void set_connection_error_handler(base::OnceClosure error_handler) { + internal_state_.set_connection_error_handler(std::move(error_handler)); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { - internal_state_.set_connection_error_with_reason_handler(error_handler); + ConnectionErrorWithReasonCallback error_handler) { + internal_state_.set_connection_error_with_reason_handler( + std::move(error_handler)); } // Unbinds and returns the associated interface pointer information which // could be used to setup an AssociatedInterfacePtr again. This method may be - // used to move the proxy to a different thread. + // used to move the proxy to a different sequence. // // It is an error to call PassInterface() while there are pending responses. // TODO: fix this restriction, it's not always obvious when there is a @@ -165,27 +167,10 @@ class AssociatedInterfacePtr { return &internal_state_; } - // Allow AssociatedInterfacePtr<> to be used in boolean expressions, but not - // implicitly convertible to a real bool (which is dangerous). - private: - // TODO(dcheng): Use an explicit conversion operator. - typedef internal::AssociatedInterfacePtrState<Interface> - AssociatedInterfacePtr::*Testable; - - public: - operator Testable() const { - return internal_state_.is_bound() ? &AssociatedInterfacePtr::internal_state_ - : nullptr; - } + // Allow AssociatedInterfacePtr<> to be used in boolean expressions. + explicit operator bool() const { return internal_state_.is_bound(); } private: - // Forbid the == and != operators explicitly, otherwise AssociatedInterfacePtr - // will be converted to Testable to do == or != comparison. - template <typename T> - bool operator==(const AssociatedInterfacePtr<T>& other) const = delete; - template <typename T> - bool operator!=(const AssociatedInterfacePtr<T>& other) const = delete; - typedef internal::AssociatedInterfacePtrState<Interface> State; mutable State internal_state_; @@ -202,8 +187,7 @@ class AssociatedInterfacePtr { template <typename Interface> AssociatedInterfaceRequest<Interface> MakeRequest( AssociatedInterfacePtr<Interface>* ptr, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { + scoped_refptr<base::SequencedTaskRunner> runner = nullptr) { AssociatedInterfacePtrInfo<Interface> ptr_info; auto request = MakeRequest(&ptr_info); ptr->Bind(std::move(ptr_info), std::move(runner)); @@ -228,9 +212,7 @@ AssociatedInterfaceRequest<Interface> MakeRequest( ptr_info->set_handle(std::move(handle0)); ptr_info->set_version(0); - AssociatedInterfaceRequest<Interface> request; - request.Bind(std::move(handle1)); - return request; + return AssociatedInterfaceRequest<Interface>(std::move(handle1)); } // Like MakeRequest() above, but it creates a dedicated message pipe. The @@ -245,17 +227,17 @@ AssociatedInterfaceRequest<Interface> MakeRequest( // * When discarding messages sent on an interface, which can be done by // discarding the returned request. template <typename Interface> -AssociatedInterfaceRequest<Interface> MakeIsolatedRequest( +AssociatedInterfaceRequest<Interface> MakeRequestAssociatedWithDedicatedPipe( AssociatedInterfacePtr<Interface>* ptr) { MessagePipe pipe; scoped_refptr<internal::MultiplexRouter> router0 = - new internal::MultiplexRouter(std::move(pipe.handle0), - internal::MultiplexRouter::MULTI_INTERFACE, - false, base::ThreadTaskRunnerHandle::Get()); + new internal::MultiplexRouter( + std::move(pipe.handle0), internal::MultiplexRouter::MULTI_INTERFACE, + false, base::SequencedTaskRunnerHandle::Get()); scoped_refptr<internal::MultiplexRouter> router1 = - new internal::MultiplexRouter(std::move(pipe.handle1), - internal::MultiplexRouter::MULTI_INTERFACE, - true, base::ThreadTaskRunnerHandle::Get()); + new internal::MultiplexRouter( + std::move(pipe.handle1), internal::MultiplexRouter::MULTI_INTERFACE, + true, base::SequencedTaskRunnerHandle::Get()); ScopedInterfaceEndpointHandle endpoint0, endpoint1; ScopedInterfaceEndpointHandle::CreatePairPendingAssociation(&endpoint0, @@ -265,17 +247,14 @@ AssociatedInterfaceRequest<Interface> MakeIsolatedRequest( ptr->Bind(AssociatedInterfacePtrInfo<Interface>(std::move(endpoint0), Interface::Version_)); - - AssociatedInterfaceRequest<Interface> request; - request.Bind(std::move(endpoint1)); - return request; + return AssociatedInterfaceRequest<Interface>(std::move(endpoint1)); } // |handle| is supposed to be the request of an associated interface. This // method associates the interface with a dedicated, disconnected message pipe. // That way, the corresponding associated interface pointer of |handle| can // safely make calls (although those calls are silently dropped). -MOJO_CPP_BINDINGS_EXPORT void GetIsolatedInterface( +MOJO_CPP_BINDINGS_EXPORT void AssociateWithDisconnectedPipe( ScopedInterfaceEndpointHandle handle); } // namespace mojo diff --git a/mojo/public/cpp/bindings/associated_interface_ptr_info.h b/mojo/public/cpp/bindings/associated_interface_ptr_info.h index 3c6ca54603..cc3f627167 100644 --- a/mojo/public/cpp/bindings/associated_interface_ptr_info.h +++ b/mojo/public/cpp/bindings/associated_interface_ptr_info.h @@ -45,6 +45,8 @@ class AssociatedInterfacePtrInfo { bool is_valid() const { return handle_.is_valid(); } + explicit operator bool() const { return handle_.is_valid(); } + ScopedInterfaceEndpointHandle PassHandle() { return std::move(handle_); } diff --git a/mojo/public/cpp/bindings/associated_interface_request.h b/mojo/public/cpp/bindings/associated_interface_request.h index c37636c9f3..0926f3df92 100644 --- a/mojo/public/cpp/bindings/associated_interface_request.h +++ b/mojo/public/cpp/bindings/associated_interface_request.h @@ -23,11 +23,15 @@ class AssociatedInterfaceRequest { AssociatedInterfaceRequest() {} AssociatedInterfaceRequest(decltype(nullptr)) {} + explicit AssociatedInterfaceRequest(ScopedInterfaceEndpointHandle handle) + : handle_(std::move(handle)) {} + // Takes the interface endpoint handle from another // AssociatedInterfaceRequest. AssociatedInterfaceRequest(AssociatedInterfaceRequest&& other) { handle_ = std::move(other.handle_); } + AssociatedInterfaceRequest& operator=(AssociatedInterfaceRequest&& other) { if (this != &other) handle_ = std::move(other.handle_); @@ -46,13 +50,9 @@ class AssociatedInterfaceRequest { // handle. bool is_pending() const { return handle_.is_valid(); } - void Bind(ScopedInterfaceEndpointHandle handle) { - handle_ = std::move(handle); - } + explicit operator bool() const { return handle_.is_valid(); } - ScopedInterfaceEndpointHandle PassHandle() { - return std::move(handle_); - } + ScopedInterfaceEndpointHandle PassHandle() { return std::move(handle_); } const ScopedInterfaceEndpointHandle& handle() const { return handle_; } @@ -75,16 +75,6 @@ class AssociatedInterfaceRequest { DISALLOW_COPY_AND_ASSIGN(AssociatedInterfaceRequest); }; -// Makes an AssociatedInterfaceRequest bound to the specified associated -// endpoint. -template <typename Interface> -AssociatedInterfaceRequest<Interface> MakeAssociatedRequest( - ScopedInterfaceEndpointHandle handle) { - AssociatedInterfaceRequest<Interface> request; - request.Bind(std::move(handle)); - return request; -} - } // namespace mojo #endif // MOJO_PUBLIC_CPP_BINDINGS_ASSOCIATED_INTERFACE_REQUEST_H_ diff --git a/mojo/public/cpp/bindings/binding.h b/mojo/public/cpp/bindings/binding.h index 88d2f4ba3e..5b119b4130 100644 --- a/mojo/public/cpp/bindings/binding.h +++ b/mojo/public/cpp/bindings/binding.h @@ -12,7 +12,6 @@ #include "base/macros.h" #include "base/memory/ref_counted.h" #include "base/single_thread_task_runner.h" -#include "base/threading/thread_task_runner_handle.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" #include "mojo/public/cpp/bindings/interface_ptr.h" #include "mojo/public/cpp/bindings/interface_ptr_info.h" @@ -28,7 +27,10 @@ class MessageReceiver; // Represents the binding of an interface implementation to a message pipe. // When the |Binding| object is destroyed, the binding between the message pipe // and the interface is torn down and the message pipe is closed, leaving the -// interface implementation in an unbound state. +// interface implementation in an unbound state. Once the |Binding| object is +// destroyed, it is guaranteed that no more method calls are dispatched to the +// implementation and the connection error handler (if registered) won't be +// called. // // Example: // @@ -55,17 +57,17 @@ class MessageReceiver; // }; // // This class is thread hostile while bound to a message pipe. All calls to this -// class must be from the thread that bound it. The interface implementation's -// methods will be called from the thread that bound this. If a Binding is not -// bound to a message pipe, it may be bound or destroyed on any thread. +// class must be from the sequence that bound it. The interface implementation's +// methods will be called from the sequence that bound this. If a Binding is not +// bound to a message pipe, it may be bound or destroyed on any sequence. // // When you bind this class to a message pipe, optionally you can specify a // base::SingleThreadTaskRunner. This task runner must belong to the same // thread. It will be used to dispatch incoming method calls and connection // error notification. It is useful when you attach multiple task runners to a -// single thread for the purposes of task scheduling. Please note that incoming -// synchrounous method calls may not be run from this task runner, when they -// reenter outgoing synchrounous calls on the same thread. +// single thread for the purposes of task scheduling. Please note that +// incoming synchrounous method calls may not be run from this task runner, when +// they reenter outgoing synchrounous calls on the same thread. template <typename Interface, typename ImplRefTraits = RawPtrImplRefTraits<Interface>> class Binding { @@ -77,85 +79,26 @@ class Binding { // Does not take ownership of |impl|, which must outlive the binding. explicit Binding(ImplPointerType impl) : internal_state_(std::move(impl)) {} - // Constructs a completed binding of message pipe |handle| to implementation - // |impl|. Does not take ownership of |impl|, which must outlive the binding. - Binding(ImplPointerType impl, - ScopedMessagePipeHandle handle, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) - : Binding(std::move(impl)) { - Bind(std::move(handle), std::move(runner)); - } - - // Constructs a completed binding of |impl| to a new message pipe, passing the - // client end to |ptr|, which takes ownership of it. The caller is expected to - // pass |ptr| on to the client of the service. Does not take ownership of any - // of the parameters. |impl| must outlive the binding. |ptr| only needs to - // last until the constructor returns. - Binding(ImplPointerType impl, - InterfacePtr<Interface>* ptr, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) - : Binding(std::move(impl)) { - Bind(ptr, std::move(runner)); - } - // Constructs a completed binding of |impl| to the message pipe endpoint in // |request|, taking ownership of the endpoint. Does not take ownership of // |impl|, which must outlive the binding. Binding(ImplPointerType impl, InterfaceRequest<Interface> request, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) + scoped_refptr<base::SingleThreadTaskRunner> runner = nullptr) : Binding(std::move(impl)) { - Bind(request.PassMessagePipe(), std::move(runner)); + Bind(std::move(request), std::move(runner)); } // Tears down the binding, closing the message pipe and leaving the interface // implementation unbound. ~Binding() {} - // Returns an InterfacePtr bound to one end of a pipe whose other end is - // bound to |this|. - InterfacePtr<Interface> CreateInterfacePtrAndBind( - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { - InterfacePtr<Interface> interface_ptr; - Bind(&interface_ptr, std::move(runner)); - return interface_ptr; - } - - // Completes a binding that was constructed with only an interface - // implementation. Takes ownership of |handle| and binds it to the previously - // specified implementation. - void Bind(ScopedMessagePipeHandle handle, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { - internal_state_.Bind(std::move(handle), std::move(runner)); - } - - // Completes a binding that was constructed with only an interface - // implementation by creating a new message pipe, binding one end of it to the - // previously specified implementation, and passing the other to |ptr|, which - // takes ownership of it. The caller is expected to pass |ptr| on to the - // eventual client of the service. Does not take ownership of |ptr|. - void Bind(InterfacePtr<Interface>* ptr, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { - MessagePipe pipe; - ptr->Bind(InterfacePtrInfo<Interface>(std::move(pipe.handle0), - Interface::Version_), - runner); - Bind(std::move(pipe.handle1), std::move(runner)); - } - // Completes a binding that was constructed with only an interface // implementation by removing the message pipe endpoint from |request| and // binding it to the previously specified implementation. void Bind(InterfaceRequest<Interface> request, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { - Bind(request.PassMessagePipe(), std::move(runner)); + scoped_refptr<base::SingleThreadTaskRunner> runner = nullptr) { + internal_state_.Bind(request.PassMessagePipe(), std::move(runner)); } // Adds a message filter to be notified of each incoming message before @@ -187,7 +130,7 @@ class Binding { internal_state_.ResumeIncomingMethodCallProcessing(); } - // Blocks the calling thread until either a call arrives on the previously + // Blocks the calling sequence until either a call arrives on the previously // bound message pipe, the deadline is exceeded, or an error occurs. Returns // true if a method was successfully read and dispatched. // @@ -209,7 +152,7 @@ class Binding { } // Unbinds the underlying pipe from this binding and returns it so it can be - // used in another context, such as on another thread or with a different + // used in another context, such as on another sequence or with a different // implementation. Put this object into a state where it can be rebound to a // new pipe. // @@ -231,15 +174,16 @@ class Binding { // This method may only be called after this Binding has been bound to a // message pipe. The error handler will be reset when this Binding is unbound // or closed. - void set_connection_error_handler(const base::Closure& error_handler) { + void set_connection_error_handler(base::OnceClosure error_handler) { DCHECK(is_bound()); - internal_state_.set_connection_error_handler(error_handler); + internal_state_.set_connection_error_handler(std::move(error_handler)); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { DCHECK(is_bound()); - internal_state_.set_connection_error_with_reason_handler(error_handler); + internal_state_.set_connection_error_with_reason_handler( + std::move(error_handler)); } // Returns the interface implementation that was previously specified. Caller @@ -250,12 +194,33 @@ class Binding { // pipe has been bound to the implementation). bool is_bound() const { return internal_state_.is_bound(); } + explicit operator bool() const { return internal_state_.is_bound(); } + // Returns the value of the handle currently bound to this Binding which can // be used to make explicit Wait/WaitMany calls. Requires that the Binding be // bound. Ownership of the handle is retained by the Binding, it is not // transferred to the caller. MessagePipeHandle handle() const { return internal_state_.handle(); } + // Reports the currently dispatching Message as bad and closes this binding. + // Note that this is only legal to call from directly within the stack frame + // of a message dispatch. If you need to do asynchronous work before you can + // determine the legitimacy of a message, use GetBadMessageCallback() and + // retain its result until you're ready to invoke or discard it. + void ReportBadMessage(const std::string& error) { + GetBadMessageCallback().Run(error); + } + + // Acquires a callback which may be run to report the currently dispatching + // Message as bad and close this binding. Note that this is only legal to call + // from directly within the stack frame of a message dispatch, but the + // returned callback may be called exactly once any time thereafter to report + // the message as bad. This may only be called once per message. The returned + // callback must be called on the Binding's own sequence. + ReportBadMessageCallback GetBadMessageCallback() { + return internal_state_.GetBadMessageCallback(); + } + // Sends a no-op message on the underlying message pipe and runs the current // message loop until its response is received. This can be used in tests to // verify that no message was sent on a message pipe in response to some @@ -265,6 +230,20 @@ class Binding { // Exposed for testing, should not generally be used. void EnableTestingMode() { internal_state_.EnableTestingMode(); } + scoped_refptr<internal::MultiplexRouter> RouterForTesting() { + return internal_state_.RouterForTesting(); + } + + // Allows test code to swap the interface implementation. + ImplPointerType SwapImplForTesting(ImplPointerType new_impl) { + return internal_state_.SwapImplForTesting(new_impl); + } + + // DO NOT USE. Exposed only for internal use and for testing. + internal::BindingState<Interface, ImplRefTraits>* internal_state() { + return &internal_state_; + } + private: internal::BindingState<Interface, ImplRefTraits> internal_state_; diff --git a/mojo/public/cpp/bindings/binding_set.h b/mojo/public/cpp/bindings/binding_set.h index 919f9c09ad..414583bbd7 100644 --- a/mojo/public/cpp/bindings/binding_set.h +++ b/mojo/public/cpp/bindings/binding_set.h @@ -71,16 +71,16 @@ class BindingSetBase { using RequestType = typename Traits::RequestType; using ImplPointerType = typename Traits::ImplPointerType; - BindingSetBase() {} + BindingSetBase() : weak_ptr_factory_(this) {} - void set_connection_error_handler(const base::Closure& error_handler) { - error_handler_ = error_handler; + void set_connection_error_handler(base::RepeatingClosure error_handler) { + error_handler_ = std::move(error_handler); error_with_reason_handler_.Reset(); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { - error_with_reason_handler_ = error_handler; + RepeatingConnectionErrorWithReasonCallback error_handler) { + error_with_reason_handler_ = std::move(error_handler); error_handler_.Reset(); } @@ -123,22 +123,24 @@ class BindingSetBase { return true; } - // Returns a proxy bound to one end of a pipe whose other end is bound to - // |this|. If |id_storage| is not null, |*id_storage| will be set to the ID - // of the added binding. - ProxyType CreateInterfacePtrAndBind(ImplPointerType impl, - BindingId* id_storage = nullptr) { - ProxyType proxy; - BindingId id = AddBinding(std::move(impl), Traits::MakeRequest(&proxy)); - if (id_storage) - *id_storage = id; - return proxy; + // Swaps the interface implementation with a different one, to allow tests + // to modify behavior. + // + // Returns the existing interface implementation to the caller. + ImplPointerType SwapImplForTesting(BindingId id, ImplPointerType new_impl) { + auto it = bindings_.find(id); + if (it == bindings_.end()) + return nullptr; + + return it->second->SwapImplForTesting(new_impl); } void CloseAllBindings() { bindings_.clear(); } bool empty() const { return bindings_.empty(); } + size_t size() const { return bindings_.size(); } + // Implementations may call this when processing a dispatched message or // error. During the extent of message or error dispatch, this will return the // context associated with the specific binding which received the message or @@ -150,9 +152,59 @@ class BindingSetBase { return *dispatch_context_; } + // Implementations may call this when processing a dispatched message or + // error. During the extent of message or error dispatch, this will return the + // BindingId of the specific binding which received the message or error. + BindingId dispatch_binding() const { + DCHECK(dispatch_context_); + return dispatch_binding_; + } + + // Reports the currently dispatching Message as bad and closes the binding the + // message was received from. Note that this is only legal to call from + // directly within the stack frame of a message dispatch. If you need to do + // asynchronous work before you can determine the legitimacy of a message, use + // GetBadMessageCallback() and retain its result until you're ready to invoke + // or discard it. + void ReportBadMessage(const std::string& error) { + GetBadMessageCallback().Run(error); + } + + // Acquires a callback which may be run to report the currently dispatching + // Message as bad and close the binding the message was received from. Note + // that this is only legal to call from directly within the stack frame of a + // message dispatch, but the returned callback may be called exactly once any + // time thereafter as long as the binding set itself hasn't been destroyed yet + // to report the message as bad. This may only be called once per message. + // The returned callback must be called on the BindingSet's own sequence. + ReportBadMessageCallback GetBadMessageCallback() { + DCHECK(dispatch_context_); + return base::BindOnce( + [](ReportBadMessageCallback error_callback, + base::WeakPtr<BindingSetBase> binding_set, BindingId binding_id, + const std::string& error) { + std::move(error_callback).Run(error); + if (binding_set) + binding_set->RemoveBinding(binding_id); + }, + mojo::GetBadMessageCallback(), weak_ptr_factory_.GetWeakPtr(), + dispatch_binding()); + } + void FlushForTesting() { + DCHECK(!is_flushing_); + is_flushing_ = true; for (auto& binding : bindings_) - binding.second->FlushForTesting(); + if (binding.second) + binding.second->FlushForTesting(); + is_flushing_ = false; + // Clean up any bindings that were destroyed. + for (auto it = bindings_.begin(); it != bindings_.end();) { + if (!it->second) + it = bindings_.erase(it); + else + ++it; + } } private: @@ -169,14 +221,17 @@ class BindingSetBase { binding_set_(binding_set), binding_id_(binding_id), context_(std::move(context)) { - if (ContextTraits::SupportsContext()) - binding_.AddFilter(base::MakeUnique<DispatchFilter>(this)); + binding_.AddFilter(std::make_unique<DispatchFilter>(this)); binding_.set_connection_error_with_reason_handler( - base::Bind(&Entry::OnConnectionError, base::Unretained(this))); + base::BindOnce(&Entry::OnConnectionError, base::Unretained(this))); } void FlushForTesting() { binding_.FlushForTesting(); } + ImplPointerType SwapImplForTesting(ImplPointerType new_impl) { + return binding_.SwapImplForTesting(new_impl); + } + private: class DispatchFilter : public MessageReceiver { public: @@ -196,14 +251,12 @@ class BindingSetBase { }; void WillDispatch() { - DCHECK(ContextTraits::SupportsContext()); - binding_set_->SetDispatchContext(&context_); + binding_set_->SetDispatchContext(&context_, binding_id_); } void OnConnectionError(uint32_t custom_reason, const std::string& description) { - if (ContextTraits::SupportsContext()) - WillDispatch(); + WillDispatch(); binding_set_->OnConnectionError(binding_id_, custom_reason, description); } @@ -215,9 +268,9 @@ class BindingSetBase { DISALLOW_COPY_AND_ASSIGN(Entry); }; - void SetDispatchContext(const Context* context) { - DCHECK(ContextTraits::SupportsContext()); + void SetDispatchContext(const Context* context, BindingId binding_id) { dispatch_context_ = context; + dispatch_binding_ = binding_id; if (!pre_dispatch_handler_.is_null()) pre_dispatch_handler_.Run(*context); } @@ -227,7 +280,7 @@ class BindingSetBase { Context context) { BindingId id = next_binding_id_++; DCHECK_GE(next_binding_id_, 0u); - auto entry = base::MakeUnique<Entry>(std::move(impl), std::move(request), + auto entry = std::make_unique<Entry>(std::move(impl), std::move(request), this, id, std::move(context)); bindings_.insert(std::make_pair(id, std::move(entry))); return id; @@ -241,20 +294,25 @@ class BindingSetBase { // We keep the Entry alive throughout error dispatch. std::unique_ptr<Entry> entry = std::move(it->second); - bindings_.erase(it); + if (!is_flushing_) + bindings_.erase(it); - if (!error_handler_.is_null()) + if (error_handler_) { error_handler_.Run(); - else if (!error_with_reason_handler_.is_null()) + } else if (error_with_reason_handler_) { error_with_reason_handler_.Run(custom_reason, description); + } } - base::Closure error_handler_; - ConnectionErrorWithReasonCallback error_with_reason_handler_; + base::RepeatingClosure error_handler_; + RepeatingConnectionErrorWithReasonCallback error_with_reason_handler_; PreDispatchCallback pre_dispatch_handler_; BindingId next_binding_id_ = 0; std::map<BindingId, std::unique_ptr<Entry>> bindings_; + bool is_flushing_ = false; const Context* dispatch_context_ = nullptr; + BindingId dispatch_binding_; + base::WeakPtrFactory<BindingSetBase> weak_ptr_factory_; DISALLOW_COPY_AND_ASSIGN(BindingSetBase); }; diff --git a/mojo/public/cpp/bindings/callback_helpers.h b/mojo/public/cpp/bindings/callback_helpers.h new file mode 100644 index 0000000000..be4d97bb18 --- /dev/null +++ b/mojo/public/cpp/bindings/callback_helpers.h @@ -0,0 +1,124 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_CALLBACK_HELPERS_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_CALLBACK_HELPERS_H_ + +#include <memory> +#include <utility> + +#include "base/bind.h" +#include "base/callback.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" + +// This is a helper utility to wrap a base::OnceCallback such that if the +// callback is destructed before it has a chance to run (e.g. the callback is +// bound into a task and the task is dropped), it will be run the with +// default arguments passed into WrapCallbackWithDefaultInvokeIfNotRun. +// Alternatively, it will run the delete closure passed to +// WrapCallbackWithDropHandler. +// +// These helpers are intended for use on the client side of a mojo interface, +// where users want to know if their individual callback was dropped (e.g. +// due to connection error). This can save the burden of tracking pending +// mojo callbacks in a map so they can be cleaned up in the interface's +// connection error callback. +// +// Caveats: +// 1) The default form of the callback, called when the original was dropped +// before running, may not run on the thread you expected. If this is a problem +// for your code, DO NOT USE these helpers. +// 2) There is no type information that indicates the wrapped object has special +// destructor behavior. It is therefore not recommended to pass these wrapped +// callbacks into deep call graphs where code readers could be confused whether +// or not the Run() mehtod should be invoked. +// +// Example: +// foo->DoWorkAndReturnResult( +// WrapCallbackWithDefaultInvokeIfNotRun( +// base::BindOnce(&Foo::OnResult, this), false)); +// +// If the callback is destructed without running, it'll be run with "false". +// +// foo->DoWorkAndReturnResult( +// WrapCallbackWithDropHandler(base::BindOnce(&Foo::OnResult, this), +// base::BindOnce(&Foo::LogError, this, WAS_DROPPED))); + +namespace mojo { +namespace internal { + +// First, tell the compiler CallbackWithDeleteHelper is a class template with +// one type parameter. Then define specializations where the type is a function +// returning void and taking zero or more arguments. +template <typename Signature> +class CallbackWithDeleteHelper; + +// Only support callbacks that return void because otherwise it is odd to call +// the callback in the destructor and drop the return value immediately. +template <typename... Args> +class CallbackWithDeleteHelper<void(Args...)> { + public: + using CallbackType = base::OnceCallback<void(Args...)>; + + // Bound arguments may be different to the callback signature when wrappers + // are used, e.g. in base::Owned and base::Unretained case, they are + // OwnedWrapper and UnretainedWrapper. Use BoundArgs to help handle this. + template <typename... BoundArgs> + explicit CallbackWithDeleteHelper(CallbackType callback, BoundArgs&&... args) + : callback_(std::move(callback)) { + delete_callback_ = + base::BindOnce(&CallbackWithDeleteHelper::Run, base::Unretained(this), + std::forward<BoundArgs>(args)...); + } + + // The first int param acts to disambiguate this constructor from the template + // constructor above. The precendent is C++'s own operator++(int) vs + // operator++() to distinguish post-increment and pre-increment. + CallbackWithDeleteHelper(int ignored, + CallbackType callback, + base::OnceClosure delete_callback) + : callback_(std::move(callback)), + delete_callback_(std::move(delete_callback)) {} + + ~CallbackWithDeleteHelper() { + if (delete_callback_) + std::move(delete_callback_).Run(); + } + + void Run(Args... args) { + delete_callback_.Reset(); + std::move(callback_).Run(std::forward<Args>(args)...); + } + + private: + CallbackType callback_; + base::OnceClosure delete_callback_; + + DISALLOW_COPY_AND_ASSIGN(CallbackWithDeleteHelper); +}; + +} // namespace internal + +template <typename T, typename... Args> +inline base::OnceCallback<T> WrapCallbackWithDropHandler( + base::OnceCallback<T> cb, + base::OnceClosure delete_cb) { + return base::BindOnce(&internal::CallbackWithDeleteHelper<T>::Run, + std::make_unique<internal::CallbackWithDeleteHelper<T>>( + 0, std::move(cb), std::move(delete_cb))); +} + +template <typename T, typename... Args> +inline base::OnceCallback<T> WrapCallbackWithDefaultInvokeIfNotRun( + base::OnceCallback<T> cb, + Args&&... args) { + return base::BindOnce(&internal::CallbackWithDeleteHelper<T>::Run, + std::make_unique<internal::CallbackWithDeleteHelper<T>>( + std::move(cb), std::forward<Args>(args)...)); +} + +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_CALLBACK_HELPERS_H_ diff --git a/mojo/public/cpp/bindings/clone_traits.h b/mojo/public/cpp/bindings/clone_traits.h index 203ab34189..e7e0a14cf0 100644 --- a/mojo/public/cpp/bindings/clone_traits.h +++ b/mojo/public/cpp/bindings/clone_traits.h @@ -6,9 +6,9 @@ #define MOJO_PUBLIC_CPP_BINDINGS_CLONE_TRAITS_H_ #include <type_traits> -#include <unordered_map> #include <vector> +#include "base/containers/flat_map.h" #include "base/optional.h" #include "mojo/public/cpp/bindings/lib/template_util.h" @@ -65,9 +65,9 @@ struct CloneTraits<std::vector<T>, false> { }; template <typename K, typename V> -struct CloneTraits<std::unordered_map<K, V>, false> { - static std::unordered_map<K, V> Clone(const std::unordered_map<K, V>& input) { - std::unordered_map<K, V> result; +struct CloneTraits<base::flat_map<K, V>, false> { + static base::flat_map<K, V> Clone(const base::flat_map<K, V>& input) { + base::flat_map<K, V> result; for (const auto& element : input) { result.insert(std::make_pair(mojo::Clone(element.first), mojo::Clone(element.second))); diff --git a/mojo/public/cpp/bindings/connection_error_callback.h b/mojo/public/cpp/bindings/connection_error_callback.h index 306e99e45b..0b9759e6d9 100644 --- a/mojo/public/cpp/bindings/connection_error_callback.h +++ b/mojo/public/cpp/bindings/connection_error_callback.h @@ -9,12 +9,15 @@ namespace mojo { -// This callback type accepts user-defined disconnect reason and description. If -// the other side specifies a reason on closing the connection, it will be +// These callback types accept user-defined disconnect reason and description. +// If the other side specifies a reason on closing the connection, it will be // passed to the error handler. using ConnectionErrorWithReasonCallback = - base::Callback<void(uint32_t /* custom_reason */, - const std::string& /* description */)>; + base::OnceCallback<void(uint32_t /* custom_reason */, + const std::string& /* description */)>; +using RepeatingConnectionErrorWithReasonCallback = + base::RepeatingCallback<void(uint32_t /* custom_reason */, + const std::string& /* description */)>; } // namespace mojo diff --git a/mojo/public/cpp/bindings/connector.h b/mojo/public/cpp/bindings/connector.h index cb065c174d..fa71604b1b 100644 --- a/mojo/public/cpp/bindings/connector.h +++ b/mojo/public/cpp/bindings/connector.h @@ -5,19 +5,22 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_CONNECTOR_H_ #define MOJO_PUBLIC_CPP_BINDINGS_CONNECTOR_H_ +#include <atomic> #include <memory> +#include <utility> #include "base/callback.h" #include "base/compiler_specific.h" #include "base/memory/ref_counted.h" #include "base/memory/weak_ptr.h" #include "base/optional.h" -#include "base/single_thread_task_runner.h" -#include "base/threading/thread_checker.h" +#include "base/sequence_checker.h" +#include "base/sequenced_task_runner.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/bindings/sync_handle_watcher.h" #include "mojo/public/cpp/system/core.h" +#include "mojo/public/cpp/system/handle_signal_tracker.h" #include "mojo/public/cpp/system/simple_watcher.h" namespace base { @@ -35,28 +38,60 @@ namespace mojo { // - MessagePipe I/O is non-blocking. // - Sending messages can be configured to be thread safe (please see comments // of the constructor). Other than that, the object should only be accessed -// on the creating thread. -class MOJO_CPP_BINDINGS_EXPORT Connector - : NON_EXPORTED_BASE(public MessageReceiver) { +// on the creating sequence. +class MOJO_CPP_BINDINGS_EXPORT Connector : public MessageReceiver { public: enum ConnectorConfig { - // Connector::Accept() is only called from a single thread. + // Connector::Accept() is only called from a single sequence. SINGLE_THREADED_SEND, - // Connector::Accept() is allowed to be called from multiple threads. + // Connector::Accept() is allowed to be called from multiple sequences. MULTI_THREADED_SEND }; + // Determines how this Connector should behave with respect to serialization + // of outgoing messages. + enum class OutgoingSerializationMode { + // Lazy serialization. The Connector prefers to transmit serialized messages + // only when it knows its peer endpoint is remote. This ensures outgoing + // requests are unserialized by default (when possible, i.e. when generated + // bindings support it) and serialized only if and when necessary. + kLazy, + + // Eager serialization. The Connector always prefers serialized messages, + // ensuring that interface calls will be serialized immediately before + // sending on the Connector. + kEager, + }; + + // Determines how this Connector should behave with respect to serialization + // of incoming messages. + enum class IncomingSerializationMode { + // Accepts and dispatches either serialized or unserialized messages. This + // is the only mode that should be used in production. + kDispatchAsIs, + + // Accepts either serialized or unserialized messages, but always forces + // serialization (if applicable) before dispatch. Should be used only in + // test environments to coerce the lazy serialization of a message after + // transmission. + kSerializeBeforeDispatchForTesting, + }; + // The Connector takes ownership of |message_pipe|. Connector(ScopedMessagePipeHandle message_pipe, ConnectorConfig config, - scoped_refptr<base::SingleThreadTaskRunner> runner); + scoped_refptr<base::SequencedTaskRunner> runner); ~Connector() override; + // Sets outgoing serialization mode. + void SetOutgoingSerializationMode(OutgoingSerializationMode mode); + void SetIncomingSerializationMode(IncomingSerializationMode mode); + // Sets the receiver to handle messages read from the message pipe. The // Connector will read messages from the pipe regardless of whether or not an // incoming receiver has been set. void set_incoming_receiver(MessageReceiver* receiver) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); incoming_receiver_ = receiver; } @@ -64,21 +99,21 @@ class MOJO_CPP_BINDINGS_EXPORT Connector // state, where no more messages will be processed. This method is used // during testing to prevent that from happening. void set_enforce_errors_from_incoming_receiver(bool enforce) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); enforce_errors_from_incoming_receiver_ = enforce; } // Sets the error handler to receive notifications when an error is // encountered while reading from the pipe or waiting to read from the pipe. - void set_connection_error_handler(const base::Closure& error_handler) { - DCHECK(thread_checker_.CalledOnValidThread()); - connection_error_handler_ = error_handler; + void set_connection_error_handler(base::OnceClosure error_handler) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + connection_error_handler_ = std::move(error_handler); } // Returns true if an error was encountered while reading from the pipe or // waiting to read from the pipe. bool encountered_error() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return error_; } @@ -106,7 +141,7 @@ class MOJO_CPP_BINDINGS_EXPORT Connector // Is the connector bound to a MessagePipe handle? bool is_valid() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return message_pipe_.is_valid(); } @@ -120,15 +155,16 @@ class MOJO_CPP_BINDINGS_EXPORT Connector void ResumeIncomingMethodCallProcessing(); // MessageReceiver implementation: + bool PrefersSerializedMessages() override; bool Accept(Message* message) override; MessagePipeHandle handle() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return message_pipe_.get(); } // Allows |message_pipe_| to be watched while others perform sync handle - // watching on the same thread. Please see comments of + // watching on the same sequence. Please see comments of // SyncHandleWatcher::AllowWokenUpBySyncWatchOnSameThread(). void AllowWokenUpBySyncWatchOnSameThread(); @@ -147,17 +183,22 @@ class MOJO_CPP_BINDINGS_EXPORT Connector return sync_handle_watcher_callback_count_ > 0; } - base::SingleThreadTaskRunner* task_runner() const { - return task_runner_.get(); - } + base::SequencedTaskRunner* task_runner() const { return task_runner_.get(); } // Sets the tag used by the heap profiler. // |tag| must be a const string literal. void SetWatcherHeapProfilerTag(const char* tag); + // Allows testing environments to override the default serialization behavior + // of newly constructed Connector instances. Must be called before any + // Connector instances are constructed. + static void OverrideDefaultSerializationBehaviorForTesting( + OutgoingSerializationMode outgoing_mode, + IncomingSerializationMode incoming_mode); + private: class ActiveDispatchTracker; - class MessageLoopNestingObserver; + class RunLoopNestingObserver; // Callback of mojo::SimpleWatcher. void OnWatcherHandleReady(MojoResult result); @@ -185,21 +226,25 @@ class MOJO_CPP_BINDINGS_EXPORT Connector void EnsureSyncWatcherExists(); - base::Closure connection_error_handler_; + base::OnceClosure connection_error_handler_; ScopedMessagePipeHandle message_pipe_; MessageReceiver* incoming_receiver_ = nullptr; - scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; std::unique_ptr<SimpleWatcher> handle_watcher_; + base::Optional<HandleSignalTracker> peer_remoteness_tracker_; - bool error_ = false; + std::atomic<bool> error_; bool drop_writes_ = false; bool enforce_errors_from_incoming_receiver_ = true; bool paused_ = false; - // If sending messages is allowed from multiple threads, |lock_| is used to + OutgoingSerializationMode outgoing_serialization_mode_; + IncomingSerializationMode incoming_serialization_mode_; + + // If sending messages is allowed from multiple sequences, |lock_| is used to // protect modifications to |message_pipe_| and |drop_writes_|. base::Optional<base::Lock> lock_; @@ -209,23 +254,27 @@ class MOJO_CPP_BINDINGS_EXPORT Connector // callback. size_t sync_handle_watcher_callback_count_ = 0; - base::ThreadChecker thread_checker_; + SEQUENCE_CHECKER(sequence_checker_); base::Lock connected_lock_; bool connected_ = true; // The tag used to track heap allocations that originated from a Watcher // notification. - const char* heap_profiler_tag_ = nullptr; + const char* heap_profiler_tag_ = "unknown interface"; - // A cached pointer to the MessageLoopNestingObserver for the MessageLoop on - // which this Connector was created. - MessageLoopNestingObserver* const nesting_observer_; + // A cached pointer to the RunLoopNestingObserver for the thread on which this + // Connector was created. + RunLoopNestingObserver* const nesting_observer_; // |true| iff the Connector is currently dispatching a message. Used to detect // nested dispatch operations. bool is_dispatching_ = false; +#if defined(ENABLE_IPC_FUZZER) + std::unique_ptr<MessageReceiver> message_dumper_; +#endif + // Create a single weak ptr and use it everywhere, to avoid the malloc/free // cost of creating a new weak ptr whenever it is needed. // NOTE: This weak pointer is invalidated when the message pipe is closed or diff --git a/mojo/public/cpp/bindings/enum_traits.h b/mojo/public/cpp/bindings/enum_traits.h index 2c528f3226..f4ba5a241b 100644 --- a/mojo/public/cpp/bindings/enum_traits.h +++ b/mojo/public/cpp/bindings/enum_traits.h @@ -5,6 +5,8 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_ENUM_TRAITS_H_ #define MOJO_PUBLIC_CPP_BINDINGS_ENUM_TRAITS_H_ +#include "mojo/public/cpp/bindings/lib/template_util.h" + namespace mojo { // This must be specialized for any type |T| to be serialized/deserialized as a @@ -20,7 +22,11 @@ namespace mojo { // }; // template <typename MojomType, typename T> -struct EnumTraits; +struct EnumTraits { + static_assert(internal::AlwaysFalse<T>::value, + "Cannot find the mojo::EnumTraits specialization. Did you " + "forget to include the corresponding header file?"); +}; } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/equals_traits.h b/mojo/public/cpp/bindings/equals_traits.h index 53c7dce693..d0bf7c1f3c 100644 --- a/mojo/public/cpp/bindings/lib/equals_traits.h +++ b/mojo/public/cpp/bindings/equals_traits.h @@ -2,18 +2,21 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_EQUALS_TRAITS_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_EQUALS_TRAITS_H_ +#ifndef MOJO_PUBLIC_CPP_BINDINGS_EQUALS_TRAITS_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_EQUALS_TRAITS_H_ #include <type_traits> -#include <unordered_map> #include <vector> +#include "base/containers/flat_map.h" #include "base/optional.h" #include "mojo/public/cpp/bindings/lib/template_util.h" namespace mojo { -namespace internal { + +// EqualsTraits<> allows you to specify comparison functions for mapped mojo +// objects. By default objects can be compared if they implement operator==() +// or have a method named Equals(). template <typename T> struct HasEqualsMethod { @@ -24,7 +27,7 @@ struct HasEqualsMethod { static const bool value = sizeof(Test<T>(0)) == sizeof(char); private: - EnsureTypeIsComplete<T> check_t_; + internal::EnsureTypeIsComplete<T> check_t_; }; template <typename T, bool has_equals_method = HasEqualsMethod<T>::value> @@ -51,7 +54,9 @@ struct EqualsTraits<base::Optional<T>, false> { if (!a || !b) return false; - return internal::Equals(*a, *b); + // NOTE: Not just Equals() because that's EqualsTraits<>::Equals() and we + // want mojo::Equals() for things like base::Optional<std::vector<T>>. + return mojo::Equals(*a, *b); } }; @@ -61,7 +66,7 @@ struct EqualsTraits<std::vector<T>, false> { if (a.size() != b.size()) return false; for (size_t i = 0; i < a.size(); ++i) { - if (!internal::Equals(a[i], b[i])) + if (!mojo::Equals(a[i], b[i])) return false; } return true; @@ -69,14 +74,14 @@ struct EqualsTraits<std::vector<T>, false> { }; template <typename K, typename V> -struct EqualsTraits<std::unordered_map<K, V>, false> { - static bool Equals(const std::unordered_map<K, V>& a, - const std::unordered_map<K, V>& b) { +struct EqualsTraits<base::flat_map<K, V>, false> { + static bool Equals(const base::flat_map<K, V>& a, + const base::flat_map<K, V>& b) { if (a.size() != b.size()) return false; for (const auto& element : a) { auto iter = b.find(element.first); - if (iter == b.end() || !internal::Equals(element.second, iter->second)) + if (iter == b.end() || !mojo::Equals(element.second, iter->second)) return false; } return true; @@ -88,7 +93,6 @@ bool Equals(const T& a, const T& b) { return EqualsTraits<T>::Equals(a, b); } -} // namespace internal } // namespace mojo -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_EQUALS_TRAITS_H_ +#endif // MOJO_PUBLIC_CPP_BINDINGS_EQUALS_TRAITS_H_ diff --git a/mojo/public/cpp/bindings/filter_chain.h b/mojo/public/cpp/bindings/filter_chain.h index 1262f39b80..9d3f2c08c7 100644 --- a/mojo/public/cpp/bindings/filter_chain.h +++ b/mojo/public/cpp/bindings/filter_chain.h @@ -16,8 +16,7 @@ namespace mojo { -class MOJO_CPP_BINDINGS_EXPORT FilterChain - : NON_EXPORTED_BASE(public MessageReceiver) { +class MOJO_CPP_BINDINGS_EXPORT FilterChain : public MessageReceiver { public: // Doesn't take ownership of |sink|. Therefore |sink| has to stay alive while // this object is alive. @@ -49,7 +48,7 @@ class MOJO_CPP_BINDINGS_EXPORT FilterChain template <typename FilterType, typename... Args> inline void FilterChain::Append(Args&&... args) { - Append(base::MakeUnique<FilterType>(std::forward<Args>(args)...)); + Append(std::make_unique<FilterType>(std::forward<Args>(args)...)); } template <> diff --git a/mojo/public/cpp/bindings/interface_endpoint_client.h b/mojo/public/cpp/bindings/interface_endpoint_client.h index b519fe92bb..6842c7c322 100644 --- a/mojo/public/cpp/bindings/interface_endpoint_client.h +++ b/mojo/public/cpp/bindings/interface_endpoint_client.h @@ -9,6 +9,7 @@ #include <map> #include <memory> +#include <utility> #include "base/callback.h" #include "base/compiler_specific.h" @@ -17,8 +18,8 @@ #include "base/memory/ref_counted.h" #include "base/memory/weak_ptr.h" #include "base/optional.h" -#include "base/single_thread_task_runner.h" -#include "base/threading/thread_checker.h" +#include "base/sequence_checker.h" +#include "base/sequenced_task_runner.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" #include "mojo/public/cpp/bindings/disconnect_reason.h" @@ -35,9 +36,9 @@ class InterfaceEndpointController; // InterfaceEndpointClient handles message sending and receiving of an interface // endpoint, either the implementation side or the client side. -// It should only be accessed and destructed on the creating thread. +// It should only be accessed and destructed on the creating sequence. class MOJO_CPP_BINDINGS_EXPORT InterfaceEndpointClient - : NON_EXPORTED_BASE(public MessageReceiverWithResponder) { + : public MessageReceiverWithResponder { public: // |receiver| is okay to be null. If it is not null, it must outlive this // object. @@ -45,34 +46,34 @@ class MOJO_CPP_BINDINGS_EXPORT InterfaceEndpointClient MessageReceiverWithResponderStatus* receiver, std::unique_ptr<MessageReceiver> payload_validator, bool expect_sync_requests, - scoped_refptr<base::SingleThreadTaskRunner> runner, + scoped_refptr<base::SequencedTaskRunner> runner, uint32_t interface_version); ~InterfaceEndpointClient() override; // Sets the error handler to receive notifications when an error is // encountered. - void set_connection_error_handler(const base::Closure& error_handler) { - DCHECK(thread_checker_.CalledOnValidThread()); - error_handler_ = error_handler; + void set_connection_error_handler(base::OnceClosure error_handler) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + error_handler_ = std::move(error_handler); error_with_reason_handler_.Reset(); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { - DCHECK(thread_checker_.CalledOnValidThread()); - error_with_reason_handler_ = error_handler; + ConnectionErrorWithReasonCallback error_handler) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + error_with_reason_handler_ = std::move(error_handler); error_handler_.Reset(); } // Returns true if an error was encountered. bool encountered_error() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return encountered_error_; } // Returns true if this endpoint has any pending callbacks. bool has_pending_responders() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return !async_responders_.empty() || !sync_responses_.empty(); } @@ -94,6 +95,7 @@ class MOJO_CPP_BINDINGS_EXPORT InterfaceEndpointClient // MessageReceiverWithResponder implementation: // They must only be called when the handle is not in pending association // state. + bool PrefersSerializedMessages() override; bool Accept(Message* message) override; bool AcceptWithResponder(Message* message, std::unique_ptr<MessageReceiver> responder) override; @@ -172,16 +174,16 @@ class MOJO_CPP_BINDINGS_EXPORT InterfaceEndpointClient uint64_t next_request_id_ = 1; - base::Closure error_handler_; + base::OnceClosure error_handler_; ConnectionErrorWithReasonCallback error_with_reason_handler_; bool encountered_error_ = false; - scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; internal::ControlMessageProxy control_message_proxy_; internal::ControlMessageHandler control_message_handler_; - base::ThreadChecker thread_checker_; + SEQUENCE_CHECKER(sequence_checker_); base::WeakPtrFactory<InterfaceEndpointClient> weak_ptr_factory_; diff --git a/mojo/public/cpp/bindings/interface_endpoint_controller.h b/mojo/public/cpp/bindings/interface_endpoint_controller.h index 8d99d4a45f..dabc3d8800 100644 --- a/mojo/public/cpp/bindings/interface_endpoint_controller.h +++ b/mojo/public/cpp/bindings/interface_endpoint_controller.h @@ -18,8 +18,8 @@ class InterfaceEndpointController { virtual bool SendMessage(Message* message) = 0; // Allows the interface endpoint to watch for incoming sync messages while - // others perform sync handle watching on the same thread. Please see comments - // of SyncHandleWatcher::AllowWokenUpBySyncWatchOnSameThread(). + // others perform sync handle watching on the same sequence. Please see + // comments of SyncHandleWatcher::AllowWokenUpBySyncWatchOnSameThread(). virtual void AllowWokenUpBySyncWatchOnSameThread() = 0; // Watches the interface endpoint for incoming sync messages. (It also watches diff --git a/mojo/public/cpp/bindings/interface_id.h b/mojo/public/cpp/bindings/interface_id.h index 53475d6f78..d6128537d2 100644 --- a/mojo/public/cpp/bindings/interface_id.h +++ b/mojo/public/cpp/bindings/interface_id.h @@ -30,6 +30,10 @@ inline bool IsValidInterfaceId(InterfaceId id) { return id != kInvalidInterfaceId; } +inline bool HasInterfaceIdNamespaceBitSet(InterfaceId id) { + return (id & kInterfaceIdNamespaceMask) != 0; +} + } // namespace mojo #endif // MOJO_PUBLIC_CPP_BINDINGS_INTERFACE_ID_H_ diff --git a/mojo/public/cpp/bindings/interface_ptr.h b/mojo/public/cpp/bindings/interface_ptr.h index e88be7436f..3ec522d472 100644 --- a/mojo/public/cpp/bindings/interface_ptr.h +++ b/mojo/public/cpp/bindings/interface_ptr.h @@ -14,8 +14,7 @@ #include "base/logging.h" #include "base/macros.h" #include "base/memory/ref_counted.h" -#include "base/single_thread_task_runner.h" -#include "base/threading/thread_task_runner_handle.h" +#include "base/sequenced_task_runner.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" #include "mojo/public/cpp/bindings/interface_ptr_info.h" #include "mojo/public/cpp/bindings/lib/interface_ptr_state.h" @@ -25,34 +24,39 @@ namespace mojo { // A pointer to a local proxy of a remote Interface implementation. Uses a // message pipe to communicate with the remote implementation, and automatically // closes the pipe and deletes the proxy on destruction. The pointer must be -// bound to a message pipe before the interface methods can be called. +// bound to a message pipe before the interface methods can be called. Once a +// pointer is destroyed, it is guaranteed that pending callbacks as well as the +// connection error handler (if registered) won't be called. // // This class is thread hostile, as is the local proxy it manages, while bound // to a message pipe. All calls to this class or the proxy should be from the -// same thread that bound it. If you need to move the proxy to a different -// thread, extract the InterfacePtrInfo (containing just the message pipe and -// any version information) using PassInterface(), pass it to a different -// thread, and create and bind a new InterfacePtr from that thread. If an -// InterfacePtr is not bound to a message pipe, it may be bound or destroyed on -// any thread. +// same sequence that bound it. If you need to move the proxy to a different +// sequence, extract the InterfacePtrInfo (containing just the message pipe and +// any version information) using PassInterface() on the original sequence, pass +// it to a different sequence, and create and bind a new InterfacePtr from that +// sequence. If an InterfacePtr is not bound to a message pipe, it may be bound +// or destroyed on any sequence. template <typename Interface> class InterfacePtr { public: using InterfaceType = Interface; using PtrInfoType = InterfacePtrInfo<Interface>; + using Proxy = typename Interface::Proxy_; // Constructs an unbound InterfacePtr. InterfacePtr() {} InterfacePtr(decltype(nullptr)) {} // Takes over the binding of another InterfacePtr. - InterfacePtr(InterfacePtr&& other) { + InterfacePtr(InterfacePtr&& other) noexcept { internal_state_.Swap(&other.internal_state_); } + explicit InterfacePtr(PtrInfoType&& info) noexcept { Bind(std::move(info)); } + // Takes over the binding of another InterfacePtr, and closes any message pipe // already bound to this pointer. - InterfacePtr& operator=(InterfacePtr&& other) { + InterfacePtr& operator=(InterfacePtr&& other) noexcept { reset(); internal_state_.Swap(&other.internal_state_); return *this; @@ -74,13 +78,12 @@ class InterfacePtr { // has the same effect as reset(). In this case, the InterfacePtr is not // considered as bound. // - // |runner| must belong to the same thread. It will be used to dispatch all - // callbacks and connection error notification. It is useful when you attach - // multiple task runners to a single thread for the purposes of task - // scheduling. + // Optionally, |runner| is a SequencedTaskRunner bound to the current sequence + // on which all callbacks and connection error notifications will be + // dispatched. It is only useful to specify this to use a different + // SequencedTaskRunner than SequencedTaskRunnerHandle::Get(). void Bind(InterfacePtrInfo<Interface> info, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { + scoped_refptr<base::SequencedTaskRunner> runner = nullptr) { reset(); if (info.is_valid()) internal_state_.Bind(std::move(info), std::move(runner)); @@ -91,11 +94,11 @@ class InterfacePtr { // Returns a raw pointer to the local proxy. Caller does not take ownership. // Note that the local proxy is thread hostile, as stated above. - Interface* get() const { return internal_state_.instance(); } + Proxy* get() const { return internal_state_.instance(); } // Functions like a pointer to Interface. Must already be bound. - Interface* operator->() const { return get(); } - Interface& operator*() const { return *get(); } + Proxy* operator->() const { return get(); } + Proxy& operator*() const { return *get(); } // Returns the version number of the interface that the remote side supports. uint32_t version() const { return internal_state_.version(); } @@ -124,8 +127,7 @@ class InterfacePtr { // stimulus. void FlushForTesting() { internal_state_.FlushForTesting(); } - // Closes the bound message pipe (if any) and returns the pointer to the - // unbound state. + // Closes the bound message pipe, if any. void reset() { State doomed; internal_state_.Swap(&doomed); @@ -143,28 +145,32 @@ class InterfacePtr { return internal_state_.HasAssociatedInterfaces(); } + // Returns true if bound and awaiting a response to a message. + bool IsExpectingResponse() { return internal_state_.has_pending_callbacks(); } + // Indicates whether the message pipe has encountered an error. If true, // method calls made on this interface will be dropped (and may already have // been dropped). bool encountered_error() const { return internal_state_.encountered_error(); } // Registers a handler to receive error notifications. The handler will be - // called from the thread that owns this InterfacePtr. + // called from the sequence that owns this InterfacePtr. // // This method may only be called after the InterfacePtr has been bound to a // message pipe. - void set_connection_error_handler(const base::Closure& error_handler) { - internal_state_.set_connection_error_handler(error_handler); + void set_connection_error_handler(base::OnceClosure error_handler) { + internal_state_.set_connection_error_handler(std::move(error_handler)); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { - internal_state_.set_connection_error_with_reason_handler(error_handler); + ConnectionErrorWithReasonCallback error_handler) { + internal_state_.set_connection_error_with_reason_handler( + std::move(error_handler)); } // Unbinds the InterfacePtr and returns the information which could be used // to setup an InterfacePtr again. This method may be used to move the proxy - // to a different thread (see class comments for details). + // to a different sequence (see class comments for details). // // It is an error to call PassInterface() while: // - there are pending responses; or @@ -198,26 +204,10 @@ class InterfacePtr { return &internal_state_; } - // Allow InterfacePtr<> to be used in boolean expressions, but not - // implicitly convertible to a real bool (which is dangerous). - private: - // TODO(dcheng): Use an explicit conversion operator. - typedef internal::InterfacePtrState<Interface> InterfacePtr::*Testable; - - public: - operator Testable() const { - return internal_state_.is_bound() ? &InterfacePtr::internal_state_ - : nullptr; - } + // Allow InterfacePtr<> to be used in boolean expressions. + explicit operator bool() const { return internal_state_.is_bound(); } private: - // Forbid the == and != operators explicitly, otherwise InterfacePtr will be - // converted to Testable to do == or != comparison. - template <typename T> - bool operator==(const InterfacePtr<T>& other) const = delete; - template <typename T> - bool operator!=(const InterfacePtr<T>& other) const = delete; - typedef internal::InterfacePtrState<Interface> State; mutable State internal_state_; @@ -229,8 +219,7 @@ class InterfacePtr { template <typename Interface> InterfacePtr<Interface> MakeProxy( InterfacePtrInfo<Interface> info, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { + scoped_refptr<base::SequencedTaskRunner> runner = nullptr) { InterfacePtr<Interface> ptr; if (info.is_valid()) ptr.Bind(std::move(info), std::move(runner)); diff --git a/mojo/public/cpp/bindings/interface_ptr_info.h b/mojo/public/cpp/bindings/interface_ptr_info.h index 0b2d8089c4..7003043417 100644 --- a/mojo/public/cpp/bindings/interface_ptr_info.h +++ b/mojo/public/cpp/bindings/interface_ptr_info.h @@ -5,7 +5,8 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_INTERFACE_PTR_INFO_H_ #define MOJO_PUBLIC_CPP_BINDINGS_INTERFACE_PTR_INFO_H_ -#include <stdint.h> +#include <cstddef> +#include <cstdint> #include <utility> #include "base/macros.h" @@ -19,6 +20,7 @@ template <typename Interface> class InterfacePtrInfo { public: InterfacePtrInfo() : version_(0u) {} + InterfacePtrInfo(std::nullptr_t) : InterfacePtrInfo() {} InterfacePtrInfo(ScopedMessagePipeHandle handle, uint32_t version) : handle_(std::move(handle)), version_(version) {} @@ -51,6 +53,9 @@ class InterfacePtrInfo { uint32_t version() const { return version_; } void set_version(uint32_t version) { version_ = version; } + // Allow InterfacePtrInfo<> to be used in boolean expressions. + explicit operator bool() const { return handle_.is_valid(); } + private: ScopedMessagePipeHandle handle_; uint32_t version_; diff --git a/mojo/public/cpp/bindings/interface_ptr_set.h b/mojo/public/cpp/bindings/interface_ptr_set.h index 09a268229d..17f90b1e7f 100644 --- a/mojo/public/cpp/bindings/interface_ptr_set.h +++ b/mojo/public/cpp/bindings/interface_ptr_set.h @@ -5,15 +5,19 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_INTERFACE_PTR_SET_H_ #define MOJO_PUBLIC_CPP_BINDINGS_INTERFACE_PTR_SET_H_ +#include <map> #include <utility> -#include <vector> #include "base/macros.h" #include "base/memory/weak_ptr.h" +#include "base/stl_util.h" #include "mojo/public/cpp/bindings/associated_interface_ptr.h" #include "mojo/public/cpp/bindings/interface_ptr.h" namespace mojo { + +using InterfacePtrSetElementId = size_t; + namespace internal { // TODO(blundell): This class should be rewritten to be structured @@ -26,29 +30,61 @@ class PtrSet { PtrSet() {} ~PtrSet() { CloseAll(); } - void AddPtr(Ptr<Interface> ptr) { + InterfacePtrSetElementId AddPtr(Ptr<Interface> ptr) { + InterfacePtrSetElementId id = next_ptr_id_++; auto weak_interface_ptr = new Element(std::move(ptr)); - ptrs_.push_back(weak_interface_ptr->GetWeakPtr()); + ptrs_.emplace(std::piecewise_construct, std::forward_as_tuple(id), + std::forward_as_tuple(weak_interface_ptr->GetWeakPtr())); ClearNullPtrs(); + return id; } template <typename FunctionType> void ForAllPtrs(FunctionType function) { for (const auto& it : ptrs_) { - if (it) - function(it->get()); + if (it.second) + function(it.second->get()); } ClearNullPtrs(); } void CloseAll() { for (const auto& it : ptrs_) { - if (it) - it->Close(); + if (it.second) + it.second->Close(); } ptrs_.clear(); } + bool empty() const { return ptrs_.empty(); } + + // Calls FlushForTesting on all Ptrs sequentially. Since each call is a + // blocking operation, may be very slow as the number of pointers increases. + void FlushForTesting() { + for (const auto& it : ptrs_) { + if (it.second) + it.second->FlushForTesting(); + } + ClearNullPtrs(); + } + + bool HasPtr(InterfacePtrSetElementId id) { + return ptrs_.find(id) != ptrs_.end(); + } + + Ptr<Interface> RemovePtr(InterfacePtrSetElementId id) { + auto it = ptrs_.find(id); + if (it == ptrs_.end()) + return Ptr<Interface>(); + Ptr<Interface> ptr; + if (it->second) { + ptr = it->second->Take(); + delete it->second.get(); + } + ptrs_.erase(it); + return ptr; + } + private: class Element { public: @@ -69,10 +105,14 @@ class PtrSet { Interface* get() { return ptr_.get(); } + Ptr<Interface> Take() { return std::move(ptr_); } + base::WeakPtr<Element> GetWeakPtr() { return weak_ptr_factory_.GetWeakPtr(); } + void FlushForTesting() { ptr_.FlushForTesting(); } + private: static void DeleteElement(Element* element) { delete element; } @@ -83,14 +123,11 @@ class PtrSet { }; void ClearNullPtrs() { - ptrs_.erase(std::remove_if(ptrs_.begin(), ptrs_.end(), - [](const base::WeakPtr<Element>& p) { - return p.get() == nullptr; - }), - ptrs_.end()); + base::EraseIf(ptrs_, [](const auto& pair) { return !(pair.second); }); } - std::vector<base::WeakPtr<Element>> ptrs_; + InterfacePtrSetElementId next_ptr_id_ = 0; + std::map<InterfacePtrSetElementId, base::WeakPtr<Element>> ptrs_; }; } // namespace internal diff --git a/mojo/public/cpp/bindings/interface_request.h b/mojo/public/cpp/bindings/interface_request.h index 29d883615e..ccfdb3716e 100644 --- a/mojo/public/cpp/bindings/interface_request.h +++ b/mojo/public/cpp/bindings/interface_request.h @@ -11,7 +11,6 @@ #include "base/macros.h" #include "base/optional.h" #include "base/single_thread_task_runner.h" -#include "base/threading/thread_task_runner_handle.h" #include "mojo/public/cpp/bindings/disconnect_reason.h" #include "mojo/public/cpp/bindings/interface_ptr.h" #include "mojo/public/cpp/bindings/pipe_control_message_proxy.h" @@ -33,18 +32,8 @@ class InterfaceRequest { InterfaceRequest() {} InterfaceRequest(decltype(nullptr)) {} - // Creates a new message pipe over which Interface is to be served, binding - // the specified InterfacePtr to one end of the message pipe and this - // InterfaceRequest to the other. For example usage, see comments on - // MakeRequest(InterfacePtr*) below. - explicit InterfaceRequest(InterfacePtr<Interface>* ptr, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { - MessagePipe pipe; - ptr->Bind(InterfacePtrInfo<Interface>(std::move(pipe.handle0), 0u), - std::move(runner)); - Bind(std::move(pipe.handle1)); - } + explicit InterfaceRequest(ScopedMessagePipeHandle handle) + : handle_(std::move(handle)) {} // Takes the message pipe from another InterfaceRequest. InterfaceRequest(InterfaceRequest&& other) { @@ -62,14 +51,11 @@ class InterfaceRequest { return *this; } - // Binds the request to a message pipe over which Interface is to be - // requested. If the request is already bound to a message pipe, the current - // message pipe will be closed. - void Bind(ScopedMessagePipeHandle handle) { handle_ = std::move(handle); } - // Indicates whether the request currently contains a valid message pipe. bool is_pending() const { return handle_.is_valid(); } + explicit operator bool() const { return handle_.is_valid(); } + // Removes the message pipe from the request and returns it. ScopedMessagePipeHandle PassMessagePipe() { return std::move(handle_); } @@ -102,16 +88,6 @@ class InterfaceRequest { DISALLOW_COPY_AND_ASSIGN(InterfaceRequest); }; -// Makes an InterfaceRequest bound to the specified message pipe. If |handle| -// is empty or invalid, the resulting InterfaceRequest will represent the -// absence of a request. -template <typename Interface> -InterfaceRequest<Interface> MakeRequest(ScopedMessagePipeHandle handle) { - InterfaceRequest<Interface> request; - request.Bind(std::move(handle)); - return std::move(request); -} - // Creates a new message pipe over which Interface is to be served. Binds the // specified InterfacePtr to one end of the message pipe, and returns an // InterfaceRequest bound to the other. The InterfacePtr should be passed to @@ -158,9 +134,21 @@ InterfaceRequest<Interface> MakeRequest(ScopedMessagePipeHandle handle) { template <typename Interface> InterfaceRequest<Interface> MakeRequest( InterfacePtr<Interface>* ptr, - scoped_refptr<base::SingleThreadTaskRunner> runner = - base::ThreadTaskRunnerHandle::Get()) { - return InterfaceRequest<Interface>(ptr, runner); + scoped_refptr<base::SingleThreadTaskRunner> runner = nullptr) { + MessagePipe pipe; + ptr->Bind(InterfacePtrInfo<Interface>(std::move(pipe.handle0), 0u), + std::move(runner)); + return InterfaceRequest<Interface>(std::move(pipe.handle1)); +} + +// Similar to the constructor above, but binds one end of the message pipe to +// an InterfacePtrInfo instance. +template <typename Interface> +InterfaceRequest<Interface> MakeRequest(InterfacePtrInfo<Interface>* ptr_info) { + MessagePipe pipe; + ptr_info->set_handle(std::move(pipe.handle0)); + ptr_info->set_version(0u); + return InterfaceRequest<Interface>(std::move(pipe.handle1)); } // Fuses an InterfaceRequest<T> endpoint with an InterfacePtrInfo<T> endpoint. diff --git a/mojo/public/cpp/bindings/lib/array_internal.h b/mojo/public/cpp/bindings/lib/array_internal.h index eecfcfbc28..574be9b6f5 100644 --- a/mojo/public/cpp/bindings/lib/array_internal.h +++ b/mojo/public/cpp/bindings/lib/array_internal.h @@ -11,9 +11,10 @@ #include <limits> #include <new> +#include "base/component_export.h" #include "base/logging.h" +#include "base/macros.h" #include "mojo/public/c/system/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" #include "mojo/public/cpp/bindings/lib/buffer.h" #include "mojo/public/cpp/bindings/lib/serialization_util.h" @@ -29,13 +30,15 @@ namespace internal { template <typename K, typename V> class Map_Data; -MOJO_CPP_BINDINGS_EXPORT std::string -MakeMessageWithArrayIndex(const char* message, size_t size, size_t index); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +std::string MakeMessageWithArrayIndex(const char* message, + size_t size, + size_t index); -MOJO_CPP_BINDINGS_EXPORT std::string MakeMessageWithExpectedArraySize( - const char* message, - size_t size, - size_t expected_size); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +std::string MakeMessageWithExpectedArraySize(const char* message, + size_t size, + size_t expected_size); template <typename T> struct ArrayDataTraits { @@ -68,7 +71,7 @@ template <> struct ArrayDataTraits<bool> { // Helper class to emulate a reference to a bool, used for direct element // access. - class MOJO_CPP_BINDINGS_EXPORT BitRef { + class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) BitRef { public: ~BitRef(); BitRef& operator=(bool value); @@ -268,17 +271,35 @@ class Array_Data { std::is_same<T, Handle_Data>::value>; using Element = T; - // Returns null if |num_elements| or the corresponding storage size cannot be - // stored in uint32_t. - static Array_Data<T>* New(size_t num_elements, Buffer* buf) { - if (num_elements > Traits::kMaxNumElements) - return nullptr; + class BufferWriter { + public: + BufferWriter() = default; + + void Allocate(size_t num_elements, Buffer* buffer) { + if (num_elements > Traits::kMaxNumElements) + return; + + uint32_t num_bytes = + Traits::GetStorageSize(static_cast<uint32_t>(num_elements)); + buffer_ = buffer; + index_ = buffer_->Allocate(num_bytes); + new (data()) + Array_Data<T>(num_bytes, static_cast<uint32_t>(num_elements)); + } - uint32_t num_bytes = - Traits::GetStorageSize(static_cast<uint32_t>(num_elements)); - return new (buf->Allocate(num_bytes)) - Array_Data<T>(num_bytes, static_cast<uint32_t>(num_elements)); - } + bool is_null() const { return !buffer_; } + Array_Data<T>* data() { + DCHECK(!is_null()); + return buffer_->Get<Array_Data<T>>(index_); + } + Array_Data<T>* operator->() { return data(); } + + private: + Buffer* buffer_ = nullptr; + size_t index_ = 0; + + DISALLOW_COPY_AND_ASSIGN(BufferWriter); + }; static bool Validate(const void* data, ValidationContext* validation_context, diff --git a/mojo/public/cpp/bindings/lib/array_serialization.h b/mojo/public/cpp/bindings/lib/array_serialization.h index d2f8ecfd72..8323b5f9a1 100644 --- a/mojo/public/cpp/bindings/lib/array_serialization.h +++ b/mojo/public/cpp/bindings/lib/array_serialization.h @@ -112,20 +112,19 @@ struct ArraySerializer< using DataElement = typename Data::Element; using Element = typename MojomType::Element; using Traits = ArrayTraits<UserType>; + using BufferWriter = typename Data::BufferWriter; static_assert(std::is_same<Element, DataElement>::value, "Incorrect array serializer"); - static_assert(std::is_same<Element, typename Traits::Element>::value, - "Incorrect array serializer"); - - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - return sizeof(Data) + Align(input->GetSize() * sizeof(DataElement)); - } + static_assert( + std::is_same< + Element, + typename std::remove_const<typename Traits::Element>::type>::value, + "Incorrect array serializer"); static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(!validate_params->element_is_nullable) @@ -138,6 +137,7 @@ struct ArraySerializer< return; auto data = input->GetDataIfExists(); + Data* output = writer->data(); if (data) { memcpy(output->storage(), data, size * sizeof(DataElement)); } else { @@ -180,18 +180,14 @@ struct ArraySerializer< using DataElement = typename Data::Element; using Element = typename MojomType::Element; using Traits = ArrayTraits<UserType>; + using BufferWriter = typename Data::BufferWriter; static_assert(sizeof(Element) == sizeof(DataElement), "Incorrect array serializer"); - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - return sizeof(Data) + Align(input->GetSize() * sizeof(DataElement)); - } - static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(!validate_params->element_is_nullable) @@ -199,6 +195,7 @@ struct ArraySerializer< DCHECK(!validate_params->element_validate_params) << "Primitive type should not have array validate params"; + Data* output = writer->data(); size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) Serialize<Element>(input->GetNext(), output->storage() + i); @@ -231,18 +228,14 @@ struct ArraySerializer<MojomType, using UserType = typename std::remove_const<MaybeConstUserType>::type; using Traits = ArrayTraits<UserType>; using Data = typename MojomTypeTraits<MojomType>::Data; + using BufferWriter = typename Data::BufferWriter; static_assert(std::is_same<bool, typename Traits::Element>::value, "Incorrect array serializer"); - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - return sizeof(Data) + Align((input->GetSize() + 7) / 8); - } - static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(!validate_params->element_is_nullable) @@ -250,6 +243,7 @@ struct ArraySerializer<MojomType, DCHECK(!validate_params->element_validate_params) << "Primitive type should not have array validate params"; + Data* output = writer->data(); size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) output->at(i) = input->GetNext(); @@ -278,37 +272,23 @@ struct ArraySerializer< BelongsTo<typename MojomType::Element, MojomTypeCategory::ASSOCIATED_INTERFACE | MojomTypeCategory::ASSOCIATED_INTERFACE_REQUEST | - MojomTypeCategory::HANDLE | - MojomTypeCategory::INTERFACE | + MojomTypeCategory::HANDLE | MojomTypeCategory::INTERFACE | MojomTypeCategory::INTERFACE_REQUEST>::value>::type> { using UserType = typename std::remove_const<MaybeConstUserType>::type; using Data = typename MojomTypeTraits<MojomType>::Data; using Element = typename MojomType::Element; using Traits = ArrayTraits<UserType>; - - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - size_t element_count = input->GetSize(); - if (BelongsTo<Element, - MojomTypeCategory::ASSOCIATED_INTERFACE | - MojomTypeCategory::ASSOCIATED_INTERFACE_REQUEST>::value) { - for (size_t i = 0; i < element_count; ++i) { - typename UserTypeIterator::GetNextResult next = input->GetNext(); - size_t size = PrepareToSerialize<Element>(next, context); - DCHECK_EQ(size, 0u); - } - } - return sizeof(Data) + Align(element_count * sizeof(typename Data::Element)); - } + using BufferWriter = typename Data::BufferWriter; static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(!validate_params->element_validate_params) << "Handle or interface type should not have array validate params"; + Data* output = writer->data(); size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) { typename UserTypeIterator::GetNextResult next = input->GetNext(); @@ -360,35 +340,27 @@ struct ArraySerializer<MojomType, using UserType = typename std::remove_const<MaybeConstUserType>::type; using Data = typename MojomTypeTraits<MojomType>::Data; using Element = typename MojomType::Element; - using DataElementPtr = typename MojomTypeTraits<Element>::Data*; + using DataElementWriter = + typename MojomTypeTraits<Element>::Data::BufferWriter; using Traits = ArrayTraits<UserType>; - - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - size_t element_count = input->GetSize(); - size_t size = sizeof(Data) + element_count * sizeof(typename Data::Element); - for (size_t i = 0; i < element_count; ++i) { - typename UserTypeIterator::GetNextResult next = input->GetNext(); - size += PrepareToSerialize<Element>(next, context); - } - return size; - } + using BufferWriter = typename Data::BufferWriter; static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) { - DataElementPtr data_ptr; + DataElementWriter data_writer; typename UserTypeIterator::GetNextResult next = input->GetNext(); - SerializeCaller<Element>::Run(next, buf, &data_ptr, + SerializeCaller<Element>::Run(next, buf, &data_writer, validate_params->element_validate_params, context); - output->at(i).Set(data_ptr); + writer->data()->at(i).Set(data_writer.is_null() ? nullptr + : data_writer.data()); MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( - !validate_params->element_is_nullable && !data_ptr, + !validate_params->element_is_nullable && data_writer.is_null(), VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, MakeMessageWithArrayIndex("null in array expecting valid pointers", size, i)); @@ -417,10 +389,10 @@ struct ArraySerializer<MojomType, template <typename InputElementType> static void Run(InputElementType&& input, Buffer* buf, - DataElementPtr* output, + DataElementWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { - Serialize<T>(std::forward<InputElementType>(input), buf, output, context); + Serialize<T>(std::forward<InputElementType>(input), buf, writer, context); } }; @@ -429,10 +401,10 @@ struct ArraySerializer<MojomType, template <typename InputElementType> static void Run(InputElementType&& input, Buffer* buf, - DataElementPtr* output, + DataElementWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { - Serialize<T>(std::forward<InputElementType>(input), buf, output, + Serialize<T>(std::forward<InputElementType>(input), buf, writer, validate_params, context); } }; @@ -451,33 +423,24 @@ struct ArraySerializer< using UserType = typename std::remove_const<MaybeConstUserType>::type; using Data = typename MojomTypeTraits<MojomType>::Data; using Element = typename MojomType::Element; + using ElementWriter = typename Data::Element::BufferWriter; using Traits = ArrayTraits<UserType>; - - static size_t GetSerializedSize(UserTypeIterator* input, - SerializationContext* context) { - size_t element_count = input->GetSize(); - size_t size = sizeof(Data); - for (size_t i = 0; i < element_count; ++i) { - // Call with |inlined| set to false, so that it will account for both the - // data in the union and the space in the array used to hold the union. - typename UserTypeIterator::GetNextResult next = input->GetNext(); - size += PrepareToSerialize<Element>(next, false, context); - } - return size; - } + using BufferWriter = typename Data::BufferWriter; static void SerializeElements(UserTypeIterator* input, Buffer* buf, - Data* output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { size_t size = input->GetSize(); for (size_t i = 0; i < size; ++i) { - typename Data::Element* result = output->storage() + i; + ElementWriter result; + result.AllocateInline(buf, writer->data()->storage() + i); typename UserTypeIterator::GetNextResult next = input->GetNext(); Serialize<Element>(next, buf, &result, true, context); MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( - !validate_params->element_is_nullable && output->at(i).is_null(), + !validate_params->element_is_nullable && + writer->data()->at(i).is_null(), VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, MakeMessageWithArrayIndex("null in array expecting valid unions", size, i)); @@ -506,38 +469,27 @@ struct Serializer<ArrayDataView<Element>, MaybeConstUserType> { MaybeConstUserType, ArrayIterator<Traits, MaybeConstUserType>>; using Data = typename MojomTypeTraits<ArrayDataView<Element>>::Data; - - static size_t PrepareToSerialize(MaybeConstUserType& input, - SerializationContext* context) { - if (CallIsNullIfExists<Traits>(input)) - return 0; - ArrayIterator<Traits, MaybeConstUserType> iterator(input); - return Impl::GetSerializedSize(&iterator, context); - } + using BufferWriter = typename Data::BufferWriter; static void Serialize(MaybeConstUserType& input, Buffer* buf, - Data** output, + BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { - if (!CallIsNullIfExists<Traits>(input)) { - MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( - validate_params->expected_num_elements != 0 && - Traits::GetSize(input) != validate_params->expected_num_elements, - internal::VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER, - internal::MakeMessageWithExpectedArraySize( - "fixed-size array has wrong number of elements", - Traits::GetSize(input), validate_params->expected_num_elements)); - Data* result = Data::New(Traits::GetSize(input), buf); - if (result) { - ArrayIterator<Traits, MaybeConstUserType> iterator(input); - Impl::SerializeElements(&iterator, buf, result, validate_params, - context); - } - *output = result; - } else { - *output = nullptr; - } + if (CallIsNullIfExists<Traits>(input)) + return; + + const size_t size = Traits::GetSize(input); + MOJO_INTERNAL_DLOG_SERIALIZATION_WARNING( + validate_params->expected_num_elements != 0 && + size != validate_params->expected_num_elements, + internal::VALIDATION_ERROR_UNEXPECTED_ARRAY_HEADER, + internal::MakeMessageWithExpectedArraySize( + "fixed-size array has wrong number of elements", size, + validate_params->expected_num_elements)); + writer->Allocate(size, buf); + ArrayIterator<Traits, MaybeConstUserType> iterator(input); + Impl::SerializeElements(&iterator, buf, writer, validate_params, context); } static bool Deserialize(Data* input, diff --git a/mojo/public/cpp/bindings/lib/associated_binding.cc b/mojo/public/cpp/bindings/lib/associated_binding.cc index 6788e68e07..c7eddc2372 100644 --- a/mojo/public/cpp/bindings/lib/associated_binding.cc +++ b/mojo/public/cpp/bindings/lib/associated_binding.cc @@ -4,6 +4,9 @@ #include "mojo/public/cpp/bindings/associated_binding.h" +#include "base/single_thread_task_runner.h" +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + namespace mojo { AssociatedBindingBase::AssociatedBindingBase() {} @@ -27,15 +30,16 @@ void AssociatedBindingBase::CloseWithReason(uint32_t custom_reason, } void AssociatedBindingBase::set_connection_error_handler( - const base::Closure& error_handler) { + base::OnceClosure error_handler) { DCHECK(is_bound()); - endpoint_client_->set_connection_error_handler(error_handler); + endpoint_client_->set_connection_error_handler(std::move(error_handler)); } void AssociatedBindingBase::set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { DCHECK(is_bound()); - endpoint_client_->set_connection_error_with_reason_handler(error_handler); + endpoint_client_->set_connection_error_with_reason_handler( + std::move(error_handler)); } void AssociatedBindingBase::FlushForTesting() { @@ -56,7 +60,9 @@ void AssociatedBindingBase::BindImpl( endpoint_client_.reset(new InterfaceEndpointClient( std::move(handle), receiver, std::move(payload_validator), - expect_sync_requests, std::move(runner), interface_version)); + expect_sync_requests, + internal::GetTaskRunnerToUseFromUserProvidedTaskRunner(std::move(runner)), + interface_version)); } } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc b/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc index 78281eda9a..453e47a995 100644 --- a/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc +++ b/mojo/public/cpp/bindings/lib/associated_interface_ptr.cc @@ -6,12 +6,12 @@ namespace mojo { -void GetIsolatedInterface(ScopedInterfaceEndpointHandle handle) { +void AssociateWithDisconnectedPipe(ScopedInterfaceEndpointHandle handle) { MessagePipe pipe; scoped_refptr<internal::MultiplexRouter> router = - new internal::MultiplexRouter(std::move(pipe.handle0), - internal::MultiplexRouter::MULTI_INTERFACE, - false, base::ThreadTaskRunnerHandle::Get()); + new internal::MultiplexRouter( + std::move(pipe.handle0), internal::MultiplexRouter::MULTI_INTERFACE, + false, base::SequencedTaskRunnerHandle::Get()); router->AssociateInterface(std::move(handle)); } diff --git a/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.cc b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.cc new file mode 100644 index 0000000000..dd3a2510f1 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.cc @@ -0,0 +1,81 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h" + +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + +namespace mojo { +namespace internal { + +AssociatedInterfacePtrStateBase::AssociatedInterfacePtrStateBase() = default; + +AssociatedInterfacePtrStateBase::~AssociatedInterfacePtrStateBase() = default; + +void AssociatedInterfacePtrStateBase::QueryVersion( + const base::Callback<void(uint32_t)>& callback) { + // It is safe to capture |this| because the callback won't be run after this + // object goes away. + endpoint_client_->QueryVersion( + base::Bind(&AssociatedInterfacePtrStateBase::OnQueryVersion, + base::Unretained(this), callback)); +} + +void AssociatedInterfacePtrStateBase::RequireVersion(uint32_t version) { + if (version <= version_) + return; + + version_ = version; + endpoint_client_->RequireVersion(version); +} + +void AssociatedInterfacePtrStateBase::OnQueryVersion( + const base::Callback<void(uint32_t)>& callback, + uint32_t version) { + version_ = version; + callback.Run(version); +} + +void AssociatedInterfacePtrStateBase::FlushForTesting() { + endpoint_client_->FlushForTesting(); +} + +void AssociatedInterfacePtrStateBase::CloseWithReason( + uint32_t custom_reason, + const std::string& description) { + endpoint_client_->CloseWithReason(custom_reason, description); +} + +void AssociatedInterfacePtrStateBase::Swap( + AssociatedInterfacePtrStateBase* other) { + using std::swap; + swap(other->endpoint_client_, endpoint_client_); + swap(other->version_, version_); +} + +void AssociatedInterfacePtrStateBase::Bind( + ScopedInterfaceEndpointHandle handle, + uint32_t version, + std::unique_ptr<MessageReceiver> validator, + scoped_refptr<base::SequencedTaskRunner> runner) { + DCHECK(!endpoint_client_); + DCHECK_EQ(0u, version_); + DCHECK(handle.is_valid()); + + version_ = version; + // The version is only queried from the client so the value passed here + // will not be used. + endpoint_client_ = std::make_unique<InterfaceEndpointClient>( + std::move(handle), nullptr, std::move(validator), false, + GetTaskRunnerToUseFromUserProvidedTaskRunner(std::move(runner)), 0u); +} + +ScopedInterfaceEndpointHandle AssociatedInterfacePtrStateBase::PassHandle() { + auto handle = endpoint_client_->PassHandle(); + endpoint_client_.reset(); + return handle; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h index a4b51882d2..79ec2bec93 100644 --- a/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h +++ b/mojo/public/cpp/bindings/lib/associated_interface_ptr_state.h @@ -17,9 +17,10 @@ #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/memory/ref_counted.h" -#include "base/single_thread_task_runner.h" +#include "base/sequenced_task_runner.h" #include "mojo/public/cpp/bindings/associated_group.h" #include "mojo/public/cpp/bindings/associated_interface_ptr_info.h" +#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" #include "mojo/public/cpp/bindings/interface_endpoint_client.h" #include "mojo/public/cpp/bindings/interface_id.h" @@ -29,77 +30,17 @@ namespace mojo { namespace internal { -template <typename Interface> -class AssociatedInterfacePtrState { +class MOJO_CPP_BINDINGS_EXPORT AssociatedInterfacePtrStateBase { public: - AssociatedInterfacePtrState() : version_(0u) {} - - ~AssociatedInterfacePtrState() { - endpoint_client_.reset(); - proxy_.reset(); - } - - Interface* instance() { - // This will be null if the object is not bound. - return proxy_.get(); - } + AssociatedInterfacePtrStateBase(); + ~AssociatedInterfacePtrStateBase(); uint32_t version() const { return version_; } - void QueryVersion(const base::Callback<void(uint32_t)>& callback) { - // It is safe to capture |this| because the callback won't be run after this - // object goes away. - endpoint_client_->QueryVersion( - base::Bind(&AssociatedInterfacePtrState::OnQueryVersion, - base::Unretained(this), callback)); - } - - void RequireVersion(uint32_t version) { - if (version <= version_) - return; - - version_ = version; - endpoint_client_->RequireVersion(version); - } - - void FlushForTesting() { endpoint_client_->FlushForTesting(); } - - void CloseWithReason(uint32_t custom_reason, const std::string& description) { - endpoint_client_->CloseWithReason(custom_reason, description); - } - - void Swap(AssociatedInterfacePtrState* other) { - using std::swap; - swap(other->endpoint_client_, endpoint_client_); - swap(other->proxy_, proxy_); - swap(other->version_, version_); - } - - void Bind(AssociatedInterfacePtrInfo<Interface> info, - scoped_refptr<base::SingleThreadTaskRunner> runner) { - DCHECK(!endpoint_client_); - DCHECK(!proxy_); - DCHECK_EQ(0u, version_); - DCHECK(info.is_valid()); - - version_ = info.version(); - // The version is only queried from the client so the value passed here - // will not be used. - endpoint_client_.reset(new InterfaceEndpointClient( - info.PassHandle(), nullptr, - base::WrapUnique(new typename Interface::ResponseValidator_()), false, - std::move(runner), 0u)); - proxy_.reset(new Proxy(endpoint_client_.get())); - } - - // After this method is called, the object is in an invalid state and - // shouldn't be reused. - AssociatedInterfacePtrInfo<Interface> PassInterface() { - ScopedInterfaceEndpointHandle handle = endpoint_client_->PassHandle(); - endpoint_client_.reset(); - proxy_.reset(); - return AssociatedInterfacePtrInfo<Interface>(std::move(handle), version_); - } + void QueryVersion(const base::Callback<void(uint32_t)>& callback); + void RequireVersion(uint32_t version); + void FlushForTesting(); + void CloseWithReason(uint32_t custom_reason, const std::string& description); bool is_bound() const { return !!endpoint_client_; } @@ -107,15 +48,16 @@ class AssociatedInterfacePtrState { return endpoint_client_ ? endpoint_client_->encountered_error() : false; } - void set_connection_error_handler(const base::Closure& error_handler) { + void set_connection_error_handler(base::OnceClosure error_handler) { DCHECK(endpoint_client_); - endpoint_client_->set_connection_error_handler(error_handler); + endpoint_client_->set_connection_error_handler(std::move(error_handler)); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { DCHECK(endpoint_client_); - endpoint_client_->set_connection_error_with_reason_handler(error_handler); + endpoint_client_->set_connection_error_with_reason_handler( + std::move(error_handler)); } // Returns true if bound and awaiting a response to a message. @@ -134,19 +76,62 @@ class AssociatedInterfacePtrState { endpoint_client_->AcceptWithResponder(&message, std::move(responder)); } + protected: + void Swap(AssociatedInterfacePtrStateBase* other); + void Bind(ScopedInterfaceEndpointHandle handle, + uint32_t version, + std::unique_ptr<MessageReceiver> validator, + scoped_refptr<base::SequencedTaskRunner> runner); + ScopedInterfaceEndpointHandle PassHandle(); + + InterfaceEndpointClient* endpoint_client() { return endpoint_client_.get(); } + private: + void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, + uint32_t version); + + std::unique_ptr<InterfaceEndpointClient> endpoint_client_; + uint32_t version_ = 0; +}; + +template <typename Interface> +class AssociatedInterfacePtrState : public AssociatedInterfacePtrStateBase { + public: using Proxy = typename Interface::Proxy_; - void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, - uint32_t version) { - version_ = version; - callback.Run(version); + AssociatedInterfacePtrState() {} + ~AssociatedInterfacePtrState() = default; + + Proxy* instance() { + // This will be null if the object is not bound. + return proxy_.get(); } - std::unique_ptr<InterfaceEndpointClient> endpoint_client_; - std::unique_ptr<Proxy> proxy_; + void Swap(AssociatedInterfacePtrState* other) { + AssociatedInterfacePtrStateBase::Swap(other); + std::swap(other->proxy_, proxy_); + } + + void Bind(AssociatedInterfacePtrInfo<Interface> info, + scoped_refptr<base::SequencedTaskRunner> runner) { + DCHECK(!proxy_); + AssociatedInterfacePtrStateBase::Bind( + info.PassHandle(), info.version(), + std::make_unique<typename Interface::ResponseValidator_>(), + std::move(runner)); + proxy_.reset(new Proxy(endpoint_client())); + } + + // After this method is called, the object is in an invalid state and + // shouldn't be reused. + AssociatedInterfacePtrInfo<Interface> PassInterface() { + AssociatedInterfacePtrInfo<Interface> info(PassHandle(), version()); + proxy_.reset(); + return info; + } - uint32_t version_; + private: + std::unique_ptr<Proxy> proxy_; DISALLOW_COPY_AND_ASSIGN(AssociatedInterfacePtrState); }; diff --git a/mojo/public/cpp/bindings/lib/binding_state.cc b/mojo/public/cpp/bindings/lib/binding_state.cc index b34cb47e28..bb4a20f39b 100644 --- a/mojo/public/cpp/bindings/lib/binding_state.cc +++ b/mojo/public/cpp/bindings/lib/binding_state.cc @@ -4,10 +4,12 @@ #include "mojo/public/cpp/bindings/lib/binding_state.h" +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + namespace mojo { namespace internal { -BindingStateBase::BindingStateBase() = default; +BindingStateBase::BindingStateBase() : weak_ptr_factory_(this) {} BindingStateBase::~BindingStateBase() = default; @@ -24,6 +26,7 @@ void BindingStateBase::PauseIncomingMethodCallProcessing() { DCHECK(router_); router_->PauseIncomingMethodCallProcessing(); } + void BindingStateBase::ResumeIncomingMethodCallProcessing() { DCHECK(router_); router_->ResumeIncomingMethodCallProcessing(); @@ -51,6 +54,17 @@ void BindingStateBase::CloseWithReason(uint32_t custom_reason, Close(); } +ReportBadMessageCallback BindingStateBase::GetBadMessageCallback() { + return base::BindOnce( + [](ReportBadMessageCallback inner_callback, + base::WeakPtr<BindingStateBase> binding, const std::string& error) { + std::move(inner_callback).Run(error); + if (binding) + binding->Close(); + }, + mojo::GetBadMessageCallback(), weak_ptr_factory_.GetWeakPtr()); +} + void BindingStateBase::FlushForTesting() { endpoint_client_->FlushForTesting(); } @@ -60,6 +74,10 @@ void BindingStateBase::EnableTestingMode() { router_->EnableTestingMode(); } +scoped_refptr<internal::MultiplexRouter> BindingStateBase::RouterForTesting() { + return router_; +} + void BindingStateBase::BindInternal( ScopedMessagePipeHandle handle, scoped_refptr<base::SingleThreadTaskRunner> runner, @@ -69,21 +87,25 @@ void BindingStateBase::BindInternal( bool has_sync_methods, MessageReceiverWithResponderStatus* stub, uint32_t interface_version) { - DCHECK(!router_); + DCHECK(!is_bound()) << "Attempting to bind interface that is already bound: " + << interface_name; + auto sequenced_runner = + GetTaskRunnerToUseFromUserProvidedTaskRunner(std::move(runner)); MultiplexRouter::Config config = passes_associated_kinds ? MultiplexRouter::MULTI_INTERFACE : (has_sync_methods ? MultiplexRouter::SINGLE_INTERFACE_WITH_SYNC_METHODS : MultiplexRouter::SINGLE_INTERFACE); - router_ = new MultiplexRouter(std::move(handle), config, false, runner); + router_ = + new MultiplexRouter(std::move(handle), config, false, sequenced_runner); router_->SetMasterInterfaceName(interface_name); endpoint_client_.reset(new InterfaceEndpointClient( router_->CreateLocalEndpointHandle(kMasterInterfaceId), stub, - std::move(request_validator), has_sync_methods, std::move(runner), - interface_version)); + std::move(request_validator), has_sync_methods, + std::move(sequenced_runner), interface_version)); } } // namesapce internal diff --git a/mojo/public/cpp/bindings/lib/binding_state.h b/mojo/public/cpp/bindings/lib/binding_state.h index 0b0dbee002..d1c561c748 100644 --- a/mojo/public/cpp/bindings/lib/binding_state.h +++ b/mojo/public/cpp/bindings/lib/binding_state.h @@ -15,6 +15,7 @@ #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/memory/ref_counted.h" +#include "base/sequenced_task_runner.h" #include "base/single_thread_task_runner.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" @@ -50,15 +51,18 @@ class MOJO_CPP_BINDINGS_EXPORT BindingStateBase { void Close(); void CloseWithReason(uint32_t custom_reason, const std::string& description); - void set_connection_error_handler(const base::Closure& error_handler) { + void RaiseError() { endpoint_client_->RaiseError(); } + + void set_connection_error_handler(base::OnceClosure error_handler) { DCHECK(is_bound()); - endpoint_client_->set_connection_error_handler(error_handler); + endpoint_client_->set_connection_error_handler(std::move(error_handler)); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { DCHECK(is_bound()); - endpoint_client_->set_connection_error_with_reason_handler(error_handler); + endpoint_client_->set_connection_error_with_reason_handler( + std::move(error_handler)); } bool is_bound() const { return !!router_; } @@ -68,10 +72,14 @@ class MOJO_CPP_BINDINGS_EXPORT BindingStateBase { return router_->handle(); } + ReportBadMessageCallback GetBadMessageCallback(); + void FlushForTesting(); void EnableTestingMode(); + scoped_refptr<internal::MultiplexRouter> RouterForTesting(); + protected: void BindInternal(ScopedMessagePipeHandle handle, scoped_refptr<base::SingleThreadTaskRunner> runner, @@ -84,6 +92,8 @@ class MOJO_CPP_BINDINGS_EXPORT BindingStateBase { scoped_refptr<internal::MultiplexRouter> router_; std::unique_ptr<InterfaceEndpointClient> endpoint_client_; + + base::WeakPtrFactory<BindingStateBase> weak_ptr_factory_; }; template <typename Interface, typename ImplRefTraits> @@ -101,20 +111,24 @@ class BindingState : public BindingStateBase { scoped_refptr<base::SingleThreadTaskRunner> runner) { BindingStateBase::BindInternal( std::move(handle), runner, Interface::Name_, - base::MakeUnique<typename Interface::RequestValidator_>(), + std::make_unique<typename Interface::RequestValidator_>(), Interface::PassesAssociatedKinds_, Interface::HasSyncMethods_, &stub_, Interface::Version_); } InterfaceRequest<Interface> Unbind() { endpoint_client_.reset(); - InterfaceRequest<Interface> request = - MakeRequest<Interface>(router_->PassMessagePipe()); + InterfaceRequest<Interface> request(router_->PassMessagePipe()); router_ = nullptr; return request; } Interface* impl() { return ImplRefTraits::GetRawPointer(&stub_.sink()); } + ImplPointerType SwapImplForTesting(ImplPointerType new_impl) { + Interface* old_impl = impl(); + stub_.set_sink(std::move(new_impl)); + return old_impl; + } private: typename Interface::template Stub_<ImplRefTraits> stub_; diff --git a/mojo/public/cpp/bindings/lib/bindings_internal.h b/mojo/public/cpp/bindings/lib/bindings_internal.h index 631daec392..8bdb9c7b77 100644 --- a/mojo/public/cpp/bindings/lib/bindings_internal.h +++ b/mojo/public/cpp/bindings/lib/bindings_internal.h @@ -8,8 +8,9 @@ #include <stdint.h> #include <functional> +#include <type_traits> -#include "base/template_util.h" +#include "mojo/public/cpp/bindings/enum_traits.h" #include "mojo/public/cpp/bindings/interface_id.h" #include "mojo/public/cpp/bindings/lib/template_util.h" #include "mojo/public/cpp/system/core.h" @@ -34,8 +35,6 @@ class InterfaceRequestDataView; template <typename K, typename V> class MapDataView; -class NativeStructDataView; - class StringDataView; namespace internal { @@ -54,8 +53,6 @@ class Array_Data; template <typename K, typename V> class Map_Data; -class NativeStruct_Data; - using String_Data = Array_Data<char>; inline size_t Align(size_t size) { @@ -299,14 +296,6 @@ struct MojomTypeTraits<MapDataView<K, V>, false> { }; template <> -struct MojomTypeTraits<NativeStructDataView, false> { - using Data = internal::NativeStruct_Data; - using DataAsArrayElement = Pointer<Data>; - - static const MojomTypeCategory category = MojomTypeCategory::STRUCT; -}; - -template <> struct MojomTypeTraits<StringDataView, false> { using Data = String_Data; using DataAsArrayElement = Pointer<Data>; @@ -325,11 +314,19 @@ struct EnumHashImpl { static_assert(std::is_enum<T>::value, "Incorrect hash function."); size_t operator()(T input) const { - using UnderlyingType = typename base::underlying_type<T>::type; + using UnderlyingType = typename std::underlying_type<T>::type; return std::hash<UnderlyingType>()(static_cast<UnderlyingType>(input)); } }; +template <typename MojomType, typename T> +T ConvertEnumValue(MojomType input) { + T output; + bool result = EnumTraits<MojomType, T>::FromMojom(input, &output); + DCHECK(result); + return output; +} + } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/buffer.cc b/mojo/public/cpp/bindings/lib/buffer.cc new file mode 100644 index 0000000000..2444cf4e54 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/buffer.cc @@ -0,0 +1,136 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/buffer.h" + +#include "base/logging.h" +#include "base/numerics/safe_math.h" +#include "mojo/public/c/system/message_pipe.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" + +namespace mojo { +namespace internal { + +Buffer::Buffer() = default; + +Buffer::Buffer(void* data, size_t size, size_t cursor) + : data_(data), size_(size), cursor_(cursor) { + DCHECK(IsAligned(data_)); +} + +Buffer::Buffer(MessageHandle message, + size_t message_payload_size, + void* data, + size_t size) + : message_(message), + message_payload_size_(message_payload_size), + data_(data), + size_(size), + cursor_(0) { + DCHECK(IsAligned(data_)); +} + +Buffer::Buffer(Buffer&& other) { + *this = std::move(other); +} + +Buffer::~Buffer() = default; + +Buffer& Buffer::operator=(Buffer&& other) { + message_ = other.message_; + message_payload_size_ = other.message_payload_size_; + data_ = other.data_; + size_ = other.size_; + cursor_ = other.cursor_; + other.Reset(); + return *this; +} + +size_t Buffer::Allocate(size_t num_bytes) { + const size_t aligned_num_bytes = Align(num_bytes); + const size_t new_cursor = cursor_ + aligned_num_bytes; + if (new_cursor < cursor_ || (new_cursor > size_ && !message_.is_valid())) { + // Either we've overflowed or exceeded a fixed capacity. + NOTREACHED(); + return 0; + } + + if (new_cursor > size_) { + // If we have an underlying message object we can extend its payload to + // obtain more storage capacity. + DCHECK_LE(message_payload_size_, new_cursor); + size_t additional_bytes = new_cursor - message_payload_size_; + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(additional_bytes)); + uint32_t new_size; + MojoResult rv = MojoAppendMessageData( + message_.value(), static_cast<uint32_t>(additional_bytes), nullptr, 0, + nullptr, &data_, &new_size); + DCHECK_EQ(MOJO_RESULT_OK, rv); + message_payload_size_ = new_cursor; + size_ = new_size; + } + + DCHECK_LE(new_cursor, size_); + size_t block_start = cursor_; + cursor_ = new_cursor; + + // Ensure that all the allocated space is zeroed to avoid uninitialized bits + // leaking into messages. + // + // TODO(rockot): We should consider only clearing the alignment padding. This + // means being careful about generated bindings zeroing padding explicitly, + // which itself gets particularly messy with e.g. packed bool bitfields. + memset(static_cast<uint8_t*>(data_) + block_start, 0, aligned_num_bytes); + + return block_start; +} + +void Buffer::AttachHandles(std::vector<ScopedHandle>* handles) { + DCHECK(message_.is_valid()); + + uint32_t new_size = 0; + MojoResult rv = MojoAppendMessageData( + message_.value(), 0, reinterpret_cast<MojoHandle*>(handles->data()), + static_cast<uint32_t>(handles->size()), nullptr, &data_, &new_size); + if (rv != MOJO_RESULT_OK) + return; + + size_ = new_size; + for (auto& handle : *handles) + ignore_result(handle.release()); +} + +void Buffer::Seal() { + if (!message_.is_valid()) + return; + + // Ensure that the backing message has the final accumulated payload size. + DCHECK_LE(message_payload_size_, cursor_); + size_t additional_bytes = cursor_ - message_payload_size_; + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(additional_bytes)); + + MojoAppendMessageDataOptions options; + options.struct_size = sizeof(options); + options.flags = MOJO_APPEND_MESSAGE_DATA_FLAG_COMMIT_SIZE; + void* data; + uint32_t size; + MojoResult rv = MojoAppendMessageData(message_.value(), + static_cast<uint32_t>(additional_bytes), + nullptr, 0, &options, &data, &size); + DCHECK_EQ(MOJO_RESULT_OK, rv); + message_ = MessageHandle(); + message_payload_size_ = cursor_; + data_ = data; + size_ = size; +} + +void Buffer::Reset() { + message_ = MessageHandle(); + data_ = nullptr; + size_ = 0; + cursor_ = 0; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/buffer.h b/mojo/public/cpp/bindings/lib/buffer.h index 213a44590f..9f2a768490 100644 --- a/mojo/public/cpp/bindings/lib/buffer.h +++ b/mojo/public/cpp/bindings/lib/buffer.h @@ -6,60 +6,121 @@ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_BUFFER_H_ #include <stddef.h> +#include <stdint.h> -#include "base/logging.h" +#include <vector> + +#include "base/component_export.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/system/handle.h" +#include "mojo/public/cpp/system/message.h" namespace mojo { namespace internal { // Buffer provides an interface to allocate memory blocks which are 8-byte -// aligned and zero-initialized. It doesn't own the underlying memory. Users -// must ensure that the memory stays valid while using the allocated blocks from -// Buffer. -class Buffer { +// aligned. It doesn't own the underlying memory. Users must ensure that the +// memory stays valid while using the allocated blocks from Buffer. +// +// A Buffer may be moved around. A moved-from Buffer is reset and may no longer +// be used to Allocate memory unless re-Initialized. +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) Buffer { public: - Buffer() {} + // Constructs an invalid Buffer. May not call Allocate(). + Buffer(); + + // Constructs a Buffer which can Allocate() blocks from a buffer of fixed size + // |size| at |data|. Allocations start at |cursor|, so if |cursor| == |size| + // then no allocations are allowed. + // + // |data| is not owned. + Buffer(void* data, size_t size, size_t cursor); + + // Like above, but gives the Buffer an underlying message object which can + // have its payload extended to acquire more storage capacity on Allocate(). + // + // |data| and |size| must correspond to |message|'s data buffer at the time of + // construction. + // + // |payload_size| is the length of the payload as known by |message|, and it + // must be less than or equal to |size|. + // + // |message| is NOT owned and must outlive this Buffer. + Buffer(MessageHandle message, + size_t message_payload_size, + void* data, + size_t size); + + Buffer(Buffer&& other); + ~Buffer(); + + Buffer& operator=(Buffer&& other); - // The memory must have been zero-initialized. |data| must be 8-byte - // aligned. - void Initialize(void* data, size_t size) { - DCHECK(IsAligned(data)); + void* data() const { return data_; } + size_t size() const { return size_; } + size_t cursor() const { return cursor_; } - data_ = data; - size_ = size; - cursor_ = reinterpret_cast<uintptr_t>(data); - data_end_ = cursor_ + size; + bool is_valid() const { + return data_ != nullptr || (size_ == 0 && !message_.is_valid()); } - size_t size() const { return size_; } + // Allocates |num_bytes| from the buffer and returns an index to the start of + // the allocated block. The resulting index is 8-byte aligned and can be + // resolved to an address using Get<T>() below. + size_t Allocate(size_t num_bytes); + + // Returns a typed address within the Buffer corresponding to |index|. Note + // that this address is NOT stable across calls to |Allocate()| and thus must + // not be cached accordingly. + template <typename T> + T* Get(size_t index) { + DCHECK_LT(index, cursor_); + return reinterpret_cast<T*>(static_cast<uint8_t*>(data_) + index); + } - void* data() const { return data_; } + // A template helper combining Allocate() and Get<T>() above to allocate and + // return a block of size |sizeof(T)|. + template <typename T> + T* AllocateAndGet() { + return Get<T>(Allocate(sizeof(T))); + } - // Allocates |num_bytes| from the buffer and returns a pointer to the start of - // the allocated block. - // The resulting address is 8-byte aligned, and the content of the memory is - // zero-filled. - void* Allocate(size_t num_bytes) { - num_bytes = Align(num_bytes); - uintptr_t result = cursor_; - cursor_ += num_bytes; - if (cursor_ > data_end_ || cursor_ < result) { - NOTREACHED(); - cursor_ -= num_bytes; - return nullptr; - } - - return reinterpret_cast<void*>(result); + // A helper which combines Allocate() and Get<void>() for a specified number + // of bytes. + void* AllocateAndGet(size_t num_bytes) { + return Get<void>(Allocate(num_bytes)); } + // Serializes |handles| into the buffer object. Only valid to call when this + // Buffer is backed by a message object. + void AttachHandles(std::vector<ScopedHandle>* handles); + + // Seals this Buffer so it can no longer be used for allocation, and ensures + // the backing message object has a complete accounting of the size of the + // meaningful payload bytes. + void Seal(); + + // Resets the buffer to an invalid state. Can no longer be used to Allocate(). + void Reset(); + private: + MessageHandle message_; + + // The payload size from the message's internal perspective. This differs from + // |size_| as Mojo may intentionally over-allocate space to account for future + // growth. It differs from |cursor_| because we don't push payload size + // updates to the message object as frequently as we update |cursor_|, for + // performance. + size_t message_payload_size_ = 0; + + // The storage location and capacity currently backing |message_|. Owned by + // the message object internally, not by this Buffer. void* data_ = nullptr; size_t size_ = 0; - uintptr_t cursor_ = 0; - uintptr_t data_end_ = 0; + // The current write offset into |data_| if this Buffer is being used for + // message creation. + size_t cursor_ = 0; DISALLOW_COPY_AND_ASSIGN(Buffer); }; diff --git a/mojo/public/cpp/bindings/lib/connector.cc b/mojo/public/cpp/bindings/lib/connector.cc index d93e45ed93..352c51815f 100644 --- a/mojo/public/cpp/bindings/lib/connector.cc +++ b/mojo/public/cpp/bindings/lib/connector.cc @@ -5,7 +5,6 @@ #include "mojo/public/cpp/bindings/connector.h" #include <stdint.h> -#include <utility> #include "base/bind.h" #include "base/lazy_instance.h" @@ -13,23 +12,37 @@ #include "base/logging.h" #include "base/macros.h" #include "base/memory/ptr_util.h" -#include "base/message_loop/message_loop.h" +#include "base/message_loop/message_loop_current.h" +#include "base/run_loop.h" #include "base/synchronization/lock.h" #include "base/threading/thread_local.h" +#include "base/trace_event/trace_event.h" #include "mojo/public/cpp/bindings/lib/may_auto_lock.h" +#include "mojo/public/cpp/bindings/mojo_buildflags.h" #include "mojo/public/cpp/bindings/sync_handle_watcher.h" #include "mojo/public/cpp/system/wait.h" +#if defined(ENABLE_IPC_FUZZER) +#include "mojo/public/cpp/bindings/message_dumper.h" +#endif + namespace mojo { namespace { // The NestingObserver for each thread. Note that this is always a -// Connector::MessageLoopNestingObserver; we use the base type here because that +// Connector::RunLoopNestingObserver; we use the base type here because that // subclass is private to Connector. -base::LazyInstance< - base::ThreadLocalPointer<base::MessageLoop::NestingObserver>>::Leaky - g_tls_nesting_observer = LAZY_INSTANCE_INITIALIZER; +base::LazyInstance<base::ThreadLocalPointer<base::RunLoop::NestingObserver>>:: + Leaky g_tls_nesting_observer = LAZY_INSTANCE_INITIALIZER; + +// The default outgoing serialization mode for new Connectors. +Connector::OutgoingSerializationMode g_default_outgoing_serialization_mode = + Connector::OutgoingSerializationMode::kLazy; + +// The default incoming serialization mode for new Connectors. +Connector::IncomingSerializationMode g_default_incoming_serialization_mode = + Connector::IncomingSerializationMode::kDispatchAsIs; } // namespace @@ -44,7 +57,7 @@ class Connector::ActiveDispatchTracker { private: const base::WeakPtr<Connector> connector_; - MessageLoopNestingObserver* const nesting_observer_; + RunLoopNestingObserver* const nesting_observer_; ActiveDispatchTracker* outer_tracker_ = nullptr; ActiveDispatchTracker* inner_tracker_ = nullptr; @@ -52,41 +65,40 @@ class Connector::ActiveDispatchTracker { }; // Watches the MessageLoop on the current thread. Notifies the current chain of -// ActiveDispatchTrackers when a nested message loop is started. -class Connector::MessageLoopNestingObserver - : public base::MessageLoop::NestingObserver, - public base::MessageLoop::DestructionObserver { +// ActiveDispatchTrackers when a nested run loop is started. +class Connector::RunLoopNestingObserver + : public base::RunLoop::NestingObserver, + public base::MessageLoopCurrent::DestructionObserver { public: - MessageLoopNestingObserver() { - base::MessageLoop::current()->AddNestingObserver(this); - base::MessageLoop::current()->AddDestructionObserver(this); + RunLoopNestingObserver() { + base::RunLoop::AddNestingObserverOnCurrentThread(this); + base::MessageLoopCurrent::Get()->AddDestructionObserver(this); } - ~MessageLoopNestingObserver() override {} + ~RunLoopNestingObserver() override {} - // base::MessageLoop::NestingObserver: - void OnBeginNestedMessageLoop() override { + // base::RunLoop::NestingObserver: + void OnBeginNestedRunLoop() override { if (top_tracker_) top_tracker_->NotifyBeginNesting(); } - // base::MessageLoop::DestructionObserver: + // base::MessageLoopCurrent::DestructionObserver: void WillDestroyCurrentMessageLoop() override { - base::MessageLoop::current()->RemoveNestingObserver(this); - base::MessageLoop::current()->RemoveDestructionObserver(this); + base::RunLoop::RemoveNestingObserverOnCurrentThread(this); + base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this); DCHECK_EQ(this, g_tls_nesting_observer.Get().Get()); g_tls_nesting_observer.Get().Set(nullptr); delete this; } - static MessageLoopNestingObserver* GetForThread() { - if (!base::MessageLoop::current() || - !base::MessageLoop::current()->nesting_allowed()) + static RunLoopNestingObserver* GetForThread() { + if (!base::MessageLoopCurrent::Get()) return nullptr; - auto* observer = static_cast<MessageLoopNestingObserver*>( + auto* observer = static_cast<RunLoopNestingObserver*>( g_tls_nesting_observer.Get().Get()); if (!observer) { - observer = new MessageLoopNestingObserver; + observer = new RunLoopNestingObserver; g_tls_nesting_observer.Get().Set(observer); } return observer; @@ -97,7 +109,7 @@ class Connector::MessageLoopNestingObserver ActiveDispatchTracker* top_tracker_ = nullptr; - DISALLOW_COPY_AND_ASSIGN(MessageLoopNestingObserver); + DISALLOW_COPY_AND_ASSIGN(RunLoopNestingObserver); }; Connector::ActiveDispatchTracker::ActiveDispatchTracker( @@ -129,14 +141,22 @@ void Connector::ActiveDispatchTracker::NotifyBeginNesting() { Connector::Connector(ScopedMessagePipeHandle message_pipe, ConnectorConfig config, - scoped_refptr<base::SingleThreadTaskRunner> runner) + scoped_refptr<base::SequencedTaskRunner> runner) : message_pipe_(std::move(message_pipe)), task_runner_(std::move(runner)), - nesting_observer_(MessageLoopNestingObserver::GetForThread()), + error_(false), + outgoing_serialization_mode_(g_default_outgoing_serialization_mode), + incoming_serialization_mode_(g_default_incoming_serialization_mode), + nesting_observer_(RunLoopNestingObserver::GetForThread()), weak_factory_(this) { if (config == MULTI_THREADED_SEND) lock_.emplace(); +#if defined(ENABLE_IPC_FUZZER) + if (!MessageDumper::GetMessageDumpDirectory().empty()) + message_dumper_ = std::make_unique<MessageDumper>(); +#endif + weak_self_ = weak_factory_.GetWeakPtr(); // Even though we don't have an incoming receiver, we still want to monitor // the message pipe to know if is closed or encounters an error. @@ -145,23 +165,34 @@ Connector::Connector(ScopedMessagePipeHandle message_pipe, Connector::~Connector() { { - // Allow for quick destruction on any thread if the pipe is already closed. + // Allow for quick destruction on any sequence if the pipe is already + // closed. base::AutoLock lock(connected_lock_); if (!connected_) return; } - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); CancelWait(); } +void Connector::SetOutgoingSerializationMode(OutgoingSerializationMode mode) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + outgoing_serialization_mode_ = mode; +} + +void Connector::SetIncomingSerializationMode(IncomingSerializationMode mode) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + incoming_serialization_mode_ = mode; +} + void Connector::CloseMessagePipe() { // Throw away the returned message pipe. PassMessagePipe(); } ScopedMessagePipeHandle Connector::PassMessagePipe() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); CancelWait(); internal::MayAutoLock locker(&lock_); @@ -175,13 +206,13 @@ ScopedMessagePipeHandle Connector::PassMessagePipe() { } void Connector::RaiseError() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); HandleError(true, true); } bool Connector::WaitForIncomingMessage(MojoDeadline deadline) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (error_) return false; @@ -211,7 +242,7 @@ bool Connector::WaitForIncomingMessage(MojoDeadline deadline) { } void Connector::PauseIncomingMethodCallProcessing() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (paused_) return; @@ -221,7 +252,7 @@ void Connector::PauseIncomingMethodCallProcessing() { } void Connector::ResumeIncomingMethodCallProcessing() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!paused_) return; @@ -230,12 +261,18 @@ void Connector::ResumeIncomingMethodCallProcessing() { WaitToReadMore(); } +bool Connector::PrefersSerializedMessages() { + if (outgoing_serialization_mode_ == OutgoingSerializationMode::kEager) + return true; + DCHECK_EQ(OutgoingSerializationMode::kLazy, outgoing_serialization_mode_); + return peer_remoteness_tracker_ && + peer_remoteness_tracker_->last_known_state().peer_remote(); +} + bool Connector::Accept(Message* message) { - DCHECK(lock_ || thread_checker_.CalledOnValidThread()); + if (!lock_) + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); - // It shouldn't hurt even if |error_| may be changed by a different thread at - // the same time. The outcome is that we may write into |message_pipe_| after - // encountering an error, which should be fine. if (error_) return false; @@ -244,6 +281,13 @@ bool Connector::Accept(Message* message) { if (!message_pipe_.is_valid() || drop_writes_) return true; +#if defined(ENABLE_IPC_FUZZER) + if (message_dumper_ && message->is_serialized()) { + bool dump_result = message_dumper_->Accept(message); + DCHECK(dump_result); + } +#endif + MojoResult rv = WriteMessageNew(message_pipe_.get(), message->TakeMojoMessage(), MOJO_WRITE_MESSAGE_FLAG_NONE); @@ -261,10 +305,10 @@ bool Connector::Accept(Message* message) { case MOJO_RESULT_BUSY: // We'd get a "busy" result if one of the message's handles is: // - |message_pipe_|'s own handle; - // - simultaneously being used on another thread; or + // - simultaneously being used on another sequence; or // - in a "busy" state that prohibits it from being transferred (e.g., // a data pipe handle in the middle of a two-phase read/write, - // regardless of which thread that two-phase read/write is happening + // regardless of which sequence that two-phase read/write is happening // on). // TODO(vtl): I wonder if this should be a |DCHECK()|. (But, until // crbug.com/389666, etc. are resolved, this will make tests fail quickly @@ -280,7 +324,7 @@ bool Connector::Accept(Message* message) { } void Connector::AllowWokenUpBySyncWatchOnSameThread() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); allow_woken_up_by_others_ = true; @@ -289,7 +333,7 @@ void Connector::AllowWokenUpBySyncWatchOnSameThread() { } bool Connector::SyncWatch(const bool* should_stop) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (error_) return false; @@ -301,12 +345,21 @@ bool Connector::SyncWatch(const bool* should_stop) { } void Connector::SetWatcherHeapProfilerTag(const char* tag) { - heap_profiler_tag_ = tag; - if (handle_watcher_) { - handle_watcher_->set_heap_profiler_tag(tag); + if (tag) { + heap_profiler_tag_ = tag; + if (handle_watcher_) + handle_watcher_->set_heap_profiler_tag(tag); } } +// static +void Connector::OverrideDefaultSerializationBehaviorForTesting( + OutgoingSerializationMode outgoing_mode, + IncomingSerializationMode incoming_mode) { + g_default_outgoing_serialization_mode = outgoing_mode; + g_default_incoming_serialization_mode = incoming_mode; +} + void Connector::OnWatcherHandleReady(MojoResult result) { OnHandleReadyInternal(result); } @@ -324,7 +377,7 @@ void Connector::OnSyncHandleWatcherHandleReady(MojoResult result) { } void Connector::OnHandleReadyInternal(MojoResult result) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (result != MOJO_RESULT_OK) { HandleError(result != MOJO_RESULT_FAILED_PRECONDITION, false); @@ -341,12 +394,16 @@ void Connector::WaitToReadMore() { handle_watcher_.reset(new SimpleWatcher( FROM_HERE, SimpleWatcher::ArmingPolicy::MANUAL, task_runner_)); - if (heap_profiler_tag_) - handle_watcher_->set_heap_profiler_tag(heap_profiler_tag_); + handle_watcher_->set_heap_profiler_tag(heap_profiler_tag_); MojoResult rv = handle_watcher_->Watch( message_pipe_.get(), MOJO_HANDLE_SIGNAL_READABLE, base::Bind(&Connector::OnWatcherHandleReady, base::Unretained(this))); + if (message_pipe_.is_valid()) { + peer_remoteness_tracker_.emplace(message_pipe_.get(), + MOJO_HANDLE_SIGNAL_PEER_REMOTE); + } + if (rv != MOJO_RESULT_OK) { // If the watch failed because the handle is invalid or its conditions can // no longer be met, we signal the error asynchronously to avoid reentry. @@ -383,6 +440,19 @@ bool Connector::ReadSingleMessage(MojoResult* read_result) { dispatch_tracker.emplace(weak_self); } + if (incoming_serialization_mode_ == + IncomingSerializationMode::kSerializeBeforeDispatchForTesting) { + message.SerializeIfNecessary(); + } else { + DCHECK_EQ(IncomingSerializationMode::kDispatchAsIs, + incoming_serialization_mode_); + } + +#if !BUILDFLAG(MOJO_TRACE_ENABLED) + // This emits just full class name, and is inferior to mojo tracing. + TRACE_EVENT0("mojom", heap_profiler_tag_); +#endif + receiver_result = incoming_receiver_ && incoming_receiver_->Accept(&message); @@ -443,6 +513,7 @@ void Connector::ReadAllAvailableMessages() { } void Connector::CancelWait() { + peer_remoteness_tracker_.reset(); handle_watcher_.reset(); sync_watcher_.reset(); } @@ -476,8 +547,8 @@ void Connector::HandleError(bool force_pipe_reset, bool force_async_handler) { WaitToReadMore(); } else { error_ = true; - if (!connection_error_handler_.is_null()) - connection_error_handler_.Run(); + if (connection_error_handler_) + std::move(connection_error_handler_).Run(); } } diff --git a/mojo/public/cpp/bindings/lib/control_message_handler.cc b/mojo/public/cpp/bindings/lib/control_message_handler.cc index 1b7bb78e5f..b87c11c874 100644 --- a/mojo/public/cpp/bindings/lib/control_message_handler.cc +++ b/mojo/public/cpp/bindings/lib/control_message_handler.cc @@ -10,9 +10,9 @@ #include "base/logging.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" #include "mojo/public/cpp/bindings/lib/serialization.h" #include "mojo/public/cpp/bindings/lib/validation_util.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/interfaces/bindings/interface_control_messages.mojom.h" namespace mojo { @@ -115,19 +115,15 @@ bool ControlMessageHandler::Run( auto response_params_ptr = interface_control::RunResponseMessageParams::New(); response_params_ptr->output = std::move(output); - size_t size = - PrepareToSerialize<interface_control::RunResponseMessageParamsDataView>( - response_params_ptr, &context_); - MessageBuilder builder(interface_control::kRunMessageId, - Message::kFlagIsResponse, size, 0); - builder.message()->set_request_id(message->request_id()); - - interface_control::internal::RunResponseMessageParams_Data* response_params = - nullptr; + Message response_message(interface_control::kRunMessageId, + Message::kFlagIsResponse, 0, 0, nullptr); + response_message.set_request_id(message->request_id()); + interface_control::internal::RunResponseMessageParams_Data::BufferWriter + response_params; Serialize<interface_control::RunResponseMessageParamsDataView>( - response_params_ptr, builder.buffer(), &response_params, &context_); - ignore_result(responder->Accept(builder.message())); - + response_params_ptr, response_message.payload_buffer(), &response_params, + &context_); + ignore_result(responder->Accept(&response_message)); return true; } diff --git a/mojo/public/cpp/bindings/lib/control_message_handler.h b/mojo/public/cpp/bindings/lib/control_message_handler.h index 5d1f716ea8..daa884bb52 100644 --- a/mojo/public/cpp/bindings/lib/control_message_handler.h +++ b/mojo/public/cpp/bindings/lib/control_message_handler.h @@ -18,7 +18,7 @@ namespace internal { // Handlers for request messages defined in interface_control_messages.mojom. class MOJO_CPP_BINDINGS_EXPORT ControlMessageHandler - : NON_EXPORTED_BASE(public MessageReceiverWithResponderStatus) { + : public MessageReceiverWithResponderStatus { public: static bool IsControlMessage(const Message* message); diff --git a/mojo/public/cpp/bindings/lib/control_message_proxy.cc b/mojo/public/cpp/bindings/lib/control_message_proxy.cc index d082b49fb3..9fd7bf4173 100644 --- a/mojo/public/cpp/bindings/lib/control_message_proxy.cc +++ b/mojo/public/cpp/bindings/lib/control_message_proxy.cc @@ -12,7 +12,6 @@ #include "base/callback_helpers.h" #include "base/macros.h" #include "base/run_loop.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" #include "mojo/public/cpp/bindings/lib/serialization.h" #include "mojo/public/cpp/bindings/lib/validation_util.h" #include "mojo/public/cpp/bindings/message.h" @@ -73,49 +72,37 @@ bool RunResponseForwardToCallback::Accept(Message* message) { void SendRunMessage(MessageReceiverWithResponder* receiver, interface_control::RunInputPtr input_ptr, const RunCallback& callback) { - SerializationContext context; - auto params_ptr = interface_control::RunMessageParams::New(); params_ptr->input = std::move(input_ptr); - size_t size = PrepareToSerialize<interface_control::RunMessageParamsDataView>( - params_ptr, &context); - MessageBuilder builder(interface_control::kRunMessageId, - Message::kFlagExpectsResponse, size, 0); - - interface_control::internal::RunMessageParams_Data* params = nullptr; + Message message(interface_control::kRunMessageId, + Message::kFlagExpectsResponse, 0, 0, nullptr); + SerializationContext context; + interface_control::internal::RunMessageParams_Data::BufferWriter params; Serialize<interface_control::RunMessageParamsDataView>( - params_ptr, builder.buffer(), ¶ms, &context); + params_ptr, message.payload_buffer(), ¶ms, &context); std::unique_ptr<MessageReceiver> responder = - base::MakeUnique<RunResponseForwardToCallback>(callback); - ignore_result( - receiver->AcceptWithResponder(builder.message(), std::move(responder))); + std::make_unique<RunResponseForwardToCallback>(callback); + ignore_result(receiver->AcceptWithResponder(&message, std::move(responder))); } Message ConstructRunOrClosePipeMessage( interface_control::RunOrClosePipeInputPtr input_ptr) { - SerializationContext context; - auto params_ptr = interface_control::RunOrClosePipeMessageParams::New(); params_ptr->input = std::move(input_ptr); - - size_t size = PrepareToSerialize< - interface_control::RunOrClosePipeMessageParamsDataView>(params_ptr, - &context); - MessageBuilder builder(interface_control::kRunOrClosePipeMessageId, 0, size, - 0); - - interface_control::internal::RunOrClosePipeMessageParams_Data* params = - nullptr; + Message message(interface_control::kRunOrClosePipeMessageId, 0, 0, 0, + nullptr); + SerializationContext context; + interface_control::internal::RunOrClosePipeMessageParams_Data::BufferWriter + params; Serialize<interface_control::RunOrClosePipeMessageParamsDataView>( - params_ptr, builder.buffer(), ¶ms, &context); - return std::move(*builder.message()); + params_ptr, message.payload_buffer(), ¶ms, &context); + return message; } void SendRunOrClosePipeMessage( MessageReceiverWithResponder* receiver, interface_control::RunOrClosePipeInputPtr input_ptr) { Message message(ConstructRunOrClosePipeMessage(std::move(input_ptr))); - ignore_result(receiver->Accept(&message)); } @@ -163,7 +150,7 @@ void ControlMessageProxy::FlushForTesting() { auto input_ptr = interface_control::RunInput::New(); input_ptr->set_flush_for_testing(interface_control::FlushForTesting::New()); - base::RunLoop run_loop; + base::RunLoop run_loop(base::RunLoop::Type::kNestableTasksAllowed); run_loop_quit_closure_ = run_loop.QuitClosure(); SendRunMessage( receiver_, std::move(input_ptr), diff --git a/mojo/public/cpp/bindings/lib/fixed_buffer.cc b/mojo/public/cpp/bindings/lib/fixed_buffer.cc index 725a193cd7..3d595cc063 100644 --- a/mojo/public/cpp/bindings/lib/fixed_buffer.cc +++ b/mojo/public/cpp/bindings/lib/fixed_buffer.cc @@ -6,25 +6,17 @@ #include <stdlib.h> +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" + namespace mojo { namespace internal { -FixedBufferForTesting::FixedBufferForTesting(size_t size) { - size = internal::Align(size); - // Use calloc here to ensure all message memory is zero'd out. - void* ptr = calloc(size, 1); - Initialize(ptr, size); -} +FixedBufferForTesting::FixedBufferForTesting(size_t size) + : Buffer(calloc(Align(size), 1), Align(size), 0) {} FixedBufferForTesting::~FixedBufferForTesting() { free(data()); } -void* FixedBufferForTesting::Leak() { - void* ptr = data(); - Initialize(nullptr, 0); - return ptr; -} - } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/fixed_buffer.h b/mojo/public/cpp/bindings/lib/fixed_buffer.h index 070b0c8cef..147ce7b115 100644 --- a/mojo/public/cpp/bindings/lib/fixed_buffer.h +++ b/mojo/public/cpp/bindings/lib/fixed_buffer.h @@ -5,11 +5,10 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_FIXED_BUFFER_H_ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_FIXED_BUFFER_H_ -#include <stddef.h> +#include <cstddef> -#include "base/compiler_specific.h" +#include "base/component_export.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/buffer.h" namespace mojo { @@ -17,18 +16,12 @@ namespace internal { // FixedBufferForTesting owns its buffer. The Leak method may be used to steal // the underlying memory. -class MOJO_CPP_BINDINGS_EXPORT FixedBufferForTesting - : NON_EXPORTED_BASE(public Buffer) { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) FixedBufferForTesting + : public Buffer { public: explicit FixedBufferForTesting(size_t size); ~FixedBufferForTesting(); - // Returns the internal memory owned by the Buffer to the caller. The Buffer - // relinquishes its pointer, effectively resetting the state of the Buffer - // and leaving the caller responsible for freeing the returned memory address - // when no longer needed. - void* Leak(); - private: DISALLOW_COPY_AND_ASSIGN(FixedBufferForTesting); }; diff --git a/mojo/public/cpp/bindings/lib/handle_serialization.h b/mojo/public/cpp/bindings/lib/handle_serialization.h new file mode 100644 index 0000000000..6e1294e0a2 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/handle_serialization.h @@ -0,0 +1,35 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_SERIALIZATION_H_ + +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_context.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" +#include "mojo/public/cpp/system/handle.h" + +namespace mojo { +namespace internal { + +template <typename T> +struct Serializer<ScopedHandleBase<T>, ScopedHandleBase<T>> { + static void Serialize(ScopedHandleBase<T>& input, + Handle_Data* output, + SerializationContext* context) { + context->AddHandle(ScopedHandle::From(std::move(input)), output); + } + + static bool Deserialize(Handle_Data* input, + ScopedHandleBase<T>* output, + SerializationContext* context) { + *output = context->TakeHandleAs<T>(*input); + return true; + } +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc index 4682e72fad..6f119e4c1d 100644 --- a/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc +++ b/mojo/public/cpp/bindings/lib/interface_endpoint_client.cc @@ -6,18 +6,17 @@ #include <stdint.h> -#include <utility> - #include "base/bind.h" #include "base/location.h" #include "base/logging.h" #include "base/macros.h" #include "base/memory/ptr_util.h" -#include "base/single_thread_task_runner.h" +#include "base/sequenced_task_runner.h" #include "base/stl_util.h" #include "mojo/public/cpp/bindings/associated_group.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/interface_endpoint_controller.h" +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" #include "mojo/public/cpp/bindings/lib/validation_util.h" #include "mojo/public/cpp/bindings/sync_call_restrictions.h" @@ -27,10 +26,10 @@ namespace mojo { namespace { -void DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient>& client, - const std::string& message) { - bool is_valid = client && !client->encountered_error(); - DCHECK(!is_valid) << message; +void DetermineIfEndpointIsConnected( + const base::WeakPtr<InterfaceEndpointClient>& client, + base::OnceCallback<void(bool)> callback) { + std::move(callback).Run(client && !client->encountered_error()); } // When receiving an incoming message which expects a repsonse, @@ -41,7 +40,7 @@ class ResponderThunk : public MessageReceiverWithStatus { public: explicit ResponderThunk( const base::WeakPtr<InterfaceEndpointClient>& endpoint_client, - scoped_refptr<base::SingleThreadTaskRunner> runner) + scoped_refptr<base::SequencedTaskRunner> runner) : endpoint_client_(endpoint_client), accept_was_invoked_(false), task_runner_(std::move(runner)) {} @@ -52,7 +51,7 @@ class ResponderThunk : public MessageReceiverWithStatus { // We raise an error to signal the calling application that an error // condition occurred. Without this the calling application would have no // way of knowing it should stop waiting for a response. - if (task_runner_->RunsTasksOnCurrentThread()) { + if (task_runner_->RunsTasksInCurrentSequence()) { // Please note that even if this code is run from a different task // runner on the same thread as |task_runner_|, it is okay to directly // call InterfaceEndpointClient::RaiseError(), because it will raise @@ -69,8 +68,12 @@ class ResponderThunk : public MessageReceiverWithStatus { } // MessageReceiver implementation: + bool PrefersSerializedMessages() override { + return endpoint_client_ && endpoint_client_->PrefersSerializedMessages(); + } + bool Accept(Message* message) override { - DCHECK(task_runner_->RunsTasksOnCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); accept_was_invoked_ = true; DCHECK(message->has_flag(Message::kFlagIsResponse)); @@ -83,24 +86,25 @@ class ResponderThunk : public MessageReceiverWithStatus { } // MessageReceiverWithStatus implementation: - bool IsValid() override { - DCHECK(task_runner_->RunsTasksOnCurrentThread()); + bool IsConnected() override { + DCHECK(task_runner_->RunsTasksInCurrentSequence()); return endpoint_client_ && !endpoint_client_->encountered_error(); } - void DCheckInvalid(const std::string& message) override { - if (task_runner_->RunsTasksOnCurrentThread()) { - DCheckIfInvalid(endpoint_client_, message); + void IsConnectedAsync(base::OnceCallback<void(bool)> callback) override { + if (task_runner_->RunsTasksInCurrentSequence()) { + DetermineIfEndpointIsConnected(endpoint_client_, std::move(callback)); } else { task_runner_->PostTask( - FROM_HERE, base::Bind(&DCheckIfInvalid, endpoint_client_, message)); + FROM_HERE, base::BindOnce(&DetermineIfEndpointIsConnected, + endpoint_client_, std::move(callback))); } - } + } private: base::WeakPtr<InterfaceEndpointClient> endpoint_client_; bool accept_was_invoked_; - scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; DISALLOW_COPY_AND_ASSIGN(ResponderThunk); }; @@ -136,7 +140,7 @@ InterfaceEndpointClient::InterfaceEndpointClient( MessageReceiverWithResponderStatus* receiver, std::unique_ptr<MessageReceiver> payload_validator, bool expect_sync_requests, - scoped_refptr<base::SingleThreadTaskRunner> runner, + scoped_refptr<base::SequencedTaskRunner> runner, uint32_t interface_version) : expect_sync_requests_(expect_sync_requests), handle_(std::move(handle)), @@ -163,20 +167,19 @@ InterfaceEndpointClient::InterfaceEndpointClient( } InterfaceEndpointClient::~InterfaceEndpointClient() { - DCHECK(thread_checker_.CalledOnValidThread()); - + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (controller_) handle_.group_controller()->DetachEndpointClient(handle_); } AssociatedGroup* InterfaceEndpointClient::associated_group() { if (!associated_group_) - associated_group_ = base::MakeUnique<AssociatedGroup>(handle_); + associated_group_ = std::make_unique<AssociatedGroup>(handle_); return associated_group_.get(); } ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(!has_pending_responders()); if (!handle_.is_valid()) @@ -199,7 +202,7 @@ void InterfaceEndpointClient::AddFilter( } void InterfaceEndpointClient::RaiseError() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!handle_.pending_association()) handle_.group_controller()->RaiseError(); @@ -207,14 +210,19 @@ void InterfaceEndpointClient::RaiseError() { void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason, const std::string& description) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); auto handle = PassHandle(); handle.ResetWithReason(custom_reason, description); } +bool InterfaceEndpointClient::PrefersSerializedMessages() { + auto* controller = handle_.group_controller(); + return controller && controller->PrefersSerializedMessages(); +} + bool InterfaceEndpointClient::Accept(Message* message) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(!message->has_flag(Message::kFlagExpectsResponse)); DCHECK(!handle_.pending_association()); @@ -237,7 +245,7 @@ bool InterfaceEndpointClient::Accept(Message* message) { bool InterfaceEndpointClient::AcceptWithResponder( Message* message, std::unique_ptr<MessageReceiver> responder) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(message->has_flag(Message::kFlagExpectsResponse)); DCHECK(!handle_.pending_association()); @@ -270,7 +278,7 @@ bool InterfaceEndpointClient::AcceptWithResponder( bool response_received = false; sync_responses_.insert(std::make_pair( - request_id, base::MakeUnique<SyncResponseInfo>(&response_received))); + request_id, std::make_unique<SyncResponseInfo>(&response_received))); base::WeakPtr<InterfaceEndpointClient> weak_self = weak_ptr_factory_.GetWeakPtr(); @@ -280,8 +288,13 @@ bool InterfaceEndpointClient::AcceptWithResponder( DCHECK(base::ContainsKey(sync_responses_, request_id)); auto iter = sync_responses_.find(request_id); DCHECK_EQ(&response_received, iter->second->response_received); - if (response_received) + if (response_received) { ignore_result(responder->Accept(&iter->second->response)); + } else { + DVLOG(1) << "Mojo sync call returns without receiving a response. " + << "Typcially it is because the interface has been " + << "disconnected."; + } sync_responses_.erase(iter); } @@ -289,13 +302,13 @@ bool InterfaceEndpointClient::AcceptWithResponder( } bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return filters_.Accept(message); } void InterfaceEndpointClient::NotifyError( const base::Optional<DisconnectReason>& reason) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (encountered_error_) return; @@ -309,16 +322,14 @@ void InterfaceEndpointClient::NotifyError( control_message_proxy_.OnConnectionError(); - if (!error_handler_.is_null()) { - base::Closure error_handler = std::move(error_handler_); - error_handler.Run(); - } else if (!error_with_reason_handler_.is_null()) { - ConnectionErrorWithReasonCallback error_with_reason_handler = - std::move(error_with_reason_handler_); + if (error_handler_) { + std::move(error_handler_).Run(); + } else if (error_with_reason_handler_) { if (reason) { - error_with_reason_handler.Run(reason->custom_reason, reason->description); + std::move(error_with_reason_handler_) + .Run(reason->custom_reason, reason->description); } else { - error_with_reason_handler.Run(0, std::string()); + std::move(error_with_reason_handler_).Run(0, std::string()); } } } @@ -374,7 +385,7 @@ bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) { if (message->has_flag(Message::kFlagExpectsResponse)) { std::unique_ptr<MessageReceiverWithStatus> responder = - base::MakeUnique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(), + std::make_unique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(), task_runner_); if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) { return control_message_handler_.AcceptWithResponder(message, diff --git a/mojo/public/cpp/bindings/lib/interface_ptr_state.cc b/mojo/public/cpp/bindings/lib/interface_ptr_state.cc new file mode 100644 index 0000000000..8cd23ea067 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/interface_ptr_state.cc @@ -0,0 +1,94 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/interface_ptr_state.h" + +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + +namespace mojo { +namespace internal { + +InterfacePtrStateBase::InterfacePtrStateBase() = default; + +InterfacePtrStateBase::~InterfacePtrStateBase() { + endpoint_client_.reset(); + if (router_) + router_->CloseMessagePipe(); +} + +void InterfacePtrStateBase::QueryVersion( + const base::Callback<void(uint32_t)>& callback) { + // It is safe to capture |this| because the callback won't be run after this + // object goes away. + endpoint_client_->QueryVersion( + base::Bind(&InterfacePtrStateBase::OnQueryVersion, base::Unretained(this), + callback)); +} + +void InterfacePtrStateBase::RequireVersion(uint32_t version) { + if (version <= version_) + return; + + version_ = version; + endpoint_client_->RequireVersion(version); +} + +void InterfacePtrStateBase::Swap(InterfacePtrStateBase* other) { + using std::swap; + swap(other->router_, router_); + swap(other->endpoint_client_, endpoint_client_); + handle_.swap(other->handle_); + runner_.swap(other->runner_); + swap(other->version_, version_); +} + +void InterfacePtrStateBase::Bind( + ScopedMessagePipeHandle handle, + uint32_t version, + scoped_refptr<base::SequencedTaskRunner> task_runner) { + DCHECK(!router_); + DCHECK(!endpoint_client_); + DCHECK(!handle_.is_valid()); + DCHECK_EQ(0u, version_); + DCHECK(handle.is_valid()); + + handle_ = std::move(handle); + version_ = version; + runner_ = + GetTaskRunnerToUseFromUserProvidedTaskRunner(std::move(task_runner)); +} + +void InterfacePtrStateBase::OnQueryVersion( + const base::Callback<void(uint32_t)>& callback, + uint32_t version) { + version_ = version; + callback.Run(version); +} + +bool InterfacePtrStateBase::InitializeEndpointClient( + bool passes_associated_kinds, + bool has_sync_methods, + std::unique_ptr<MessageReceiver> payload_validator) { + // The object hasn't been bound. + if (!handle_.is_valid()) + return false; + + MultiplexRouter::Config config = + passes_associated_kinds + ? MultiplexRouter::MULTI_INTERFACE + : (has_sync_methods + ? MultiplexRouter::SINGLE_INTERFACE_WITH_SYNC_METHODS + : MultiplexRouter::SINGLE_INTERFACE); + router_ = new MultiplexRouter(std::move(handle_), config, true, runner_); + endpoint_client_.reset(new InterfaceEndpointClient( + router_->CreateLocalEndpointHandle(kMasterInterfaceId), nullptr, + std::move(payload_validator), false, std::move(runner_), + // The version is only queried from the client so the value passed here + // will not be used. + 0u)); + return true; +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/interface_ptr_state.h b/mojo/public/cpp/bindings/lib/interface_ptr_state.h index fa54979795..2e73564a80 100644 --- a/mojo/public/cpp/bindings/lib/interface_ptr_state.h +++ b/mojo/public/cpp/bindings/lib/interface_ptr_state.h @@ -18,8 +18,9 @@ #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/memory/ref_counted.h" -#include "base/single_thread_task_runner.h" +#include "base/sequenced_task_runner.h" #include "mojo/public/cpp/bindings/associated_group.h" +#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/connection_error_callback.h" #include "mojo/public/cpp/bindings/filter_chain.h" #include "mojo/public/cpp/bindings/interface_endpoint_client.h" @@ -32,191 +33,190 @@ namespace mojo { namespace internal { -template <typename Interface> -class InterfacePtrState { +class MOJO_CPP_BINDINGS_EXPORT InterfacePtrStateBase { public: - InterfacePtrState() : version_(0u) {} + InterfacePtrStateBase(); + ~InterfacePtrStateBase(); + + MessagePipeHandle handle() const { + return router_ ? router_->handle() : handle_.get(); + } + + uint32_t version() const { return version_; } + + bool is_bound() const { return handle_.is_valid() || endpoint_client_; } + + bool encountered_error() const { + return endpoint_client_ ? endpoint_client_->encountered_error() : false; + } + + bool HasAssociatedInterfaces() const { + return router_ ? router_->HasAssociatedEndpoints() : false; + } + + // Returns true if bound and awaiting a response to a message. + bool has_pending_callbacks() const { + return endpoint_client_ && endpoint_client_->has_pending_responders(); + } + + protected: + InterfaceEndpointClient* endpoint_client() const { + return endpoint_client_.get(); + } + MultiplexRouter* router() const { return router_.get(); } - ~InterfacePtrState() { + void QueryVersion(const base::Callback<void(uint32_t)>& callback); + void RequireVersion(uint32_t version); + void Swap(InterfacePtrStateBase* other); + void Bind(ScopedMessagePipeHandle handle, + uint32_t version, + scoped_refptr<base::SequencedTaskRunner> task_runner); + + ScopedMessagePipeHandle PassMessagePipe() { endpoint_client_.reset(); - proxy_.reset(); - if (router_) - router_->CloseMessagePipe(); + return router_ ? router_->PassMessagePipe() : std::move(handle_); } - Interface* instance() { + bool InitializeEndpointClient( + bool passes_associated_kinds, + bool has_sync_methods, + std::unique_ptr<MessageReceiver> payload_validator); + + private: + void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, + uint32_t version); + + scoped_refptr<MultiplexRouter> router_; + + std::unique_ptr<InterfaceEndpointClient> endpoint_client_; + + // |router_| (as well as other members above) is not initialized until + // read/write with the message pipe handle is needed. |handle_| is valid + // between the Bind() call and the initialization of |router_|. + ScopedMessagePipeHandle handle_; + scoped_refptr<base::SequencedTaskRunner> runner_; + + uint32_t version_ = 0; + + DISALLOW_COPY_AND_ASSIGN(InterfacePtrStateBase); +}; + +template <typename Interface> +class InterfacePtrState : public InterfacePtrStateBase { + public: + using Proxy = typename Interface::Proxy_; + + InterfacePtrState() = default; + ~InterfacePtrState() = default; + + Proxy* instance() { ConfigureProxyIfNecessary(); // This will be null if the object is not bound. return proxy_.get(); } - uint32_t version() const { return version_; } - void QueryVersion(const base::Callback<void(uint32_t)>& callback) { ConfigureProxyIfNecessary(); - - // It is safe to capture |this| because the callback won't be run after this - // object goes away. - endpoint_client_->QueryVersion(base::Bind( - &InterfacePtrState::OnQueryVersion, base::Unretained(this), callback)); + InterfacePtrStateBase::QueryVersion(callback); } void RequireVersion(uint32_t version) { ConfigureProxyIfNecessary(); - - if (version <= version_) - return; - - version_ = version; - endpoint_client_->RequireVersion(version); + InterfacePtrStateBase::RequireVersion(version); } void FlushForTesting() { ConfigureProxyIfNecessary(); - endpoint_client_->FlushForTesting(); + endpoint_client()->FlushForTesting(); } void CloseWithReason(uint32_t custom_reason, const std::string& description) { ConfigureProxyIfNecessary(); - endpoint_client_->CloseWithReason(custom_reason, description); + endpoint_client()->CloseWithReason(custom_reason, description); } void Swap(InterfacePtrState* other) { using std::swap; - swap(other->router_, router_); - swap(other->endpoint_client_, endpoint_client_); swap(other->proxy_, proxy_); - handle_.swap(other->handle_); - runner_.swap(other->runner_); - swap(other->version_, version_); + InterfacePtrStateBase::Swap(other); } void Bind(InterfacePtrInfo<Interface> info, - scoped_refptr<base::SingleThreadTaskRunner> runner) { - DCHECK(!router_); - DCHECK(!endpoint_client_); + scoped_refptr<base::SequencedTaskRunner> runner) { DCHECK(!proxy_); - DCHECK(!handle_.is_valid()); - DCHECK_EQ(0u, version_); - DCHECK(info.is_valid()); - - handle_ = info.PassHandle(); - version_ = info.version(); - runner_ = std::move(runner); - } - - bool HasAssociatedInterfaces() const { - return router_ ? router_->HasAssociatedEndpoints() : false; + InterfacePtrStateBase::Bind(info.PassHandle(), info.version(), + std::move(runner)); } // After this method is called, the object is in an invalid state and // shouldn't be reused. InterfacePtrInfo<Interface> PassInterface() { - endpoint_client_.reset(); proxy_.reset(); - return InterfacePtrInfo<Interface>( - router_ ? router_->PassMessagePipe() : std::move(handle_), version_); + return InterfacePtrInfo<Interface>(PassMessagePipe(), version()); } - bool is_bound() const { return handle_.is_valid() || endpoint_client_; } - - bool encountered_error() const { - return endpoint_client_ ? endpoint_client_->encountered_error() : false; - } - - void set_connection_error_handler(const base::Closure& error_handler) { + void set_connection_error_handler(base::OnceClosure error_handler) { ConfigureProxyIfNecessary(); - DCHECK(endpoint_client_); - endpoint_client_->set_connection_error_handler(error_handler); + DCHECK(endpoint_client()); + endpoint_client()->set_connection_error_handler(std::move(error_handler)); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { ConfigureProxyIfNecessary(); - DCHECK(endpoint_client_); - endpoint_client_->set_connection_error_with_reason_handler(error_handler); - } - - // Returns true if bound and awaiting a response to a message. - bool has_pending_callbacks() const { - return endpoint_client_ && endpoint_client_->has_pending_responders(); + DCHECK(endpoint_client()); + endpoint_client()->set_connection_error_with_reason_handler( + std::move(error_handler)); } AssociatedGroup* associated_group() { ConfigureProxyIfNecessary(); - return endpoint_client_->associated_group(); + return endpoint_client()->associated_group(); } void EnableTestingMode() { ConfigureProxyIfNecessary(); - router_->EnableTestingMode(); + router()->EnableTestingMode(); } void ForwardMessage(Message message) { ConfigureProxyIfNecessary(); - endpoint_client_->Accept(&message); + endpoint_client()->Accept(&message); } void ForwardMessageWithResponder(Message message, std::unique_ptr<MessageReceiver> responder) { ConfigureProxyIfNecessary(); - endpoint_client_->AcceptWithResponder(&message, std::move(responder)); + endpoint_client()->AcceptWithResponder(&message, std::move(responder)); } - private: - using Proxy = typename Interface::Proxy_; + void RaiseError() { + ConfigureProxyIfNecessary(); + endpoint_client()->RaiseError(); + } + private: void ConfigureProxyIfNecessary() { // The proxy has been configured. if (proxy_) { - DCHECK(router_); - DCHECK(endpoint_client_); + DCHECK(router()); + DCHECK(endpoint_client()); return; } - // The object hasn't been bound. - if (!handle_.is_valid()) - return; - MultiplexRouter::Config config = - Interface::PassesAssociatedKinds_ - ? MultiplexRouter::MULTI_INTERFACE - : (Interface::HasSyncMethods_ - ? MultiplexRouter::SINGLE_INTERFACE_WITH_SYNC_METHODS - : MultiplexRouter::SINGLE_INTERFACE); - router_ = new MultiplexRouter(std::move(handle_), config, true, runner_); - router_->SetMasterInterfaceName(Interface::Name_); - endpoint_client_.reset(new InterfaceEndpointClient( - router_->CreateLocalEndpointHandle(kMasterInterfaceId), nullptr, - base::WrapUnique(new typename Interface::ResponseValidator_()), false, - std::move(runner_), - // The version is only queried from the client so the value passed here - // will not be used. - 0u)); - proxy_.reset(new Proxy(endpoint_client_.get())); - } - - void OnQueryVersion(const base::Callback<void(uint32_t)>& callback, - uint32_t version) { - version_ = version; - callback.Run(version); + if (InitializeEndpointClient( + Interface::PassesAssociatedKinds_, Interface::HasSyncMethods_, + std::make_unique<typename Interface::ResponseValidator_>())) { + router()->SetMasterInterfaceName(Interface::Name_); + proxy_ = std::make_unique<Proxy>(endpoint_client()); + } } - scoped_refptr<MultiplexRouter> router_; - - std::unique_ptr<InterfaceEndpointClient> endpoint_client_; std::unique_ptr<Proxy> proxy_; - // |router_| (as well as other members above) is not initialized until - // read/write with the message pipe handle is needed. |handle_| is valid - // between the Bind() call and the initialization of |router_|. - ScopedMessagePipeHandle handle_; - scoped_refptr<base::SingleThreadTaskRunner> runner_; - - uint32_t version_; - DISALLOW_COPY_AND_ASSIGN(InterfacePtrState); }; diff --git a/mojo/public/cpp/bindings/lib/handle_interface_serialization.h b/mojo/public/cpp/bindings/lib/interface_serialization.h index 14ed21f0ac..00954de261 100644 --- a/mojo/public/cpp/bindings/lib/handle_interface_serialization.h +++ b/mojo/public/cpp/bindings/lib/interface_serialization.h @@ -1,9 +1,9 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. +// Copyright 2018 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_INTERFACE_SERIALIZATION_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_INTERFACE_SERIALIZATION_H_ +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_INTERFACE_SERIALIZATION_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_INTERFACE_SERIALIZATION_H_ #include <type_traits> @@ -17,6 +17,7 @@ #include "mojo/public/cpp/bindings/lib/serialization_context.h" #include "mojo/public/cpp/bindings/lib/serialization_forward.h" #include "mojo/public/cpp/system/handle.h" +#include "mojo/public/cpp/system/message_pipe.h" namespace mojo { namespace internal { @@ -26,40 +27,24 @@ struct Serializer<AssociatedInterfacePtrInfoDataView<Base>, AssociatedInterfacePtrInfo<T>> { static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static size_t PrepareToSerialize(const AssociatedInterfacePtrInfo<T>& input, - SerializationContext* context) { - if (input.handle().is_valid()) - context->associated_endpoint_count++; - return 0; - } - static void Serialize(AssociatedInterfacePtrInfo<T>& input, AssociatedInterface_Data* output, SerializationContext* context) { DCHECK(!input.handle().is_valid() || input.handle().pending_association()); - if (input.handle().is_valid()) { - // Set to the index of the element pushed to the back of the vector. - output->handle.value = - static_cast<uint32_t>(context->associated_endpoint_handles.size()); - context->associated_endpoint_handles.push_back(input.PassHandle()); - } else { - output->handle.value = kEncodedInvalidHandleValue; - } - output->version = input.version(); + context->AddAssociatedInterfaceInfo(input.PassHandle(), input.version(), + output); } static bool Deserialize(AssociatedInterface_Data* input, AssociatedInterfacePtrInfo<T>* output, SerializationContext* context) { - if (input->handle.is_valid()) { - DCHECK_LT(input->handle.value, - context->associated_endpoint_handles.size()); - output->set_handle( - std::move(context->associated_endpoint_handles[input->handle.value])); + auto handle = context->TakeAssociatedEndpointHandle(input->handle); + if (!handle.is_valid()) { + *output = AssociatedInterfacePtrInfo<T>(); } else { - output->set_handle(ScopedInterfaceEndpointHandle()); + output->set_handle(std::move(handle)); + output->set_version(input->version); } - output->set_version(input->version); return true; } }; @@ -69,37 +54,21 @@ struct Serializer<AssociatedInterfaceRequestDataView<Base>, AssociatedInterfaceRequest<T>> { static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static size_t PrepareToSerialize(const AssociatedInterfaceRequest<T>& input, - SerializationContext* context) { - if (input.handle().is_valid()) - context->associated_endpoint_count++; - return 0; - } - static void Serialize(AssociatedInterfaceRequest<T>& input, AssociatedEndpointHandle_Data* output, SerializationContext* context) { DCHECK(!input.handle().is_valid() || input.handle().pending_association()); - if (input.handle().is_valid()) { - // Set to the index of the element pushed to the back of the vector. - output->value = - static_cast<uint32_t>(context->associated_endpoint_handles.size()); - context->associated_endpoint_handles.push_back(input.PassHandle()); - } else { - output->value = kEncodedInvalidHandleValue; - } + context->AddAssociatedEndpoint(input.PassHandle(), output); } static bool Deserialize(AssociatedEndpointHandle_Data* input, AssociatedInterfaceRequest<T>* output, SerializationContext* context) { - if (input->is_valid()) { - DCHECK_LT(input->value, context->associated_endpoint_handles.size()); - output->Bind( - std::move(context->associated_endpoint_handles[input->value])); - } else { - output->Bind(ScopedInterfaceEndpointHandle()); - } + auto handle = context->TakeAssociatedEndpointHandle(*input); + if (!handle.is_valid()) + *output = AssociatedInterfaceRequest<T>(); + else + *output = AssociatedInterfaceRequest<T>(std::move(handle)); return true; } }; @@ -108,69 +77,58 @@ template <typename Base, typename T> struct Serializer<InterfacePtrDataView<Base>, InterfacePtr<T>> { static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static size_t PrepareToSerialize(const InterfacePtr<T>& input, - SerializationContext* context) { - return 0; - } - static void Serialize(InterfacePtr<T>& input, Interface_Data* output, SerializationContext* context) { InterfacePtrInfo<T> info = input.PassInterface(); - output->handle = context->handles.AddHandle(info.PassHandle().release()); - output->version = info.version(); + context->AddInterfaceInfo(info.PassHandle(), info.version(), output); } static bool Deserialize(Interface_Data* input, InterfacePtr<T>* output, SerializationContext* context) { output->Bind(InterfacePtrInfo<T>( - context->handles.TakeHandleAs<mojo::MessagePipeHandle>(input->handle), + context->TakeHandleAs<mojo::MessagePipeHandle>(input->handle), input->version)); return true; } }; template <typename Base, typename T> -struct Serializer<InterfaceRequestDataView<Base>, InterfaceRequest<T>> { +struct Serializer<InterfacePtrDataView<Base>, InterfacePtrInfo<T>> { static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static size_t PrepareToSerialize(const InterfaceRequest<T>& input, - SerializationContext* context) { - return 0; - } - - static void Serialize(InterfaceRequest<T>& input, - Handle_Data* output, + static void Serialize(InterfacePtrInfo<T>& input, + Interface_Data* output, SerializationContext* context) { - *output = context->handles.AddHandle(input.PassMessagePipe().release()); + context->AddInterfaceInfo(input.PassHandle(), input.version(), output); } - static bool Deserialize(Handle_Data* input, - InterfaceRequest<T>* output, + static bool Deserialize(Interface_Data* input, + InterfacePtrInfo<T>* output, SerializationContext* context) { - output->Bind(context->handles.TakeHandleAs<MessagePipeHandle>(*input)); + *output = InterfacePtrInfo<T>( + context->TakeHandleAs<mojo::MessagePipeHandle>(input->handle), + input->version); return true; } }; -template <typename T> -struct Serializer<ScopedHandleBase<T>, ScopedHandleBase<T>> { - static size_t PrepareToSerialize(const ScopedHandleBase<T>& input, - SerializationContext* context) { - return 0; - } +template <typename Base, typename T> +struct Serializer<InterfaceRequestDataView<Base>, InterfaceRequest<T>> { + static_assert(std::is_base_of<Base, T>::value, "Interface type mismatch."); - static void Serialize(ScopedHandleBase<T>& input, + static void Serialize(InterfaceRequest<T>& input, Handle_Data* output, SerializationContext* context) { - *output = context->handles.AddHandle(input.release()); + context->AddHandle(ScopedHandle::From(input.PassMessagePipe()), output); } static bool Deserialize(Handle_Data* input, - ScopedHandleBase<T>* output, + InterfaceRequest<T>* output, SerializationContext* context) { - *output = context->handles.TakeHandleAs<T>(*input); + *output = + InterfaceRequest<T>(context->TakeHandleAs<MessagePipeHandle>(*input)); return true; } }; @@ -178,4 +136,4 @@ struct Serializer<ScopedHandleBase<T>, ScopedHandleBase<T>> { } // namespace internal } // namespace mojo -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_HANDLE_INTERFACE_SERIALIZATION_H_ +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_INTERFACE_SERIALIZATION_H_ diff --git a/mojo/public/cpp/bindings/lib/map_data_internal.h b/mojo/public/cpp/bindings/lib/map_data_internal.h index f8e3d2918f..217904fd43 100644 --- a/mojo/public/cpp/bindings/lib/map_data_internal.h +++ b/mojo/public/cpp/bindings/lib/map_data_internal.h @@ -5,6 +5,7 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_DATA_INTERNAL_H_ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_MAP_DATA_INTERNAL_H_ +#include "base/macros.h" #include "mojo/public/cpp/bindings/lib/array_internal.h" #include "mojo/public/cpp/bindings/lib/validate_params.h" #include "mojo/public/cpp/bindings/lib/validation_errors.h" @@ -18,9 +19,29 @@ namespace internal { template <typename Key, typename Value> class Map_Data { public: - static Map_Data* New(Buffer* buf) { - return new (buf->Allocate(sizeof(Map_Data))) Map_Data(); - } + class BufferWriter { + public: + BufferWriter() = default; + + void Allocate(Buffer* buffer) { + buffer_ = buffer; + index_ = buffer_->Allocate(sizeof(Map_Data)); + new (data()) Map_Data(); + } + + bool is_null() const { return !buffer_; } + Map_Data* data() { + DCHECK(!is_null()); + return buffer_->Get<Map_Data>(index_); + } + Map_Data* operator->() { return data(); } + + private: + Buffer* buffer_ = nullptr; + size_t index_ = 0; + + DISALLOW_COPY_AND_ASSIGN(BufferWriter); + }; // |validate_params| must have non-null |key_validate_params| and // |element_validate_params| members. @@ -41,16 +62,13 @@ class Map_Data { return false; } - if (!ValidatePointerNonNullable( - object->keys, "null key array in map struct", validation_context) || + if (!ValidatePointerNonNullable(object->keys, 0, validation_context) || !ValidateContainer(object->keys, validation_context, validate_params->key_validate_params)) { return false; } - if (!ValidatePointerNonNullable(object->values, - "null value array in map struct", - validation_context) || + if (!ValidatePointerNonNullable(object->values, 1, validation_context) || !ValidateContainer(object->values, validation_context, validate_params->element_validate_params)) { return false; diff --git a/mojo/public/cpp/bindings/lib/map_serialization.h b/mojo/public/cpp/bindings/lib/map_serialization.h index 718a76307d..b114f4995c 100644 --- a/mojo/public/cpp/bindings/lib/map_serialization.h +++ b/mojo/public/cpp/bindings/lib/map_serialization.h @@ -95,57 +95,34 @@ struct Serializer<MapDataView<Key, Value>, MaybeConstUserType> { std::vector<UserValue>, MapValueReader<MaybeConstUserType>>; - static size_t PrepareToSerialize(MaybeConstUserType& input, - SerializationContext* context) { - if (CallIsNullIfExists<Traits>(input)) - return 0; - - size_t struct_overhead = sizeof(Data); - MapKeyReader<MaybeConstUserType> key_reader(input); - size_t keys_size = - KeyArraySerializer::GetSerializedSize(&key_reader, context); - MapValueReader<MaybeConstUserType> value_reader(input); - size_t values_size = - ValueArraySerializer::GetSerializedSize(&value_reader, context); - - return struct_overhead + keys_size + values_size; - } - static void Serialize(MaybeConstUserType& input, Buffer* buf, - Data** output, + typename Data::BufferWriter* writer, const ContainerValidateParams* validate_params, SerializationContext* context) { DCHECK(validate_params->key_validate_params); DCHECK(validate_params->element_validate_params); - if (CallIsNullIfExists<Traits>(input)) { - *output = nullptr; + if (CallIsNullIfExists<Traits>(input)) return; - } - auto result = Data::New(buf); - if (result) { - auto keys_ptr = MojomTypeTraits<ArrayDataView<Key>>::Data::New( - Traits::GetSize(input), buf); - if (keys_ptr) { - MapKeyReader<MaybeConstUserType> key_reader(input); - KeyArraySerializer::SerializeElements( - &key_reader, buf, keys_ptr, validate_params->key_validate_params, - context); - result->keys.Set(keys_ptr); - } - - auto values_ptr = MojomTypeTraits<ArrayDataView<Value>>::Data::New( - Traits::GetSize(input), buf); - if (values_ptr) { - MapValueReader<MaybeConstUserType> value_reader(input); - ValueArraySerializer::SerializeElements( - &value_reader, buf, values_ptr, - validate_params->element_validate_params, context); - result->values.Set(values_ptr); - } - } - *output = result; + writer->Allocate(buf); + typename MojomTypeTraits<ArrayDataView<Key>>::Data::BufferWriter + keys_writer; + keys_writer.Allocate(Traits::GetSize(input), buf); + MapKeyReader<MaybeConstUserType> key_reader(input); + KeyArraySerializer::SerializeElements(&key_reader, buf, &keys_writer, + validate_params->key_validate_params, + context); + (*writer)->keys.Set(keys_writer.data()); + + typename MojomTypeTraits<ArrayDataView<Value>>::Data::BufferWriter + values_writer; + values_writer.Allocate(Traits::GetSize(input), buf); + MapValueReader<MaybeConstUserType> value_reader(input); + ValueArraySerializer::SerializeElements( + &value_reader, buf, &values_writer, + validate_params->element_validate_params, context); + (*writer)->values.Set(values_writer.data()); } static bool Deserialize(Data* input, diff --git a/mojo/public/cpp/bindings/lib/may_auto_lock.h b/mojo/public/cpp/bindings/lib/may_auto_lock.h index 06091fee90..78cb89fa77 100644 --- a/mojo/public/cpp/bindings/lib/may_auto_lock.h +++ b/mojo/public/cpp/bindings/lib/may_auto_lock.h @@ -5,6 +5,7 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MAY_AUTO_LOCK_H_ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_MAY_AUTO_LOCK_H_ +#include "base/component_export.h" #include "base/macros.h" #include "base/optional.h" #include "base/synchronization/lock.h" @@ -14,7 +15,7 @@ namespace internal { // Similar to base::AutoLock, except that it does nothing if |lock| passed into // the constructor is null. -class MayAutoLock { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) MayAutoLock { public: explicit MayAutoLock(base::Optional<base::Lock>* lock) : lock_(lock->has_value() ? &lock->value() : nullptr) { @@ -36,7 +37,7 @@ class MayAutoLock { // Similar to base::AutoUnlock, except that it does nothing if |lock| passed // into the constructor is null. -class MayAutoUnlock { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) MayAutoUnlock { public: explicit MayAutoUnlock(base::Optional<base::Lock>* lock) : lock_(lock->has_value() ? &lock->value() : nullptr) { diff --git a/mojo/public/cpp/bindings/lib/message.cc b/mojo/public/cpp/bindings/lib/message.cc index e5f3808117..8972d9efd1 100644 --- a/mojo/public/cpp/bindings/lib/message.cc +++ b/mojo/public/cpp/bindings/lib/message.cc @@ -14,72 +14,279 @@ #include "base/bind.h" #include "base/lazy_instance.h" #include "base/logging.h" +#include "base/numerics/safe_math.h" #include "base/strings/stringprintf.h" #include "base/threading/thread_local.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/lib/unserialized_message_context.h" namespace mojo { namespace { base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>:: - DestructorAtExit g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER; + Leaky g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER; -base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>:: - DestructorAtExit g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER; +base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::Leaky + g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER; void DoNotifyBadMessage(Message message, const std::string& error) { message.NotifyBadMessage(error); } -} // namespace +template <typename HeaderType> +void AllocateHeaderFromBuffer(internal::Buffer* buffer, HeaderType** header) { + *header = buffer->AllocateAndGet<HeaderType>(); + (*header)->num_bytes = sizeof(HeaderType); +} + +void WriteMessageHeader(uint32_t name, + uint32_t flags, + size_t payload_interface_id_count, + internal::Buffer* payload_buffer) { + if (payload_interface_id_count > 0) { + // Version 2 + internal::MessageHeaderV2* header; + AllocateHeaderFromBuffer(payload_buffer, &header); + header->version = 2; + header->name = name; + header->flags = flags; + // The payload immediately follows the header. + header->payload.Set(header + 1); + } else if (flags & + (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { + // Version 1 + internal::MessageHeaderV1* header; + AllocateHeaderFromBuffer(payload_buffer, &header); + header->version = 1; + header->name = name; + header->flags = flags; + } else { + internal::MessageHeader* header; + AllocateHeaderFromBuffer(payload_buffer, &header); + header->version = 0; + header->name = name; + header->flags = flags; + } +} + +void CreateSerializedMessageObject(uint32_t name, + uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count, + std::vector<ScopedHandle>* handles, + ScopedMessageHandle* out_handle, + internal::Buffer* out_buffer) { + ScopedMessageHandle handle; + MojoResult rv = mojo::CreateMessage(&handle); + DCHECK_EQ(MOJO_RESULT_OK, rv); + DCHECK(handle.is_valid()); + + void* buffer; + uint32_t buffer_size; + size_t total_size = internal::ComputeSerializedMessageSize( + flags, payload_size, payload_interface_id_count); + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(total_size)); + DCHECK(!handles || + base::IsValueInRangeForNumericType<uint32_t>(handles->size())); + rv = MojoAppendMessageData( + handle->value(), static_cast<uint32_t>(total_size), + handles ? reinterpret_cast<MojoHandle*>(handles->data()) : nullptr, + handles ? static_cast<uint32_t>(handles->size()) : 0, nullptr, &buffer, + &buffer_size); + DCHECK_EQ(MOJO_RESULT_OK, rv); + if (handles) { + // Handle ownership has been taken by MojoAppendMessageData. + for (size_t i = 0; i < handles->size(); ++i) + ignore_result(handles->at(i).release()); + } + + internal::Buffer payload_buffer(handle.get(), total_size, buffer, + buffer_size); + + // Make sure we zero the memory first! + memset(payload_buffer.data(), 0, total_size); + WriteMessageHeader(name, flags, payload_interface_id_count, &payload_buffer); + + *out_handle = std::move(handle); + *out_buffer = std::move(payload_buffer); +} + +void SerializeUnserializedContext(MojoMessageHandle message, + uintptr_t context_value) { + auto* context = + reinterpret_cast<internal::UnserializedMessageContext*>(context_value); + void* buffer; + uint32_t buffer_size; + MojoResult attach_result = MojoAppendMessageData( + message, 0, nullptr, 0, nullptr, &buffer, &buffer_size); + if (attach_result != MOJO_RESULT_OK) + return; + + internal::Buffer payload_buffer(MessageHandle(message), 0, buffer, + buffer_size); + WriteMessageHeader(context->message_name(), context->message_flags(), + 0 /* payload_interface_id_count */, &payload_buffer); + + // We need to copy additional header data which may have been set after + // message construction, as this codepath may be reached at some arbitrary + // time between message send and message dispatch. + static_cast<internal::MessageHeader*>(buffer)->interface_id = + context->header()->interface_id; + if (context->header()->flags & + (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { + DCHECK_GE(context->header()->version, 1u); + static_cast<internal::MessageHeaderV1*>(buffer)->request_id = + context->header()->request_id; + } + + internal::SerializationContext serialization_context; + context->Serialize(&serialization_context, &payload_buffer); + + // TODO(crbug.com/753433): Support lazy serialization of associated endpoint + // handles. See corresponding TODO in the bindings generator for proof that + // this DCHECK is indeed valid. + DCHECK(serialization_context.associated_endpoint_handles()->empty()); + if (!serialization_context.handles()->empty()) + payload_buffer.AttachHandles(serialization_context.mutable_handles()); + payload_buffer.Seal(); +} + +void DestroyUnserializedContext(uintptr_t context) { + delete reinterpret_cast<internal::UnserializedMessageContext*>(context); +} -Message::Message() { +ScopedMessageHandle CreateUnserializedMessageObject( + std::unique_ptr<internal::UnserializedMessageContext> context) { + ScopedMessageHandle handle; + MojoResult rv = mojo::CreateMessage(&handle); + DCHECK_EQ(MOJO_RESULT_OK, rv); + DCHECK(handle.is_valid()); + + rv = MojoSetMessageContext( + handle->value(), reinterpret_cast<uintptr_t>(context.release()), + &SerializeUnserializedContext, &DestroyUnserializedContext, nullptr); + DCHECK_EQ(MOJO_RESULT_OK, rv); + return handle; } +} // namespace + +Message::Message() = default; + Message::Message(Message&& other) - : buffer_(std::move(other.buffer_)), + : handle_(std::move(other.handle_)), + payload_buffer_(std::move(other.payload_buffer_)), handles_(std::move(other.handles_)), associated_endpoint_handles_( - std::move(other.associated_endpoint_handles_)) {} + std::move(other.associated_endpoint_handles_)), + transferable_(other.transferable_), + serialized_(other.serialized_) { + other.transferable_ = false; + other.serialized_ = false; +#if defined(ENABLE_IPC_FUZZER) + interface_name_ = other.interface_name_; + method_name_ = other.method_name_; +#endif +} + +Message::Message(std::unique_ptr<internal::UnserializedMessageContext> context) + : Message(CreateUnserializedMessageObject(std::move(context))) {} + +Message::Message(uint32_t name, + uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count, + std::vector<ScopedHandle>* handles) { + CreateSerializedMessageObject(name, flags, payload_size, + payload_interface_id_count, handles, &handle_, + &payload_buffer_); + transferable_ = true; + serialized_ = true; +} + +Message::Message(ScopedMessageHandle handle) { + DCHECK(handle.is_valid()); + + uintptr_t context_value = 0; + MojoResult get_context_result = + MojoGetMessageContext(handle->value(), nullptr, &context_value); + if (get_context_result == MOJO_RESULT_NOT_FOUND) { + // It's a serialized message. Extract handles if possible. + uint32_t num_bytes; + void* buffer; + uint32_t num_handles = 0; + MojoResult rv = MojoGetMessageData(handle->value(), nullptr, &buffer, + &num_bytes, nullptr, &num_handles); + if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) { + handles_.resize(num_handles); + rv = MojoGetMessageData(handle->value(), nullptr, &buffer, &num_bytes, + reinterpret_cast<MojoHandle*>(handles_.data()), + &num_handles); + } else { + // No handles, so it's safe to retransmit this message if the caller + // really wants to. + transferable_ = true; + } -Message::~Message() { - CloseHandles(); + if (rv != MOJO_RESULT_OK) { + // Failed to deserialize handles. Leave the Message uninitialized. + return; + } + + payload_buffer_ = internal::Buffer(buffer, num_bytes, num_bytes); + serialized_ = true; + } else { + DCHECK_EQ(MOJO_RESULT_OK, get_context_result); + auto* context = + reinterpret_cast<internal::UnserializedMessageContext*>(context_value); + // Dummy data address so common header accessors still behave properly. The + // choice is V1 reflects unserialized message capabilities: we may or may + // not need to support request IDs (which require at least V1), but we never + // (for now, anyway) need to support associated interface handles (V2). + payload_buffer_ = + internal::Buffer(context->header(), sizeof(internal::MessageHeaderV1), + sizeof(internal::MessageHeaderV1)); + transferable_ = true; + serialized_ = false; + } + + handle_ = std::move(handle); } +Message::~Message() = default; + Message& Message::operator=(Message&& other) { - Reset(); - std::swap(other.buffer_, buffer_); - std::swap(other.handles_, handles_); - std::swap(other.associated_endpoint_handles_, associated_endpoint_handles_); + handle_ = std::move(other.handle_); + payload_buffer_ = std::move(other.payload_buffer_); + handles_ = std::move(other.handles_); + associated_endpoint_handles_ = std::move(other.associated_endpoint_handles_); + transferable_ = other.transferable_; + other.transferable_ = false; + serialized_ = other.serialized_; + other.serialized_ = false; +#if defined(ENABLE_IPC_FUZZER) + interface_name_ = other.interface_name_; + method_name_ = other.method_name_; +#endif return *this; } void Message::Reset() { - CloseHandles(); + handle_.reset(); + payload_buffer_.Reset(); handles_.clear(); associated_endpoint_handles_.clear(); - buffer_.reset(); -} - -void Message::Initialize(size_t capacity, bool zero_initialized) { - DCHECK(!buffer_); - buffer_.reset(new internal::MessageBuffer(capacity, zero_initialized)); -} - -void Message::InitializeFromMojoMessage(ScopedMessageHandle message, - uint32_t num_bytes, - std::vector<Handle>* handles) { - DCHECK(!buffer_); - buffer_.reset(new internal::MessageBuffer(std::move(message), num_bytes)); - handles_.swap(*handles); + transferable_ = false; + serialized_ = false; } const uint8_t* Message::payload() const { if (version() < 2) return data() + header()->num_bytes; + DCHECK(!header_v2()->payload.is_null()); return static_cast<const uint8_t*>(header_v2()->payload.Get()); } @@ -89,19 +296,16 @@ uint32_t Message::payload_num_bytes() const { if (version() < 2) { num_bytes = data_num_bytes() - header()->num_bytes; } else { - auto payload = reinterpret_cast<uintptr_t>(header_v2()->payload.Get()); - if (!payload) { - num_bytes = 0; - } else { - auto payload_end = - reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get()); - if (!payload_end) - payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes()); - DCHECK_GE(payload_end, payload); - num_bytes = payload_end - payload; - } + auto payload_begin = + reinterpret_cast<uintptr_t>(header_v2()->payload.Get()); + auto payload_end = + reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get()); + if (!payload_end) + payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes()); + DCHECK_GE(payload_end, payload_begin); + num_bytes = payload_end - payload_begin; } - DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max()); + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(num_bytes)); return static_cast<uint32_t>(num_bytes); } @@ -117,52 +321,52 @@ const uint32_t* Message::payload_interface_ids() const { return array_pointer ? array_pointer->storage() : nullptr; } -ScopedMessageHandle Message::TakeMojoMessage() { - // If there are associated endpoints transferred, - // SerializeAssociatedEndpointHandles() must be called before this method. - DCHECK(associated_endpoint_handles_.empty()); +void Message::AttachHandlesFromSerializationContext( + internal::SerializationContext* context) { + if (context->handles()->empty() && + context->associated_endpoint_handles()->empty()) { + // No handles attached, so no extra serialization work. + return; + } - if (handles_.empty()) // Fast path for the common case: No handles. - return buffer_->TakeMessage(); + if (context->associated_endpoint_handles()->empty()) { + // Attaching only non-associated handles is easier since we don't have to + // modify the message header. Faster path for that. + payload_buffer_.AttachHandles(context->mutable_handles()); + return; + } - // Allocate a new message with space for the handles, then copy the buffer - // contents into it. + // Allocate a new message with enough space to hold all attached handles. Copy + // this message's contents into the new one and use it to replace ourself. // - // TODO(rockot): We could avoid this copy by extending GetSerializedSize() - // behavior to collect handles. It's unoptimized for now because it's much - // more common to have messages with no handles. - ScopedMessageHandle new_message; - MojoResult rv = AllocMessage( - data_num_bytes(), - handles_.empty() ? nullptr - : reinterpret_cast<const MojoHandle*>(handles_.data()), - handles_.size(), - MOJO_ALLOC_MESSAGE_FLAG_NONE, - &new_message); - CHECK_EQ(rv, MOJO_RESULT_OK); - handles_.clear(); - - void* new_buffer = nullptr; - rv = GetMessageBuffer(new_message.get(), &new_buffer); - CHECK_EQ(rv, MOJO_RESULT_OK); - - memcpy(new_buffer, data(), data_num_bytes()); - buffer_.reset(); - - return new_message; + // TODO(rockot): We could avoid the extra full message allocation by instead + // growing the buffer and carefully moving its contents around. This errs on + // the side of less complexity with probably only marginal performance cost. + uint32_t payload_size = payload_num_bytes(); + mojo::Message new_message(name(), header()->flags, payload_size, + context->associated_endpoint_handles()->size(), + context->mutable_handles()); + std::swap(*context->mutable_associated_endpoint_handles(), + new_message.associated_endpoint_handles_); + memcpy(new_message.payload_buffer()->AllocateAndGet(payload_size), payload(), + payload_size); + *this = std::move(new_message); } -void Message::NotifyBadMessage(const std::string& error) { - DCHECK(buffer_); - buffer_->NotifyBadMessage(error); +ScopedMessageHandle Message::TakeMojoMessage() { + // If there are associated endpoints transferred, + // SerializeAssociatedEndpointHandles() must be called before this method. + DCHECK(associated_endpoint_handles_.empty()); + DCHECK(transferable_); + payload_buffer_.Seal(); + auto handle = std::move(handle_); + Reset(); + return handle; } -void Message::CloseHandles() { - for (std::vector<Handle>::iterator it = handles_.begin(); - it != handles_.end(); ++it) { - if (it->is_valid()) - CloseRaw(*it); - } +void Message::NotifyBadMessage(const std::string& error) { + DCHECK(handle_.is_valid()); + mojo::NotifyBadMessage(handle_.get(), error); } void Message::SerializeAssociatedEndpointHandles( @@ -172,16 +376,20 @@ void Message::SerializeAssociatedEndpointHandles( DCHECK_GE(version(), 2u); DCHECK(header_v2()->payload_interface_ids.is_null()); + DCHECK(payload_buffer_.is_valid()); + DCHECK(handle_.is_valid()); size_t size = associated_endpoint_handles_.size(); - auto* data = internal::Array_Data<uint32_t>::New(size, buffer()); - header_v2()->payload_interface_ids.Set(data); + + internal::Array_Data<uint32_t>::BufferWriter handle_writer; + handle_writer.Allocate(size, &payload_buffer_); + header_v2()->payload_interface_ids.Set(handle_writer.data()); for (size_t i = 0; i < size; ++i) { ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i]; DCHECK(handle.pending_association()); - data->storage()[i] = + handle_writer->storage()[i] = group_controller->AssociateInterface(std::move(handle)); } associated_endpoint_handles_.clear(); @@ -189,6 +397,9 @@ void Message::SerializeAssociatedEndpointHandles( bool Message::DeserializeAssociatedEndpointHandles( AssociatedGroupController* group_controller) { + if (!serialized_) + return true; + associated_endpoint_handles_.clear(); uint32_t num_ids = payload_num_interface_ids(); @@ -213,11 +424,48 @@ bool Message::DeserializeAssociatedEndpointHandles( return result; } +void Message::SerializeIfNecessary() { + MojoResult rv = MojoSerializeMessage(handle_->value(), nullptr); + if (rv == MOJO_RESULT_FAILED_PRECONDITION) + return; + + // Reconstruct this Message instance from the serialized message's handle. + *this = Message(std::move(handle_)); +} + +std::unique_ptr<internal::UnserializedMessageContext> +Message::TakeUnserializedContext( + const internal::UnserializedMessageContext::Tag* tag) { + DCHECK(handle_.is_valid()); + uintptr_t context_value = 0; + MojoResult rv = + MojoGetMessageContext(handle_->value(), nullptr, &context_value); + if (rv == MOJO_RESULT_NOT_FOUND) + return nullptr; + DCHECK_EQ(MOJO_RESULT_OK, rv); + + auto* context = + reinterpret_cast<internal::UnserializedMessageContext*>(context_value); + if (context->tag() != tag) + return nullptr; + + // Detach the context from the message. + rv = MojoSetMessageContext(handle_->value(), 0, nullptr, nullptr, nullptr); + DCHECK_EQ(MOJO_RESULT_OK, rv); + return base::WrapUnique(context); +} + +bool MessageReceiver::PrefersSerializedMessages() { + return false; +} + PassThroughFilter::PassThroughFilter() {} PassThroughFilter::~PassThroughFilter() {} -bool PassThroughFilter::Accept(Message* message) { return true; } +bool PassThroughFilter::Accept(Message* message) { + return true; +} SyncMessageResponseContext::SyncMessageResponseContext() : outer_context_(current()) { @@ -238,43 +486,19 @@ void SyncMessageResponseContext::ReportBadMessage(const std::string& error) { GetBadMessageCallback().Run(error); } -const ReportBadMessageCallback& -SyncMessageResponseContext::GetBadMessageCallback() { - if (bad_message_callback_.is_null()) { - bad_message_callback_ = - base::Bind(&DoNotifyBadMessage, base::Passed(&response_)); - } - return bad_message_callback_; +ReportBadMessageCallback SyncMessageResponseContext::GetBadMessageCallback() { + DCHECK(!response_.IsNull()); + return base::BindOnce(&DoNotifyBadMessage, std::move(response_)); } MojoResult ReadMessage(MessagePipeHandle handle, Message* message) { - MojoResult rv; - - std::vector<Handle> handles; - ScopedMessageHandle mojo_message; - uint32_t num_bytes = 0, num_handles = 0; - rv = ReadMessageNew(handle, - &mojo_message, - &num_bytes, - nullptr, - &num_handles, - MOJO_READ_MESSAGE_FLAG_NONE); - if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) { - DCHECK_GT(num_handles, 0u); - handles.resize(num_handles); - rv = ReadMessageNew(handle, - &mojo_message, - &num_bytes, - reinterpret_cast<MojoHandle*>(handles.data()), - &num_handles, - MOJO_READ_MESSAGE_FLAG_NONE); - } - + ScopedMessageHandle message_handle; + MojoResult rv = + ReadMessageNew(handle, &message_handle, MOJO_READ_MESSAGE_FLAG_NONE); if (rv != MOJO_RESULT_OK) return rv; - message->InitializeFromMojoMessage( - std::move(mojo_message), num_bytes, &handles); + *message = Message(std::move(message_handle)); return MOJO_RESULT_OK; } @@ -311,13 +535,9 @@ MessageDispatchContext* MessageDispatchContext::current() { return g_tls_message_dispatch_context.Get().Get(); } -const ReportBadMessageCallback& -MessageDispatchContext::GetBadMessageCallback() { - if (bad_message_callback_.is_null()) { - bad_message_callback_ = - base::Bind(&DoNotifyBadMessage, base::Passed(message_)); - } - return bad_message_callback_; +ReportBadMessageCallback MessageDispatchContext::GetBadMessageCallback() { + DCHECK(!message_->IsNull()); + return base::BindOnce(&DoNotifyBadMessage, std::move(*message_)); } // static diff --git a/mojo/public/cpp/bindings/lib/message_buffer.cc b/mojo/public/cpp/bindings/lib/message_buffer.cc deleted file mode 100644 index cc12ef6e31..0000000000 --- a/mojo/public/cpp/bindings/lib/message_buffer.cc +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/lib/message_buffer.h" - -#include <limits> - -#include "mojo/public/cpp/bindings/lib/serialization_util.h" - -namespace mojo { -namespace internal { - -MessageBuffer::MessageBuffer(size_t capacity, bool zero_initialized) { - DCHECK_LE(capacity, std::numeric_limits<uint32_t>::max()); - - MojoResult rv = AllocMessage(capacity, nullptr, 0, - MOJO_ALLOC_MESSAGE_FLAG_NONE, &message_); - CHECK_EQ(rv, MOJO_RESULT_OK); - - void* buffer = nullptr; - if (capacity != 0) { - rv = GetMessageBuffer(message_.get(), &buffer); - CHECK_EQ(rv, MOJO_RESULT_OK); - - if (zero_initialized) - memset(buffer, 0, capacity); - } - Initialize(buffer, capacity); -} - -MessageBuffer::MessageBuffer(ScopedMessageHandle message, uint32_t num_bytes) { - message_ = std::move(message); - - void* buffer = nullptr; - if (num_bytes != 0) { - MojoResult rv = GetMessageBuffer(message_.get(), &buffer); - CHECK_EQ(rv, MOJO_RESULT_OK); - } - Initialize(buffer, num_bytes); -} - -MessageBuffer::~MessageBuffer() {} - -void MessageBuffer::NotifyBadMessage(const std::string& error) { - DCHECK(message_.is_valid()); - MojoResult result = mojo::NotifyBadMessage(message_.get(), error); - DCHECK_EQ(result, MOJO_RESULT_OK); -} - -} // namespace internal -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_buffer.h b/mojo/public/cpp/bindings/lib/message_buffer.h deleted file mode 100644 index 96d5140f77..0000000000 --- a/mojo/public/cpp/bindings/lib/message_buffer.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_LIB_MESSAGE_BUFFER_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_LIB_MESSAGE_BUFFER_H_ - -#include <stdint.h> - -#include <utility> - -#include "base/macros.h" -#include "mojo/public/cpp/bindings/lib/buffer.h" -#include "mojo/public/cpp/system/message.h" - -namespace mojo { -namespace internal { - -// A fixed-size Buffer using a Mojo message object for storage. -class MessageBuffer : public Buffer { - public: - // Initializes this buffer to carry a fixed byte capacity and no handles. - MessageBuffer(size_t capacity, bool zero_initialized); - - // Initializes this buffer from an existing Mojo MessageHandle. - MessageBuffer(ScopedMessageHandle message, uint32_t num_bytes); - - ~MessageBuffer(); - - ScopedMessageHandle TakeMessage() { return std::move(message_); } - - void NotifyBadMessage(const std::string& error); - - private: - ScopedMessageHandle message_; - - DISALLOW_COPY_AND_ASSIGN(MessageBuffer); -}; - -} // namespace internal -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_LIB_MESSAGE_BUFFER_H_ diff --git a/mojo/public/cpp/bindings/lib/message_builder.cc b/mojo/public/cpp/bindings/lib/message_builder.cc deleted file mode 100644 index 6806a73213..0000000000 --- a/mojo/public/cpp/bindings/lib/message_builder.cc +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/lib/message_builder.h" - -#include "mojo/public/cpp/bindings/lib/array_internal.h" -#include "mojo/public/cpp/bindings/lib/bindings_internal.h" -#include "mojo/public/cpp/bindings/lib/buffer.h" -#include "mojo/public/cpp/bindings/lib/message_internal.h" - -namespace mojo { -namespace internal { - -template <typename Header> -void Allocate(Buffer* buf, Header** header) { - *header = static_cast<Header*>(buf->Allocate(sizeof(Header))); - (*header)->num_bytes = sizeof(Header); -} - -MessageBuilder::MessageBuilder(uint32_t name, - uint32_t flags, - size_t payload_size, - size_t payload_interface_id_count) { - if (payload_interface_id_count > 0) { - // Version 2 - InitializeMessage( - sizeof(MessageHeaderV2) + Align(payload_size) + - ArrayDataTraits<uint32_t>::GetStorageSize( - static_cast<uint32_t>(payload_interface_id_count))); - - MessageHeaderV2* header; - Allocate(message_.buffer(), &header); - header->version = 2; - header->name = name; - header->flags = flags; - // The payload immediately follows the header. - header->payload.Set(header + 1); - } else if (flags & - (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { - // Version 1 - InitializeMessage(sizeof(MessageHeaderV1) + payload_size); - - MessageHeaderV1* header; - Allocate(message_.buffer(), &header); - header->version = 1; - header->name = name; - header->flags = flags; - } else { - InitializeMessage(sizeof(MessageHeader) + payload_size); - - MessageHeader* header; - Allocate(message_.buffer(), &header); - header->version = 0; - header->name = name; - header->flags = flags; - } -} - -MessageBuilder::~MessageBuilder() { -} - -void MessageBuilder::InitializeMessage(size_t size) { - message_.Initialize(static_cast<uint32_t>(Align(size)), - true /* zero_initialized */); -} - -} // namespace internal -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_builder.h b/mojo/public/cpp/bindings/lib/message_builder.h deleted file mode 100644 index 8a4d5c4690..0000000000 --- a/mojo/public/cpp/bindings/lib/message_builder.h +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_BUILDER_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_BUILDER_H_ - -#include <stddef.h> -#include <stdint.h> - -#include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" -#include "mojo/public/cpp/bindings/message.h" - -namespace mojo { - -class Message; - -namespace internal { - -class Buffer; - -class MOJO_CPP_BINDINGS_EXPORT MessageBuilder { - public: - MessageBuilder(uint32_t name, - uint32_t flags, - size_t payload_size, - size_t payload_interface_id_count); - ~MessageBuilder(); - - Buffer* buffer() { return message_.buffer(); } - Message* message() { return &message_; } - - private: - void InitializeMessage(size_t size); - - Message message_; - - DISALLOW_COPY_AND_ASSIGN(MessageBuilder); -}; - -} // namespace internal -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_MESSAGE_BUILDER_H_ diff --git a/mojo/public/cpp/bindings/lib/message_dumper.cc b/mojo/public/cpp/bindings/lib/message_dumper.cc new file mode 100644 index 0000000000..35696bbcbf --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_dumper.cc @@ -0,0 +1,96 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/message_dumper.h" + +#include "base/files/file.h" +#include "base/files/file_path.h" +#include "base/files/file_util.h" +#include "base/logging.h" +#include "base/no_destructor.h" +#include "base/process/process.h" +#include "base/rand_util.h" +#include "base/strings/string_number_conversions.h" +#include "base/task_scheduler/post_task.h" +#include "mojo/public/cpp/bindings/message.h" + +namespace { + +base::FilePath& DumpDirectory() { + static base::NoDestructor<base::FilePath> dump_directory; + return *dump_directory; +} + +// void WriteMessage(uint32_t identifier, +// const mojo::MessageDumper::MessageEntry& entry) { +// static uint64_t num = 0; + +// if (!entry.interface_name) +// return; + +// base::FilePath message_directory = +// DumpDirectory() +// .AppendASCII(entry.interface_name) +// .AppendASCII(base::NumberToString(identifier)); + +// if (!base::DirectoryExists(message_directory) && +// !base::CreateDirectory(message_directory)) { +// LOG(ERROR) << "Failed to create" << message_directory.value(); +// return; +// } + +// std::string filename = +// base::NumberToString(num++) + "." + entry.method_name + ".mojomsg"; +// base::FilePath path = message_directory.AppendASCII(filename); +// base::File file(path, +// base::File::FLAG_WRITE | base::File::FLAG_CREATE_ALWAYS); + +// file.WriteAtCurrentPos(reinterpret_cast<const char*>(entry.data_bytes.data()), +// static_cast<int>(entry.data_bytes.size())); +// } + +} // namespace + +namespace mojo { + +MessageDumper::MessageEntry::MessageEntry(const uint8_t* data, + uint32_t data_size, + const char* interface_name, + const char* method_name) + : interface_name(interface_name), + method_name(method_name), + data_bytes(data, data + data_size) {} + +MessageDumper::MessageEntry::MessageEntry(const MessageEntry& entry) = default; + +MessageDumper::MessageEntry::~MessageEntry() {} + +MessageDumper::MessageDumper() : identifier_(base::RandUint64()) {} + +MessageDumper::~MessageDumper() {} + +bool MessageDumper::Accept(mojo::Message* message) { + // MessageEntry entry(message->data(), message->data_num_bytes(), + // "unknown interface", "unknown name"); + + // static base::NoDestructor<scoped_refptr<base::TaskRunner>> task_runner( + // base::CreateSequencedTaskRunnerWithTraits( + // {base::MayBlock(), base::TaskPriority::USER_BLOCKING, + // base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN})); + + // (*task_runner) + // ->PostTask(FROM_HERE, + // base::BindOnce(&WriteMessage, identifier_, std::move(entry))); + return true; +} + +void MessageDumper::SetMessageDumpDirectory(const base::FilePath& directory) { + DumpDirectory() = directory; +} + +const base::FilePath& MessageDumper::GetMessageDumpDirectory() { + return DumpDirectory(); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_header_validator.cc b/mojo/public/cpp/bindings/lib/message_header_validator.cc index 9f8c6278c0..46bc5ed6e3 100644 --- a/mojo/public/cpp/bindings/lib/message_header_validator.cc +++ b/mojo/public/cpp/bindings/lib/message_header_validator.cc @@ -73,9 +73,10 @@ bool IsValidMessageHeader(const internal::MessageHeader* header, // payload size). // - Validation of the payload contents will be done separately based on the // payload type. - if (!header_v2->payload.is_null() && - (!internal::ValidatePointer(header_v2->payload, validation_context) || - !validation_context->ClaimMemory(header_v2->payload.Get(), 1))) { + if (!internal::ValidatePointerNonNullable(header_v2->payload, 5, + validation_context) || + !internal::ValidatePointer(header_v2->payload, validation_context) || + !validation_context->ClaimMemory(header_v2->payload.Get(), 1)) { return false; } @@ -115,6 +116,10 @@ void MessageHeaderValidator::SetDescription(const std::string& description) { } bool MessageHeaderValidator::Accept(Message* message) { + // Don't bother validating unserialized message headers. + if (!message->is_serialized()) + return true; + // Pass 0 as number of handles and associated endpoint handles because we // don't expect any in the header, even if |message| contains handles. internal::ValidationContext validation_context( diff --git a/mojo/public/cpp/bindings/lib/message_internal.cc b/mojo/public/cpp/bindings/lib/message_internal.cc new file mode 100644 index 0000000000..445eb4d891 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/message_internal.cc @@ -0,0 +1,45 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/message_internal.h" + +#include "mojo/public/cpp/bindings/lib/array_internal.h" +#include "mojo/public/cpp/bindings/message.h" + +namespace mojo { +namespace internal { + +namespace { + +size_t ComputeHeaderSize(uint32_t flags, size_t payload_interface_id_count) { + if (payload_interface_id_count > 0) { + // Version 2 + return sizeof(MessageHeaderV2); + } else if (flags & + (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { + // Version 1 + return sizeof(MessageHeaderV1); + } else { + // Version 0 + return sizeof(MessageHeader); + } +} + +} // namespace + +size_t ComputeSerializedMessageSize(uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count) { + const size_t header_size = + ComputeHeaderSize(flags, payload_interface_id_count); + if (payload_interface_id_count > 0) { + return Align(header_size + Align(payload_size) + + ArrayDataTraits<uint32_t>::GetStorageSize( + static_cast<uint32_t>(payload_interface_id_count))); + } + return internal::Align(header_size + payload_size); +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/message_internal.h b/mojo/public/cpp/bindings/lib/message_internal.h index 6693198f81..40539e27aa 100644 --- a/mojo/public/cpp/bindings/lib/message_internal.h +++ b/mojo/public/cpp/bindings/lib/message_internal.h @@ -10,8 +10,8 @@ #include <string> #include "base/callback.h" +#include "base/component_export.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" namespace mojo { @@ -54,28 +54,32 @@ static_assert(sizeof(MessageHeaderV2) == 48, "Bad sizeof(MessageHeaderV2)"); #pragma pack(pop) -class MOJO_CPP_BINDINGS_EXPORT MessageDispatchContext { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) MessageDispatchContext { public: explicit MessageDispatchContext(Message* message); ~MessageDispatchContext(); static MessageDispatchContext* current(); - const base::Callback<void(const std::string&)>& GetBadMessageCallback(); + base::OnceCallback<void(const std::string&)> GetBadMessageCallback(); private: MessageDispatchContext* outer_context_; Message* message_; - base::Callback<void(const std::string&)> bad_message_callback_; DISALLOW_COPY_AND_ASSIGN(MessageDispatchContext); }; -class MOJO_CPP_BINDINGS_EXPORT SyncMessageResponseSetup { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) SyncMessageResponseSetup { public: static void SetCurrentSyncResponseMessage(Message* message); }; +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +size_t ComputeSerializedMessageSize(uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count); + } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/multiplex_router.cc b/mojo/public/cpp/bindings/lib/multiplex_router.cc index ff7c678289..61833097ef 100644 --- a/mojo/public/cpp/bindings/lib/multiplex_router.cc +++ b/mojo/public/cpp/bindings/lib/multiplex_router.cc @@ -12,14 +12,13 @@ #include "base/location.h" #include "base/macros.h" #include "base/memory/ptr_util.h" -#include "base/single_thread_task_runner.h" +#include "base/sequenced_task_runner.h" #include "base/stl_util.h" #include "base/synchronization/waitable_event.h" -#include "base/threading/thread_task_runner_handle.h" #include "mojo/public/cpp/bindings/interface_endpoint_client.h" #include "mojo/public/cpp/bindings/interface_endpoint_controller.h" #include "mojo/public/cpp/bindings/lib/may_auto_lock.h" -#include "mojo/public/cpp/bindings/sync_event_watcher.h" +#include "mojo/public/cpp/bindings/sequence_local_sync_event_watcher.h" namespace mojo { namespace internal { @@ -41,7 +40,7 @@ class MultiplexRouter::InterfaceEndpoint client_(nullptr) {} // --------------------------------------------------------------------------- - // The following public methods are safe to call from any threads without + // The following public methods are safe to call from any sequence without // locking. InterfaceId id() const { return id_; } @@ -76,29 +75,27 @@ class MultiplexRouter::InterfaceEndpoint disconnect_reason_ = disconnect_reason; } - base::SingleThreadTaskRunner* task_runner() const { - return task_runner_.get(); - } + base::SequencedTaskRunner* task_runner() const { return task_runner_.get(); } InterfaceEndpointClient* client() const { return client_; } void AttachClient(InterfaceEndpointClient* client, - scoped_refptr<base::SingleThreadTaskRunner> runner) { + scoped_refptr<base::SequencedTaskRunner> runner) { router_->AssertLockAcquired(); DCHECK(!client_); DCHECK(!closed_); - DCHECK(runner->BelongsToCurrentThread()); + DCHECK(runner->RunsTasksInCurrentSequence()); task_runner_ = std::move(runner); client_ = client; } - // This method must be called on the same thread as the corresponding + // This method must be called on the same sequence as the corresponding // AttachClient() call. void DetachClient() { router_->AssertLockAcquired(); DCHECK(client_); - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); DCHECK(!closed_); task_runner_ = nullptr; @@ -111,8 +108,8 @@ class MultiplexRouter::InterfaceEndpoint if (sync_message_event_signaled_) return; sync_message_event_signaled_ = true; - if (sync_message_event_) - sync_message_event_->Signal(); + if (sync_watcher_) + sync_watcher_->SignalEvent(); } void ResetSyncMessageSignal() { @@ -120,30 +117,30 @@ class MultiplexRouter::InterfaceEndpoint if (!sync_message_event_signaled_) return; sync_message_event_signaled_ = false; - if (sync_message_event_) - sync_message_event_->Reset(); + if (sync_watcher_) + sync_watcher_->ResetEvent(); } // --------------------------------------------------------------------------- // The following public methods (i.e., InterfaceEndpointController - // implementation) are called by the client on the same thread as the + // implementation) are called by the client on the same sequence as the // AttachClient() call. They are called outside of the router's lock. bool SendMessage(Message* message) override { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); message->set_interface_id(id_); return router_->connector_.Accept(message); } void AllowWokenUpBySyncWatchOnSameThread() override { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); EnsureSyncWatcherExists(); - sync_watcher_->AllowWokenUpBySyncWatchOnSameThread(); + sync_watcher_->AllowWokenUpBySyncWatchOnSameSequence(); } bool SyncWatch(const bool* should_stop) override { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); EnsureSyncWatcherExists(); return sync_watcher_->SyncWatch(should_stop); @@ -156,13 +153,10 @@ class MultiplexRouter::InterfaceEndpoint router_->AssertLockAcquired(); DCHECK(!client_); - DCHECK(closed_); - DCHECK(peer_closed_); - DCHECK(!sync_watcher_); } void OnSyncEventSignaled() { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); scoped_refptr<MultiplexRouter> router_protector(router_); MayAutoLock locker(&router_->lock_); @@ -184,28 +178,20 @@ class MultiplexRouter::InterfaceEndpoint } void EnsureSyncWatcherExists() { - DCHECK(task_runner_->BelongsToCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); if (sync_watcher_) return; - { - MayAutoLock locker(&router_->lock_); - if (!sync_message_event_) { - sync_message_event_.emplace( - base::WaitableEvent::ResetPolicy::MANUAL, - base::WaitableEvent::InitialState::NOT_SIGNALED); - if (sync_message_event_signaled_) - sync_message_event_->Signal(); - } - } - sync_watcher_.reset( - new SyncEventWatcher(&sync_message_event_.value(), - base::Bind(&InterfaceEndpoint::OnSyncEventSignaled, - base::Unretained(this)))); + MayAutoLock locker(&router_->lock_); + sync_watcher_ = + std::make_unique<SequenceLocalSyncEventWatcher>(base::BindRepeating( + &InterfaceEndpoint::OnSyncEventSignaled, base::Unretained(this))); + if (sync_message_event_signaled_) + sync_watcher_->SignalEvent(); } // --------------------------------------------------------------------------- - // The following members are safe to access from any threads. + // The following members are safe to access from any sequence. MultiplexRouter* const router_; const InterfaceId id_; @@ -225,30 +211,22 @@ class MultiplexRouter::InterfaceEndpoint base::Optional<DisconnectReason> disconnect_reason_; // The task runner on which |client_|'s methods can be called. - scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; // Not owned. It is null if no client is attached to this endpoint. InterfaceEndpointClient* client_; - // An event used to signal that sync messages are available. The event is - // initialized under the router's lock and remains unchanged afterwards. It - // may be accessed outside of the router's lock later. - base::Optional<base::WaitableEvent> sync_message_event_; + // Indicates whether the sync watcher should be signaled for this endpoint. bool sync_message_event_signaled_ = false; - // --------------------------------------------------------------------------- - // The following members are only valid while a client is attached. They are - // used exclusively on the client's thread. They may be accessed outside of - // the router's lock. - - std::unique_ptr<SyncEventWatcher> sync_watcher_; + // Guarded by the router's lock. Used to synchronously wait on replies. + std::unique_ptr<SequenceLocalSyncEventWatcher> sync_watcher_; DISALLOW_COPY_AND_ASSIGN(InterfaceEndpoint); }; // MessageWrapper objects are always destroyed under the router's lock. On -// destruction, if the message it wrappers contains -// ScopedInterfaceEndpointHandles (which cannot be destructed under the -// router's lock), the wrapper unlocks to clean them up. +// destruction, if the message it wrappers contains interface IDs, the wrapper +// closes the corresponding endpoints. class MultiplexRouter::MessageWrapper { public: MessageWrapper() = default; @@ -260,14 +238,14 @@ class MultiplexRouter::MessageWrapper { : router_(other.router_), value_(std::move(other.value_)) {} ~MessageWrapper() { - if (value_.associated_endpoint_handles()->empty()) + if (!router_ || value_.IsNull()) return; router_->AssertLockAcquired(); - { - MayAutoUnlock unlocker(&router_->lock_); - value_.mutable_associated_endpoint_handles()->clear(); - } + // Don't try to close the endpoints if at this point the router is already + // half-destructed. + if (!router_->being_destructed_) + router_->CloseEndpointsForMessage(value_); } MessageWrapper& operator=(MessageWrapper&& other) { @@ -276,7 +254,21 @@ class MultiplexRouter::MessageWrapper { return *this; } - Message& value() { return value_; } + const Message& value() const { return value_; } + + // Must be called outside of the router's lock. + // Returns a null message if it fails to deseralize the associated endpoint + // handles. + Message DeserializeEndpointHandlesAndTake() { + if (!value_.DeserializeAssociatedEndpointHandles(router_)) { + // The previous call may have deserialized part of the associated + // interface endpoint handles. They must be destroyed outside of the + // router's lock, so we cannot wait until destruction of MessageWrapper. + value_.Reset(); + return Message(); + } + return std::move(value_); + } private: MultiplexRouter* router_ = nullptr; @@ -322,23 +314,17 @@ MultiplexRouter::MultiplexRouter( ScopedMessagePipeHandle message_pipe, Config config, bool set_interface_id_namesapce_bit, - scoped_refptr<base::SingleThreadTaskRunner> runner) + scoped_refptr<base::SequencedTaskRunner> runner) : set_interface_id_namespace_bit_(set_interface_id_namesapce_bit), task_runner_(runner), - header_validator_(nullptr), filters_(this), connector_(std::move(message_pipe), config == MULTI_INTERFACE ? Connector::MULTI_THREADED_SEND : Connector::SINGLE_THREADED_SEND, std::move(runner)), control_message_handler_(this), - control_message_proxy_(&connector_), - next_interface_id_value_(1), - posted_to_process_tasks_(false), - encountered_error_(false), - paused_(false), - testing_mode_(false) { - DCHECK(task_runner_->BelongsToCurrentThread()); + control_message_proxy_(&connector_) { + DCHECK(task_runner_->RunsTasksInCurrentSequence()); if (config == MULTI_INTERFACE) lock_.emplace(); @@ -348,16 +334,15 @@ MultiplexRouter::MultiplexRouter( // Always participate in sync handle watching in multi-interface mode, // because even if it doesn't expect sync requests during sync handle // watching, it may still need to dispatch messages to associated endpoints - // on a different thread. + // on a different sequence. connector_.AllowWokenUpBySyncWatchOnSameThread(); } connector_.set_incoming_receiver(&filters_); - connector_.set_connection_error_handler( - base::Bind(&MultiplexRouter::OnPipeConnectionError, - base::Unretained(this))); + connector_.set_connection_error_handler(base::Bind( + &MultiplexRouter::OnPipeConnectionError, base::Unretained(this))); std::unique_ptr<MessageHeaderValidator> header_validator = - base::MakeUnique<MessageHeaderValidator>(); + std::make_unique<MessageHeaderValidator>(); header_validator_ = header_validator.get(); filters_.Append(std::move(header_validator)); } @@ -365,33 +350,22 @@ MultiplexRouter::MultiplexRouter( MultiplexRouter::~MultiplexRouter() { MayAutoLock locker(&lock_); + being_destructed_ = true; + sync_message_tasks_.clear(); tasks_.clear(); + endpoints_.clear(); +} - for (auto iter = endpoints_.begin(); iter != endpoints_.end();) { - InterfaceEndpoint* endpoint = iter->second.get(); - // Increment the iterator before calling UpdateEndpointStateMayRemove() - // because it may remove the corresponding value from the map. - ++iter; - - if (!endpoint->closed()) { - // This happens when a NotifyPeerEndpointClosed message been received, but - // the interface ID hasn't been used to create local endpoint handle. - DCHECK(!endpoint->client()); - DCHECK(endpoint->peer_closed()); - UpdateEndpointStateMayRemove(endpoint, ENDPOINT_CLOSED); - } else { - UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); - } - } - - DCHECK(endpoints_.empty()); +void MultiplexRouter::AddIncomingMessageFilter( + std::unique_ptr<MessageReceiver> filter) { + filters_.Append(std::move(filter)); } void MultiplexRouter::SetMasterInterfaceName(const char* name) { - DCHECK(thread_checker_.CalledOnValidThread()); - header_validator_->SetDescription( - std::string(name) + " [master] MessageHeaderValidator"); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + header_validator_->SetDescription(std::string(name) + + " [master] MessageHeaderValidator"); control_message_handler_.SetDescription( std::string(name) + " [master] PipeControlMessageHandler"); connector_.SetWatcherHeapProfilerTag(name); @@ -445,17 +419,10 @@ ScopedInterfaceEndpointHandle MultiplexRouter::CreateLocalEndpointHandle( bool inserted = false; InterfaceEndpoint* endpoint = FindOrInsertEndpoint(id, &inserted); if (inserted) { - DCHECK(!endpoint->handle_created()); - if (encountered_error_) UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); } else { - // If the endpoint already exist, it is because we have received a - // notification that the peer endpoint has closed. - CHECK(!endpoint->closed()); - CHECK(endpoint->peer_closed()); - - if (endpoint->handle_created()) + if (endpoint->handle_created() || endpoint->closed()) return ScopedInterfaceEndpointHandle(); } @@ -487,7 +454,7 @@ void MultiplexRouter::CloseEndpointHandle( InterfaceEndpointController* MultiplexRouter::AttachEndpointClient( const ScopedInterfaceEndpointHandle& handle, InterfaceEndpointClient* client, - scoped_refptr<base::SingleThreadTaskRunner> runner) { + scoped_refptr<base::SequencedTaskRunner> runner) { const InterfaceId id = handle.id(); DCHECK(IsValidInterfaceId(id)); @@ -520,7 +487,7 @@ void MultiplexRouter::DetachEndpointClient( } void MultiplexRouter::RaiseError() { - if (task_runner_->BelongsToCurrentThread()) { + if (task_runner_->RunsTasksInCurrentSequence()) { connector_.RaiseError(); } else { task_runner_->PostTask(FROM_HERE, @@ -528,8 +495,13 @@ void MultiplexRouter::RaiseError() { } } +bool MultiplexRouter::PrefersSerializedMessages() { + MayAutoLock locker(&lock_); + return connector_.PrefersSerializedMessages(); +} + void MultiplexRouter::CloseMessagePipe() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); connector_.CloseMessagePipe(); // CloseMessagePipe() above won't trigger connection error handler. // Explicitly call OnPipeConnectionError() so that associated endpoints will @@ -538,7 +510,7 @@ void MultiplexRouter::CloseMessagePipe() { } void MultiplexRouter::PauseIncomingMethodCallProcessing() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); connector_.PauseIncomingMethodCallProcessing(); MayAutoLock locker(&lock_); @@ -549,7 +521,7 @@ void MultiplexRouter::PauseIncomingMethodCallProcessing() { } void MultiplexRouter::ResumeIncomingMethodCallProcessing() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); connector_.ResumeIncomingMethodCallProcessing(); MayAutoLock locker(&lock_); @@ -568,7 +540,7 @@ void MultiplexRouter::ResumeIncomingMethodCallProcessing() { } bool MultiplexRouter::HasAssociatedEndpoints() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); MayAutoLock locker(&lock_); if (endpoints_.size() > 1) @@ -580,7 +552,7 @@ bool MultiplexRouter::HasAssociatedEndpoints() const { } void MultiplexRouter::EnableTestingMode() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); MayAutoLock locker(&lock_); testing_mode_ = true; @@ -588,9 +560,19 @@ void MultiplexRouter::EnableTestingMode() { } bool MultiplexRouter::Accept(Message* message) { - DCHECK(thread_checker_.CalledOnValidThread()); - - if (!message->DeserializeAssociatedEndpointHandles(this)) + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + // Insert endpoints for the payload interface IDs as soon as the message + // arrives, instead of waiting till the message is dispatched. Consider the + // following sequence: + // 1) Async message msg1 arrives, containing interface ID x. Msg1 is not + // dispatched because a sync call is blocking the thread. + // 2) Sync message msg2 arrives targeting interface ID x. + // + // If we don't insert endpoint for interface ID x, when trying to dispatch + // msg2 we don't know whether it is an unexpected message or it is just + // because the message containing x hasn't been dispatched. + if (!InsertEndpointsForMessage(*message)) return false; scoped_refptr<MultiplexRouter> protector(this); @@ -603,15 +585,15 @@ bool MultiplexRouter::Accept(Message* message) { ? ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES : ALLOW_DIRECT_CLIENT_CALLS; - bool processed = - tasks_.empty() && ProcessIncomingMessage(message, client_call_behavior, - connector_.task_runner()); + MessageWrapper message_wrapper(this, std::move(*message)); + bool processed = tasks_.empty() && ProcessIncomingMessage( + &message_wrapper, client_call_behavior, + connector_.task_runner()); if (!processed) { // Either the task queue is not empty or we cannot process the message // directly. In both cases, there is no need to call ProcessTasks(). - tasks_.push_back( - Task::CreateMessageTask(MessageWrapper(this, std::move(*message)))); + tasks_.push_back(Task::CreateMessageTask(std::move(message_wrapper))); Task* task = tasks_.back().get(); if (task->message_wrapper.value().has_flag(Message::kFlagIsSync)) { @@ -636,8 +618,6 @@ bool MultiplexRouter::Accept(Message* message) { bool MultiplexRouter::OnPeerAssociatedEndpointClosed( InterfaceId id, const base::Optional<DisconnectReason>& reason) { - DCHECK(!IsMasterInterfaceId(id) || reason); - MayAutoLock locker(&lock_); InterfaceEndpoint* endpoint = FindOrInsertEndpoint(id, nullptr); @@ -662,23 +642,26 @@ bool MultiplexRouter::OnPeerAssociatedEndpointClosed( } void MultiplexRouter::OnPipeConnectionError() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); scoped_refptr<MultiplexRouter> protector(this); MayAutoLock locker(&lock_); encountered_error_ = true; - for (auto iter = endpoints_.begin(); iter != endpoints_.end();) { - InterfaceEndpoint* endpoint = iter->second.get(); - // Increment the iterator before calling UpdateEndpointStateMayRemove() - // because it may remove the corresponding value from the map. - ++iter; + // Calling UpdateEndpointStateMayRemove() may remove the corresponding value + // from |endpoints_| and invalidate any iterator of |endpoints_|. Therefore, + // copy the endpoint pointers to a vector and iterate over it instead. + std::vector<scoped_refptr<InterfaceEndpoint>> endpoint_vector; + endpoint_vector.reserve(endpoints_.size()); + for (const auto& pair : endpoints_) + endpoint_vector.push_back(pair.second); + for (const auto& endpoint : endpoint_vector) { if (endpoint->client()) - tasks_.push_back(Task::CreateNotifyErrorTask(endpoint)); + tasks_.push_back(Task::CreateNotifyErrorTask(endpoint.get())); - UpdateEndpointStateMayRemove(endpoint, PEER_ENDPOINT_CLOSED); + UpdateEndpointStateMayRemove(endpoint.get(), PEER_ENDPOINT_CLOSED); } ProcessTasks(connector_.during_sync_handle_watcher_callback() @@ -689,7 +672,7 @@ void MultiplexRouter::OnPipeConnectionError() { void MultiplexRouter::ProcessTasks( ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner) { + base::SequencedTaskRunner* current_task_runner) { AssertLockAcquired(); if (posted_to_process_tasks_) @@ -714,7 +697,7 @@ void MultiplexRouter::ProcessTasks( task->IsNotifyErrorTask() ? ProcessNotifyErrorTask(task.get(), client_call_behavior, current_task_runner) - : ProcessIncomingMessage(&task->message_wrapper.value(), + : ProcessIncomingMessage(&task->message_wrapper, client_call_behavior, current_task_runner); if (!processed) { @@ -752,8 +735,7 @@ bool MultiplexRouter::ProcessFirstSyncMessageForEndpoint(InterfaceId id) { // Note: after this call, |task| and |iter| may be invalidated. bool processed = ProcessIncomingMessage( - &message_wrapper.value(), ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES, - nullptr); + &message_wrapper, ALLOW_DIRECT_CLIENT_CALLS_FOR_SYNC_MESSAGES, nullptr); DCHECK(processed); iter = sync_message_tasks_.find(id); @@ -771,8 +753,9 @@ bool MultiplexRouter::ProcessFirstSyncMessageForEndpoint(InterfaceId id) { bool MultiplexRouter::ProcessNotifyErrorTask( Task* task, ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner) { - DCHECK(!current_task_runner || current_task_runner->BelongsToCurrentThread()); + base::SequencedTaskRunner* current_task_runner) { + DCHECK(!current_task_runner || + current_task_runner->RunsTasksInCurrentSequence()); DCHECK(!paused_); AssertLockAcquired(); @@ -786,7 +769,7 @@ bool MultiplexRouter::ProcessNotifyErrorTask( return false; } - DCHECK(endpoint->task_runner()->BelongsToCurrentThread()); + DCHECK(endpoint->task_runner()->RunsTasksInCurrentSequence()); InterfaceEndpointClient* client = endpoint->client(); base::Optional<DisconnectReason> disconnect_reason( @@ -797,7 +780,7 @@ bool MultiplexRouter::ProcessNotifyErrorTask( // object within NotifyError(). Holding the lock will lead to deadlock. // // It is safe to call into |client| without the lock. Because |client| is - // always accessed on the same thread, including DetachEndpointClient(). + // always accessed on the same sequence, including DetachEndpointClient(). MayAutoUnlock unlocker(&lock_); client->NotifyError(disconnect_reason); } @@ -805,14 +788,16 @@ bool MultiplexRouter::ProcessNotifyErrorTask( } bool MultiplexRouter::ProcessIncomingMessage( - Message* message, + MessageWrapper* message_wrapper, ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner) { - DCHECK(!current_task_runner || current_task_runner->BelongsToCurrentThread()); + base::SequencedTaskRunner* current_task_runner) { + DCHECK(!current_task_runner || + current_task_runner->RunsTasksInCurrentSequence()); DCHECK(!paused_); - DCHECK(message); + DCHECK(message_wrapper); AssertLockAcquired(); + const Message* message = &message_wrapper->value(); if (message->IsNull()) { // This is a sync message and has been processed during sync handle // watching. @@ -824,7 +809,10 @@ bool MultiplexRouter::ProcessIncomingMessage( { MayAutoUnlock unlocker(&lock_); - result = control_message_handler_.Accept(message); + Message tmp_message = + message_wrapper->DeserializeEndpointHandlesAndTake(); + result = !tmp_message.IsNull() && + control_message_handler_.Accept(&tmp_message); } if (!result) @@ -849,7 +837,7 @@ bool MultiplexRouter::ProcessIncomingMessage( bool can_direct_call; if (message->has_flag(Message::kFlagIsSync)) { can_direct_call = client_call_behavior != NO_DIRECT_CLIENT_CALLS && - endpoint->task_runner()->BelongsToCurrentThread(); + endpoint->task_runner()->RunsTasksInCurrentSequence(); } else { can_direct_call = client_call_behavior == ALLOW_DIRECT_CLIENT_CALLS && endpoint->task_runner() == current_task_runner; @@ -860,7 +848,7 @@ bool MultiplexRouter::ProcessIncomingMessage( return false; } - DCHECK(endpoint->task_runner()->BelongsToCurrentThread()); + DCHECK(endpoint->task_runner()->RunsTasksInCurrentSequence()); InterfaceEndpointClient* client = endpoint->client(); bool result = false; @@ -870,9 +858,11 @@ bool MultiplexRouter::ProcessIncomingMessage( // deadlock. // // It is safe to call into |client| without the lock. Because |client| is - // always accessed on the same thread, including DetachEndpointClient(). + // always accessed on the same sequence, including DetachEndpointClient(). MayAutoUnlock unlocker(&lock_); - result = client->HandleIncomingMessage(message); + Message tmp_message = message_wrapper->DeserializeEndpointHandlesAndTake(); + result = + !tmp_message.IsNull() && client->HandleIncomingMessage(&tmp_message); } if (!result) RaiseErrorInNonTestingMode(); @@ -881,7 +871,7 @@ bool MultiplexRouter::ProcessIncomingMessage( } void MultiplexRouter::MaybePostToProcessTasks( - base::SingleThreadTaskRunner* task_runner) { + base::SequencedTaskRunner* task_runner) { AssertLockAcquired(); if (posted_to_process_tasks_) return; @@ -897,7 +887,7 @@ void MultiplexRouter::LockAndCallProcessTasks() { // always called using base::Bind(), which holds a ref. MayAutoLock locker(&lock_); posted_to_process_tasks_ = false; - scoped_refptr<base::SingleThreadTaskRunner> runner( + scoped_refptr<base::SequencedTaskRunner> runner( std::move(posted_to_task_runner_)); ProcessTasks(ALLOW_DIRECT_CLIENT_CALLS, runner.get()); } @@ -956,5 +946,67 @@ void MultiplexRouter::AssertLockAcquired() { #endif } +bool MultiplexRouter::InsertEndpointsForMessage(const Message& message) { + if (!message.is_serialized()) + return true; + + uint32_t num_ids = message.payload_num_interface_ids(); + if (num_ids == 0) + return true; + + const uint32_t* ids = message.payload_interface_ids(); + + MayAutoLock locker(&lock_); + for (uint32_t i = 0; i < num_ids; ++i) { + // Message header validation already ensures that the IDs are valid and not + // the master ID. + // The IDs are from the remote side and therefore their namespace bit is + // supposed to be different than the value that this router would use. + if (set_interface_id_namespace_bit_ == + HasInterfaceIdNamespaceBitSet(ids[i])) { + return false; + } + + // It is possible that the endpoint already exists even when the remote side + // is well-behaved: it might have notified us that the peer endpoint has + // closed. + bool inserted = false; + InterfaceEndpoint* endpoint = FindOrInsertEndpoint(ids[i], &inserted); + if (endpoint->closed() || endpoint->handle_created()) + return false; + } + + return true; +} + +void MultiplexRouter::CloseEndpointsForMessage(const Message& message) { + AssertLockAcquired(); + + if (!message.is_serialized()) + return; + + uint32_t num_ids = message.payload_num_interface_ids(); + if (num_ids == 0) + return; + + const uint32_t* ids = message.payload_interface_ids(); + for (uint32_t i = 0; i < num_ids; ++i) { + InterfaceEndpoint* endpoint = FindEndpoint(ids[i]); + // If the remote side maliciously sends the same interface ID in another + // message which has been dispatched, we could get here with no endpoint + // for the ID, a closed endpoint, or an endpoint with handle created. + if (!endpoint || endpoint->closed() || endpoint->handle_created()) { + RaiseErrorInNonTestingMode(); + continue; + } + + UpdateEndpointStateMayRemove(endpoint, ENDPOINT_CLOSED); + MayAutoUnlock unlocker(&lock_); + control_message_proxy_.NotifyPeerEndpointClosed(ids[i], base::nullopt); + } + + ProcessTasks(NO_DIRECT_CLIENT_CALLS, nullptr); +} + } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/multiplex_router.h b/mojo/public/cpp/bindings/lib/multiplex_router.h index cac138bcb7..8c2e7c8b0f 100644 --- a/mojo/public/cpp/bindings/lib/multiplex_router.h +++ b/mojo/public/cpp/bindings/lib/multiplex_router.h @@ -7,20 +7,21 @@ #include <stdint.h> -#include <deque> #include <map> #include <memory> #include <string> #include "base/compiler_specific.h" +#include "base/containers/queue.h" +#include "base/containers/small_map.h" #include "base/logging.h" #include "base/macros.h" #include "base/memory/ref_counted.h" #include "base/memory/weak_ptr.h" #include "base/optional.h" -#include "base/single_thread_task_runner.h" +#include "base/sequence_checker.h" +#include "base/sequenced_task_runner.h" #include "base/synchronization/lock.h" -#include "base/threading/thread_checker.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/connector.h" @@ -33,7 +34,7 @@ #include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" namespace base { -class SingleThreadTaskRunner; +class SequencedTaskRunner; } namespace mojo { @@ -43,19 +44,19 @@ namespace internal { // MultiplexRouter supports routing messages for multiple interfaces over a // single message pipe. // -// It is created on the thread where the master interface of the message pipe +// It is created on the sequence where the master interface of the message pipe // lives. Although it is ref-counted, it is guarateed to be destructed on the -// same thread. -// Some public methods are only allowed to be called on the creating thread; -// while the others are safe to call from any threads. Please see the method +// same sequence. +// Some public methods are only allowed to be called on the creating sequence; +// while the others are safe to call from any sequence. Please see the method // comments for more details. // // NOTE: CloseMessagePipe() or PassMessagePipe() MUST be called on |runner|'s -// thread before this object is destroyed. +// sequence before this object is destroyed. class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter - : NON_EXPORTED_BASE(public MessageReceiver), + : public MessageReceiver, public AssociatedGroupController, - NON_EXPORTED_BASE(public PipeControlMessageHandlerDelegate) { + public PipeControlMessageHandlerDelegate { public: enum Config { // There is only the master interface running on this router. Please note @@ -76,7 +77,11 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter MultiplexRouter(ScopedMessagePipeHandle message_pipe, Config config, bool set_interface_id_namespace_bit, - scoped_refptr<base::SingleThreadTaskRunner> runner); + scoped_refptr<base::SequencedTaskRunner> runner); + + // Adds a MessageReceiver which can filter a message after validation but + // before dispatch. + void AddIncomingMessageFilter(std::unique_ptr<MessageReceiver> filter); // Sets the master interface name for this router. Only used when reporting // message header or control message validation errors. @@ -84,7 +89,7 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter void SetMasterInterfaceName(const char* name); // --------------------------------------------------------------------------- - // The following public methods are safe to call from any threads. + // The following public methods are safe to call from any sequence. // AssociatedGroupController implementation: InterfaceId AssociateInterface( @@ -97,13 +102,14 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter InterfaceEndpointController* AttachEndpointClient( const ScopedInterfaceEndpointHandle& handle, InterfaceEndpointClient* endpoint_client, - scoped_refptr<base::SingleThreadTaskRunner> runner) override; + scoped_refptr<base::SequencedTaskRunner> runner) override; void DetachEndpointClient( const ScopedInterfaceEndpointHandle& handle) override; void RaiseError() override; + bool PrefersSerializedMessages() override; // --------------------------------------------------------------------------- - // The following public methods are called on the creating thread. + // The following public methods are called on the creating sequence. // Please note that this method shouldn't be called unless it results from an // explicit request of the user of bindings (e.g., the user sets an @@ -112,14 +118,15 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter // Extracts the underlying message pipe. ScopedMessagePipeHandle PassMessagePipe() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(!HasAssociatedEndpoints()); return connector_.PassMessagePipe(); } - // Blocks the current thread until the first incoming message, or |deadline|. + // Blocks the current sequence until the first incoming message, or + // |deadline|. bool WaitForIncomingMessage(MojoDeadline deadline) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return connector_.WaitForIncomingMessage(deadline); } @@ -137,13 +144,13 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter // Is the router bound to a message pipe handle? bool is_valid() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return connector_.is_valid(); } // TODO(yzshen): consider removing this getter. MessagePipeHandle handle() const { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); return connector_.handle(); } @@ -169,7 +176,7 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter void OnPipeConnectionError(); // Specifies whether we are allowed to directly call into - // InterfaceEndpointClient (given that we are already on the same thread as + // InterfaceEndpointClient (given that we are already on the same sequence as // the client). enum ClientCallBehavior { // Don't call any InterfaceEndpointClient methods directly. @@ -191,7 +198,7 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter // of this object, if direct calls are allowed, the caller needs to hold on to // a ref outside of |lock_| before calling this method. void ProcessTasks(ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner); + base::SequencedTaskRunner* current_task_runner); // Processes the first queued sync message for the endpoint corresponding to // |id|; returns whether there are more sync messages for that endpoint in the @@ -202,16 +209,14 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter bool ProcessFirstSyncMessageForEndpoint(InterfaceId id); // Returns true to indicate that |task|/|message| has been processed. - bool ProcessNotifyErrorTask( - Task* task, - ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner); - bool ProcessIncomingMessage( - Message* message, - ClientCallBehavior client_call_behavior, - base::SingleThreadTaskRunner* current_task_runner); - - void MaybePostToProcessTasks(base::SingleThreadTaskRunner* task_runner); + bool ProcessNotifyErrorTask(Task* task, + ClientCallBehavior client_call_behavior, + base::SequencedTaskRunner* current_task_runner); + bool ProcessIncomingMessage(MessageWrapper* message_wrapper, + ClientCallBehavior client_call_behavior, + base::SequencedTaskRunner* current_task_runner); + + void MaybePostToProcessTasks(base::SequencedTaskRunner* task_runner); void LockAndCallProcessTasks(); // Updates the state of |endpoint|. If both the endpoint and its peer have @@ -226,21 +231,25 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter InterfaceEndpoint* FindOrInsertEndpoint(InterfaceId id, bool* inserted); InterfaceEndpoint* FindEndpoint(InterfaceId id); + // Returns false if some interface IDs are invalid or have been used. + bool InsertEndpointsForMessage(const Message& message); + void CloseEndpointsForMessage(const Message& message); + void AssertLockAcquired(); // Whether to set the namespace bit when generating interface IDs. Please see // comments of kInterfaceIdNamespaceMask. const bool set_interface_id_namespace_bit_; - scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; // Owned by |filters_| below. - MessageHeaderValidator* header_validator_; + MessageHeaderValidator* header_validator_ = nullptr; FilterChain filters_; Connector connector_; - base::ThreadChecker thread_checker_; + SEQUENCE_CHECKER(sequence_checker_); // Protects the following members. // Not set in Config::SINGLE_INTERFACE* mode. @@ -250,21 +259,24 @@ class MOJO_CPP_BINDINGS_EXPORT MultiplexRouter // NOTE: It is unsafe to call into this object while holding |lock_|. PipeControlMessageProxy control_message_proxy_; - std::map<InterfaceId, scoped_refptr<InterfaceEndpoint>> endpoints_; - uint32_t next_interface_id_value_; + base::small_map<std::map<InterfaceId, scoped_refptr<InterfaceEndpoint>>, 1> + endpoints_; + uint32_t next_interface_id_value_ = 1; - std::deque<std::unique_ptr<Task>> tasks_; + base::circular_deque<std::unique_ptr<Task>> tasks_; // It refers to tasks in |tasks_| and doesn't own any of them. - std::map<InterfaceId, std::deque<Task*>> sync_message_tasks_; + std::map<InterfaceId, base::circular_deque<Task*>> sync_message_tasks_; + + bool posted_to_process_tasks_ = false; + scoped_refptr<base::SequencedTaskRunner> posted_to_task_runner_; - bool posted_to_process_tasks_; - scoped_refptr<base::SingleThreadTaskRunner> posted_to_task_runner_; + bool encountered_error_ = false; - bool encountered_error_; + bool paused_ = false; - bool paused_; + bool testing_mode_ = false; - bool testing_mode_; + bool being_destructed_ = false; DISALLOW_COPY_AND_ASSIGN(MultiplexRouter); }; diff --git a/mojo/public/cpp/bindings/lib/native_struct.cc b/mojo/public/cpp/bindings/lib/native_struct.cc deleted file mode 100644 index 7b1a1a6c59..0000000000 --- a/mojo/public/cpp/bindings/lib/native_struct.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/native_struct.h" - -#include "mojo/public/cpp/bindings/lib/hash_util.h" - -namespace mojo { - -// static -NativeStructPtr NativeStruct::New() { - return NativeStructPtr(base::in_place); -} - -NativeStruct::NativeStruct() {} - -NativeStruct::~NativeStruct() {} - -NativeStructPtr NativeStruct::Clone() const { - NativeStructPtr rv(New()); - rv->data = data; - return rv; -} - -bool NativeStruct::Equals(const NativeStruct& other) const { - return data == other.data; -} - -size_t NativeStruct::Hash(size_t seed) const { - return internal::Hash(seed, data); -} - -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/native_struct_data.cc b/mojo/public/cpp/bindings/lib/native_struct_data.cc deleted file mode 100644 index 0e5d245692..0000000000 --- a/mojo/public/cpp/bindings/lib/native_struct_data.cc +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/lib/native_struct_data.h" - -#include "mojo/public/cpp/bindings/lib/buffer.h" -#include "mojo/public/cpp/bindings/lib/validation_context.h" - -namespace mojo { -namespace internal { - -// static -bool NativeStruct_Data::Validate(const void* data, - ValidationContext* validation_context) { - const ContainerValidateParams data_validate_params(0, false, nullptr); - return Array_Data<uint8_t>::Validate(data, validation_context, - &data_validate_params); -} - -} // namespace internal -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/native_struct_data.h b/mojo/public/cpp/bindings/lib/native_struct_data.h deleted file mode 100644 index 1c7cd81c77..0000000000 --- a/mojo/public/cpp/bindings/lib/native_struct_data.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_DATA_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_DATA_H_ - -#include <vector> - -#include "mojo/public/cpp/bindings/bindings_export.h" -#include "mojo/public/cpp/bindings/lib/array_internal.h" -#include "mojo/public/cpp/system/handle.h" - -namespace mojo { -namespace internal { - -class ValidationContext; - -class MOJO_CPP_BINDINGS_EXPORT NativeStruct_Data { - public: - static bool Validate(const void* data, ValidationContext* validation_context); - - // Unlike normal structs, the memory layout is exactly the same as an array - // of uint8_t. - Array_Data<uint8_t> data; - - private: - NativeStruct_Data() = delete; - ~NativeStruct_Data() = delete; -}; - -static_assert(sizeof(Array_Data<uint8_t>) == sizeof(NativeStruct_Data), - "Mismatched NativeStruct_Data and Array_Data<uint8_t> size"); - -} // namespace internal -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_NATIVE_STRUCT_DATA_H_ diff --git a/mojo/public/cpp/bindings/lib/native_struct_serialization.cc b/mojo/public/cpp/bindings/lib/native_struct_serialization.cc index fa0dbf3803..283080089f 100644 --- a/mojo/public/cpp/bindings/lib/native_struct_serialization.cc +++ b/mojo/public/cpp/bindings/lib/native_struct_serialization.cc @@ -4,56 +4,120 @@ #include "mojo/public/cpp/bindings/lib/native_struct_serialization.h" +#include "ipc/ipc_message_attachment.h" +#include "ipc/ipc_message_attachment_set.h" +#include "ipc/native_handle_type_converters.h" #include "mojo/public/cpp/bindings/lib/serialization.h" +#include "mojo/public/cpp/bindings/lib/serialization_forward.h" namespace mojo { namespace internal { // static -size_t UnmappedNativeStructSerializerImpl::PrepareToSerialize( - const NativeStructPtr& input, +void UnmappedNativeStructSerializerImpl::Serialize( + const native::NativeStructPtr& input, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, SerializationContext* context) { if (!input) - return 0; - return internal::PrepareToSerialize<ArrayDataView<uint8_t>>(input->data, - context); + return; + + writer->Allocate(buffer); + + Array_Data<uint8_t>::BufferWriter data_writer; + const mojo::internal::ContainerValidateParams data_validate_params(0, false, + nullptr); + mojo::internal::Serialize<ArrayDataView<uint8_t>>( + input->data, buffer, &data_writer, &data_validate_params, context); + writer->data()->data.Set(data_writer.data()); + + mojo::internal::Array_Data<mojo::internal::Pointer< + native::internal::SerializedHandle_Data>>::BufferWriter handles_writer; + const mojo::internal::ContainerValidateParams handles_validate_params( + 0, false, nullptr); + mojo::internal::Serialize< + mojo::ArrayDataView<::mojo::native::SerializedHandleDataView>>( + input->handles, buffer, &handles_writer, &handles_validate_params, + context); + writer->data()->handles.Set(handles_writer.is_null() ? nullptr + : handles_writer.data()); } // static -void UnmappedNativeStructSerializerImpl::Serialize( - const NativeStructPtr& input, - Buffer* buffer, - NativeStruct_Data** output, +bool UnmappedNativeStructSerializerImpl::Deserialize( + native::internal::NativeStruct_Data* input, + native::NativeStructPtr* output, SerializationContext* context) { if (!input) { - *output = nullptr; - return; + output->reset(); + return true; } - Array_Data<uint8_t>* data = nullptr; - const ContainerValidateParams params(0, false, nullptr); - internal::Serialize<ArrayDataView<uint8_t>>(input->data, buffer, &data, - ¶ms, context); - *output = reinterpret_cast<NativeStruct_Data*>(data); + native::NativeStructDataView data_view(input, context); + return StructTraits<::mojo::native::NativeStructDataView, + native::NativeStructPtr>::Read(data_view, output); } // static -bool UnmappedNativeStructSerializerImpl::Deserialize( - NativeStruct_Data* input, - NativeStructPtr* output, +void UnmappedNativeStructSerializerImpl::SerializeMessageContents( + IPC::Message* message, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, SerializationContext* context) { - Array_Data<uint8_t>* data = reinterpret_cast<Array_Data<uint8_t>*>(input); + writer->Allocate(buffer); + + // Allocate a uint8 array, initialize its header, and copy the Pickle in. + Array_Data<uint8_t>::BufferWriter data_writer; + data_writer.Allocate(message->payload_size(), buffer); + memcpy(data_writer->storage(), message->payload(), message->payload_size()); + writer->data()->data.Set(data_writer.data()); + + if (message->attachment_set()->empty()) { + writer->data()->handles.Set(nullptr); + return; + } + + mojo::internal::Array_Data<mojo::internal::Pointer< + native::internal::SerializedHandle_Data>>::BufferWriter handles_writer; + auto* attachments = message->attachment_set(); + handles_writer.Allocate(attachments->size(), buffer); + for (unsigned i = 0; i < attachments->size(); ++i) { + native::internal::SerializedHandle_Data::BufferWriter handle_writer; + handle_writer.Allocate(buffer); + + auto attachment = attachments->GetAttachmentAt(i); + ScopedHandle handle = attachment->TakeMojoHandle(); + internal::Serializer<ScopedHandle, ScopedHandle>::Serialize( + handle, &handle_writer->the_handle, context); + handle_writer->type = static_cast<int32_t>( + mojo::ConvertTo<native::SerializedHandle::Type>(attachment->GetType())); + handles_writer.data()->at(i).Set(handle_writer.data()); + } + writer->data()->handles.Set(handles_writer.data()); +} + +// static +bool UnmappedNativeStructSerializerImpl::DeserializeMessageAttachments( + native::internal::NativeStruct_Data* data, + SerializationContext* context, + IPC::Message* message) { + if (data->handles.is_null()) + return true; - NativeStructPtr result(NativeStruct::New()); - if (!internal::Deserialize<ArrayDataView<uint8_t>>(data, &result->data, - context)) { - output = nullptr; - return false; + auto* handles_data = data->handles.Get(); + for (size_t i = 0; i < handles_data->size(); ++i) { + auto* handle_data = handles_data->at(i).Get(); + if (!handle_data) + return false; + ScopedHandle handle; + internal::Serializer<ScopedHandle, ScopedHandle>::Deserialize( + &handle_data->the_handle, &handle, context); + auto attachment = IPC::MessageAttachment::CreateFromMojoHandle( + std::move(handle), + mojo::ConvertTo<IPC::MessageAttachment::Type>( + static_cast<native::SerializedHandle::Type>(handle_data->type))); + message->attachment_set()->AddAttachment(std::move(attachment)); } - if (!result->data) - *output = nullptr; - else - result.Swap(output); return true; } diff --git a/mojo/public/cpp/bindings/lib/native_struct_serialization.h b/mojo/public/cpp/bindings/lib/native_struct_serialization.h index 457435b955..6aa4c3a4a8 100644 --- a/mojo/public/cpp/bindings/lib/native_struct_serialization.h +++ b/mojo/public/cpp/bindings/lib/native_struct_serialization.h @@ -12,59 +12,63 @@ #include "base/logging.h" #include "base/pickle.h" +#include "ipc/ipc_message.h" #include "ipc/ipc_param_traits.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/array_internal.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" -#include "mojo/public/cpp/bindings/lib/native_struct_data.h" #include "mojo/public/cpp/bindings/lib/serialization_forward.h" #include "mojo/public/cpp/bindings/lib/serialization_util.h" -#include "mojo/public/cpp/bindings/native_struct.h" -#include "mojo/public/cpp/bindings/native_struct_data_view.h" +#include "mojo/public/interfaces/bindings/native_struct.mojom.h" namespace mojo { namespace internal { +// Base class for the templated native struct serialization interface below, +// used to consolidated some shared logic and provide a basic +// Serialize/Deserialize for [Native] mojom structs which do not have a +// registered typemap in the current configuration (i.e. structs that are +// represented by a raw native::NativeStruct mojom struct in C++ bindings.) +struct MOJO_CPP_BINDINGS_EXPORT UnmappedNativeStructSerializerImpl { + static void Serialize( + const native::NativeStructPtr& input, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, + SerializationContext* context); + + static bool Deserialize(native::internal::NativeStruct_Data* input, + native::NativeStructPtr* output, + SerializationContext* context); + + static void SerializeMessageContents( + IPC::Message* message, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, + SerializationContext* context); + + static bool DeserializeMessageAttachments( + native::internal::NativeStruct_Data* data, + SerializationContext* context, + IPC::Message* message); +}; + template <typename MaybeConstUserType> struct NativeStructSerializerImpl { using UserType = typename std::remove_const<MaybeConstUserType>::type; using Traits = IPC::ParamTraits<UserType>; - static size_t PrepareToSerialize(MaybeConstUserType& value, - SerializationContext* context) { - base::PickleSizer sizer; - Traits::GetSize(&sizer, value); - return Align(sizer.payload_size() + sizeof(ArrayHeader)); - } - - static void Serialize(MaybeConstUserType& value, - Buffer* buffer, - NativeStruct_Data** out, - SerializationContext* context) { - base::Pickle pickle; - Traits::Write(&pickle, value); - -#if DCHECK_IS_ON() - base::PickleSizer sizer; - Traits::GetSize(&sizer, value); - DCHECK_EQ(sizer.payload_size(), pickle.payload_size()); -#endif - - size_t total_size = pickle.payload_size() + sizeof(ArrayHeader); - DCHECK_LT(total_size, std::numeric_limits<uint32_t>::max()); - - // Allocate a uint8 array, initialize its header, and copy the Pickle in. - ArrayHeader* header = - reinterpret_cast<ArrayHeader*>(buffer->Allocate(total_size)); - header->num_bytes = static_cast<uint32_t>(total_size); - header->num_elements = static_cast<uint32_t>(pickle.payload_size()); - memcpy(reinterpret_cast<char*>(header) + sizeof(ArrayHeader), - pickle.payload(), pickle.payload_size()); - - *out = reinterpret_cast<NativeStruct_Data*>(header); + static void Serialize( + MaybeConstUserType& value, + Buffer* buffer, + native::internal::NativeStruct_Data::BufferWriter* writer, + SerializationContext* context) { + IPC::Message message; + Traits::Write(&message, value); + UnmappedNativeStructSerializerImpl::SerializeMessageContents( + &message, buffer, writer, context); } - static bool Deserialize(NativeStruct_Data* data, + static bool Deserialize(native::internal::NativeStruct_Data* data, UserType* out, SerializationContext* context) { if (!data) @@ -82,7 +86,7 @@ struct NativeStructSerializerImpl { // Because ArrayHeader's num_bytes includes the length of the header and // Pickle's payload_size does not, we need to adjust the stored value // momentarily so Pickle can view the data. - ArrayHeader* header = reinterpret_cast<ArrayHeader*>(data); + ArrayHeader* header = reinterpret_cast<ArrayHeader*>(data->data.Get()); DCHECK_GE(header->num_bytes, sizeof(ArrayHeader)); header->num_bytes -= sizeof(ArrayHeader); @@ -90,10 +94,15 @@ struct NativeStructSerializerImpl { // Construct a view over the full Array_Data, including our hacked up // header. Pickle will infer from this that the header is 8 bytes long, // and the payload will contain all of the pickled bytes. - base::Pickle pickle_view(reinterpret_cast<const char*>(header), - header->num_bytes + sizeof(ArrayHeader)); - base::PickleIterator iter(pickle_view); - if (!Traits::Read(&pickle_view, &iter, out)) + IPC::Message message_view(reinterpret_cast<const char*>(header), + header->num_bytes + sizeof(ArrayHeader)); + base::PickleIterator iter(message_view); + if (!UnmappedNativeStructSerializerImpl::DeserializeMessageAttachments( + data, context, &message_view)) { + return false; + } + + if (!Traits::Read(&message_view, &iter, out)) return false; } @@ -104,28 +113,16 @@ struct NativeStructSerializerImpl { } }; -struct MOJO_CPP_BINDINGS_EXPORT UnmappedNativeStructSerializerImpl { - static size_t PrepareToSerialize(const NativeStructPtr& input, - SerializationContext* context); - static void Serialize(const NativeStructPtr& input, - Buffer* buffer, - NativeStruct_Data** output, - SerializationContext* context); - static bool Deserialize(NativeStruct_Data* input, - NativeStructPtr* output, - SerializationContext* context); -}; - template <> -struct NativeStructSerializerImpl<NativeStructPtr> +struct NativeStructSerializerImpl<native::NativeStructPtr> : public UnmappedNativeStructSerializerImpl {}; template <> -struct NativeStructSerializerImpl<const NativeStructPtr> +struct NativeStructSerializerImpl<const native::NativeStructPtr> : public UnmappedNativeStructSerializerImpl {}; template <typename MaybeConstUserType> -struct Serializer<NativeStructDataView, MaybeConstUserType> +struct Serializer<native::NativeStructDataView, MaybeConstUserType> : public NativeStructSerializerImpl<MaybeConstUserType> {}; } // namespace internal diff --git a/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc b/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc index d451c05a5f..d39b991e20 100644 --- a/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc +++ b/mojo/public/cpp/bindings/lib/pipe_control_message_handler.cc @@ -6,7 +6,6 @@ #include "base/logging.h" #include "mojo/public/cpp/bindings/interface_id.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" #include "mojo/public/cpp/bindings/lib/serialization.h" #include "mojo/public/cpp/bindings/lib/serialization_context.h" #include "mojo/public/cpp/bindings/lib/validation_context.h" diff --git a/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc b/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc index 1029c2c491..f218892db5 100644 --- a/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc +++ b/mojo/public/cpp/bindings/lib/pipe_control_message_proxy.cc @@ -9,8 +9,8 @@ #include "base/logging.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" #include "mojo/public/cpp/bindings/lib/serialization.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/interfaces/bindings/pipe_control_messages.mojom.h" namespace mojo { @@ -18,21 +18,16 @@ namespace { Message ConstructRunOrClosePipeMessage( pipe_control::RunOrClosePipeInputPtr input_ptr) { - internal::SerializationContext context; - auto params_ptr = pipe_control::RunOrClosePipeMessageParams::New(); params_ptr->input = std::move(input_ptr); - size_t size = internal::PrepareToSerialize< - pipe_control::RunOrClosePipeMessageParamsDataView>(params_ptr, &context); - internal::MessageBuilder builder(pipe_control::kRunOrClosePipeMessageId, 0, - size, 0); - - pipe_control::internal::RunOrClosePipeMessageParams_Data* params = nullptr; + Message message(pipe_control::kRunOrClosePipeMessageId, 0, 0, 0, nullptr); + internal::SerializationContext context; + pipe_control::internal::RunOrClosePipeMessageParams_Data::BufferWriter params; internal::Serialize<pipe_control::RunOrClosePipeMessageParamsDataView>( - params_ptr, builder.buffer(), ¶ms, &context); - builder.message()->set_interface_id(kInvalidInterfaceId); - return std::move(*builder.message()); + params_ptr, message.payload_buffer(), ¶ms, &context); + message.set_interface_id(kInvalidInterfaceId); + return message; } } // namespace diff --git a/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc b/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc index c1345079a5..2e5559ce0d 100644 --- a/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc +++ b/mojo/public/cpp/bindings/lib/scoped_interface_endpoint_handle.cc @@ -7,6 +7,7 @@ #include "base/bind.h" #include "base/logging.h" #include "base/synchronization/lock.h" +#include "base/threading/sequenced_task_runner_handle.h" #include "mojo/public/cpp/bindings/associated_group_controller.h" #include "mojo/public/cpp/bindings/lib/may_auto_lock.h" @@ -14,7 +15,7 @@ namespace mojo { // ScopedInterfaceEndpointHandle::State ---------------------------------------- -// State could be called from multiple threads. +// State could be called from multiple sequences. class ScopedInterfaceEndpointHandle::State : public base::RefCountedThreadSafe<State> { public: @@ -51,7 +52,7 @@ class ScopedInterfaceEndpointHandle::State // Intentionally keep |group_controller_| unchanged. // That is because the callback created by // CreateGroupControllerGetter() could still be used after this point, - // potentially from another thread. We would like it to continue + // potentially from another sequence. We would like it to continue // returning the same group controller. // // Imagine there is a ThreadSafeForwarder A: @@ -103,7 +104,7 @@ class ScopedInterfaceEndpointHandle::State return; } - runner_ = base::ThreadTaskRunnerHandle::Get(); + runner_ = base::SequencedTaskRunnerHandle::Get(); if (!pending_association_) { runner_->PostTask( FROM_HERE, @@ -171,7 +172,7 @@ class ScopedInterfaceEndpointHandle::State DCHECK(!IsValidInterfaceId(id_)); } - // Called by the peer, maybe from a different thread. + // Called by the peer, maybe from a different sequence. void OnAssociated(InterfaceId id, scoped_refptr<AssociatedGroupController> group_controller) { AssociationEventCallback handler; @@ -179,7 +180,7 @@ class ScopedInterfaceEndpointHandle::State internal::MayAutoLock locker(&lock_); // There may be race between Close() of endpoint A and - // NotifyPeerAssociation() of endpoint A_peer on different threads. + // NotifyPeerAssociation() of endpoint A_peer on different sequences. // Therefore, it is possible that endpoint A has been closed but it // still gets OnAssociated() call from its peer. if (!pending_association_) @@ -191,7 +192,7 @@ class ScopedInterfaceEndpointHandle::State group_controller_ = std::move(group_controller); if (!association_event_handler_.is_null()) { - if (runner_->BelongsToCurrentThread()) { + if (runner_->RunsTasksInCurrentSequence()) { handler = std::move(association_event_handler_); runner_ = nullptr; } else { @@ -207,7 +208,7 @@ class ScopedInterfaceEndpointHandle::State std::move(handler).Run(ASSOCIATED); } - // Called by the peer, maybe from a different thread. + // Called by the peer, maybe from a different sequence. void OnPeerClosedBeforeAssociation( const base::Optional<DisconnectReason>& reason) { AssociationEventCallback handler; @@ -215,7 +216,7 @@ class ScopedInterfaceEndpointHandle::State internal::MayAutoLock locker(&lock_); // There may be race between Close()/NotifyPeerAssociation() of endpoint - // A and Close() of endpoint A_peer on different threads. + // A and Close() of endpoint A_peer on different sequences. // Therefore, it is possible that endpoint A is not in pending association // state but still gets OnPeerClosedBeforeAssociation() call from its // peer. @@ -227,7 +228,7 @@ class ScopedInterfaceEndpointHandle::State peer_state_ = nullptr; if (!association_event_handler_.is_null()) { - if (runner_->BelongsToCurrentThread()) { + if (runner_->RunsTasksInCurrentSequence()) { handler = std::move(association_event_handler_); runner_ = nullptr; } else { @@ -245,7 +246,7 @@ class ScopedInterfaceEndpointHandle::State } void RunAssociationEventHandler( - scoped_refptr<base::SingleThreadTaskRunner> posted_to_runner, + scoped_refptr<base::SequencedTaskRunner> posted_to_runner, AssociationEvent event) { AssociationEventCallback handler; @@ -271,7 +272,7 @@ class ScopedInterfaceEndpointHandle::State scoped_refptr<State> peer_state_; AssociationEventCallback association_event_handler_; - scoped_refptr<base::SingleThreadTaskRunner> runner_; + scoped_refptr<base::SequencedTaskRunner> runner_; InterfaceId id_ = kInvalidInterfaceId; scoped_refptr<AssociatedGroupController> group_controller_; @@ -373,7 +374,7 @@ void ScopedInterfaceEndpointHandle::ResetInternal( base::Callback<AssociatedGroupController*()> ScopedInterfaceEndpointHandle::CreateGroupControllerGetter() const { - // We allow this callback to be run on any thread. If this handle is created + // We allow this callback to be run on any sequence. If this handle is created // in non-pending state, we don't have a lock but it should still be safe // because the group controller never changes. return base::Bind(&State::group_controller, state_); diff --git a/mojo/public/cpp/bindings/lib/sequence_local_sync_event_watcher.cc b/mojo/public/cpp/bindings/lib/sequence_local_sync_event_watcher.cc new file mode 100644 index 0000000000..f4618ffbe8 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/sequence_local_sync_event_watcher.cc @@ -0,0 +1,286 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/sequence_local_sync_event_watcher.h" + +#include <map> +#include <memory> +#include <set> + +#include "base/bind.h" +#include "base/containers/flat_set.h" +#include "base/macros.h" +#include "base/memory/ptr_util.h" +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "base/no_destructor.h" +#include "base/synchronization/lock.h" +#include "base/synchronization/waitable_event.h" +#include "base/threading/sequence_local_storage_slot.h" +#include "mojo/public/cpp/bindings/sync_event_watcher.h" + +namespace mojo { + +namespace { + +struct WatcherState; + +using WatcherStateMap = + std::map<const SequenceLocalSyncEventWatcher*, scoped_refptr<WatcherState>>; + +// Ref-counted watcher state which may outlive the watcher to which it pertains. +// This is necessary to store outside of the SequenceLocalSyncEventWatcher +// itself in order to support nested sync operations where an inner operation +// may destroy the watcher. +struct WatcherState : public base::RefCounted<WatcherState> { + WatcherState() = default; + + bool watcher_was_destroyed = false; + + private: + friend class base::RefCounted<WatcherState>; + + ~WatcherState() = default; + + DISALLOW_COPY_AND_ASSIGN(WatcherState); +}; + +} // namespace + +// Owns the WaitableEvent and SyncEventWatcher shared by all +// SequenceLocalSyncEventWatchers on a single sequence, and coordinates the +// multiplexing of those shared objects to support an arbitrary number of +// SequenceLocalSyncEventWatchers waiting and signaling potentially while +// nested. +class SequenceLocalSyncEventWatcher::SequenceLocalState { + public: + SequenceLocalState() + : event_(base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::NOT_SIGNALED), + event_watcher_(&event_, + base::BindRepeating(&SequenceLocalState::OnEventSignaled, + base::Unretained(this))), + weak_ptr_factory_(this) { + // We always allow this event handler to be awoken during any sync event on + // the sequence. Individual watchers still must opt into having such + // wake-ups propagated to them. + event_watcher_.AllowWokenUpBySyncWatchOnSameThread(); + } + + ~SequenceLocalState() {} + + // Initializes a SequenceLocalState instance in sequence-local storage if + // not already initialized. Returns a WeakPtr to the stored state object. + static base::WeakPtr<SequenceLocalState> GetOrCreate() { + auto& state_ptr = GetStorageSlot().Get(); + if (!state_ptr) + state_ptr = std::make_unique<SequenceLocalState>(); + return state_ptr->weak_ptr_factory_.GetWeakPtr(); + } + + // Registers a new watcher and returns an iterator into the WatcherStateMap to + // be used for fast access with other methods. + WatcherStateMap::iterator RegisterWatcher( + const SequenceLocalSyncEventWatcher* watcher) { + auto result = registered_watchers_.emplace( + watcher, base::MakeRefCounted<WatcherState>()); + DCHECK(result.second); + return result.first; + } + + void UnregisterWatcher(WatcherStateMap::iterator iter) { + if (top_watcher_ == iter->first) { + // If the watcher being unregistered is currently blocking in a + // |SyncWatch()| operation, we need to unblock it. Setting this flag does + // that. + top_watcher_state_->watcher_was_destroyed = true; + top_watcher_state_ = nullptr; + top_watcher_ = nullptr; + } + + { + base::AutoLock lock(ready_watchers_lock_); + ready_watchers_.erase(iter->first); + } + + registered_watchers_.erase(iter); + if (registered_watchers_.empty()) { + // If no more watchers are registered, clear our sequence-local storage. + // Deletes |this|. + GetStorageSlot().Get().reset(); + } + } + + void SignalForWatcher(const SequenceLocalSyncEventWatcher* watcher) { + bool must_signal = false; + { + base::AutoLock lock(ready_watchers_lock_); + must_signal = ready_watchers_.empty(); + ready_watchers_.insert(watcher); + } + + // If we didn't have any ready watchers before, the event may not have + // been signaled. Signal it to ensure that |OnEventSignaled()| is run. + if (must_signal) + event_.Signal(); + } + + void ResetForWatcher(const SequenceLocalSyncEventWatcher* watcher) { + base::AutoLock lock(ready_watchers_lock_); + ready_watchers_.erase(watcher); + + // No more watchers are ready, so we can reset the event. The next watcher + // to call |SignalForWatcher()| will re-signal the event. + if (ready_watchers_.empty()) + event_.Reset(); + } + + bool SyncWatch(const SequenceLocalSyncEventWatcher* watcher, + WatcherState* watcher_state, + const bool* should_stop) { + // |SyncWatch()| calls may nest arbitrarily deep on the same sequence. We + // preserve the outer watcher state on the stack and restore it once the + // innermost watch is complete. + const SequenceLocalSyncEventWatcher* outer_watcher = top_watcher_; + WatcherState* outer_watcher_state = top_watcher_state_; + + // Keep a ref on the stack so the state stays alive even if the watcher is + // destroyed. + scoped_refptr<WatcherState> top_watcher_state(watcher_state); + top_watcher_state_ = watcher_state; + top_watcher_ = watcher; + + // In addition to the caller's own stop condition, we need to interrupt the + // SyncEventWatcher if |watcher| is destroyed while we're waiting. + const bool* stop_flags[] = {should_stop, + &top_watcher_state_->watcher_was_destroyed}; + + // |SyncWatch()| may delete |this|. + auto weak_self = weak_ptr_factory_.GetWeakPtr(); + bool result = event_watcher_.SyncWatch(stop_flags, 2); + if (!weak_self) + return false; + + top_watcher_state_ = outer_watcher_state; + top_watcher_ = outer_watcher; + return result; + } + + private: + using StorageSlotType = + base::SequenceLocalStorageSlot<std::unique_ptr<SequenceLocalState>>; + static StorageSlotType& GetStorageSlot() { + static base::NoDestructor<StorageSlotType> storage; + return *storage; + } + + void OnEventSignaled(); + + // The shared event and watcher used for this sequence. + base::WaitableEvent event_; + mojo::SyncEventWatcher event_watcher_; + + // All SequenceLocalSyncEventWatchers on the current sequence have some state + // registered here. + WatcherStateMap registered_watchers_; + + // Tracks state of the top-most |SyncWatch()| invocation on the stack. + const SequenceLocalSyncEventWatcher* top_watcher_ = nullptr; + WatcherState* top_watcher_state_ = nullptr; + + // Set of all SequenceLocalSyncEventWatchers in a signaled state, guarded by + // a lock for sequence-safe signaling. + base::Lock ready_watchers_lock_; + base::flat_set<const SequenceLocalSyncEventWatcher*> ready_watchers_; + + base::WeakPtrFactory<SequenceLocalState> weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(SequenceLocalState); +}; + +void SequenceLocalSyncEventWatcher::SequenceLocalState::OnEventSignaled() { + for (;;) { + base::flat_set<const SequenceLocalSyncEventWatcher*> ready_watchers; + { + base::AutoLock lock(ready_watchers_lock_); + std::swap(ready_watchers_, ready_watchers); + } + if (ready_watchers.empty()) + return; + + auto weak_self = weak_ptr_factory_.GetWeakPtr(); + for (auto* watcher : ready_watchers) { + if (top_watcher_ == watcher || watcher->can_wake_up_during_any_watch_) { + watcher->callback_.Run(); + + // The callback may have deleted |this|. + if (!weak_self) + return; + } + } + } +} + +// Manages a watcher's reference to the sequence-local state. This hides +// implementation details from the SequenceLocalSyncEventWatcher interface. +class SequenceLocalSyncEventWatcher::Registration { + public: + explicit Registration(const SequenceLocalSyncEventWatcher* watcher) + : weak_shared_state_(SequenceLocalState::GetOrCreate()), + shared_state_(weak_shared_state_.get()), + watcher_state_iterator_(shared_state_->RegisterWatcher(watcher)), + watcher_state_(watcher_state_iterator_->second) {} + + ~Registration() { + if (weak_shared_state_) { + // Because |this| may itself be owned by sequence- or thread-local storage + // (e.g. if an interface binding lives there) we have no guarantee that + // our SequenceLocalState's storage slot will still be alive during our + // own destruction; so we have to guard against any access to it. Note + // that this uncertainty only exists within the destructor and does not + // apply to other methods on SequenceLocalSyncEventWatcher. + // + // May delete |shared_state_|. + shared_state_->UnregisterWatcher(watcher_state_iterator_); + } + } + + SequenceLocalState* shared_state() const { return shared_state_; } + WatcherState* watcher_state() { return watcher_state_.get(); } + + private: + const base::WeakPtr<SequenceLocalState> weak_shared_state_; + SequenceLocalState* const shared_state_; + WatcherStateMap::iterator watcher_state_iterator_; + const scoped_refptr<WatcherState> watcher_state_; + + DISALLOW_COPY_AND_ASSIGN(Registration); +}; + +SequenceLocalSyncEventWatcher::SequenceLocalSyncEventWatcher( + const base::RepeatingClosure& callback) + : registration_(std::make_unique<Registration>(this)), + callback_(callback) {} + +SequenceLocalSyncEventWatcher::~SequenceLocalSyncEventWatcher() = default; + +void SequenceLocalSyncEventWatcher::SignalEvent() { + registration_->shared_state()->SignalForWatcher(this); +} + +void SequenceLocalSyncEventWatcher::ResetEvent() { + registration_->shared_state()->ResetForWatcher(this); +} + +void SequenceLocalSyncEventWatcher::AllowWokenUpBySyncWatchOnSameSequence() { + can_wake_up_during_any_watch_ = true; +} + +bool SequenceLocalSyncEventWatcher::SyncWatch(const bool* should_stop) { + // NOTE: |SyncWatch()| may delete |this|. + return registration_->shared_state()->SyncWatch( + this, registration_->watcher_state(), should_stop); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/serialization.h b/mojo/public/cpp/bindings/lib/serialization.h index 2a7d288d55..8ced91ea53 100644 --- a/mojo/public/cpp/bindings/lib/serialization.h +++ b/mojo/public/cpp/bindings/lib/serialization.h @@ -7,96 +7,122 @@ #include <string.h> -#include "mojo/public/cpp/bindings/array_traits_carray.h" +#include <type_traits> + +#include "base/numerics/safe_math.h" +#include "mojo/public/cpp/bindings/array_traits_span.h" #include "mojo/public/cpp/bindings/array_traits_stl.h" #include "mojo/public/cpp/bindings/lib/array_serialization.h" +#include "mojo/public/cpp/bindings/lib/bindings_internal.h" #include "mojo/public/cpp/bindings/lib/buffer.h" -#include "mojo/public/cpp/bindings/lib/handle_interface_serialization.h" +#include "mojo/public/cpp/bindings/lib/handle_serialization.h" #include "mojo/public/cpp/bindings/lib/map_serialization.h" -#include "mojo/public/cpp/bindings/lib/native_enum_serialization.h" -#include "mojo/public/cpp/bindings/lib/native_struct_serialization.h" #include "mojo/public/cpp/bindings/lib/string_serialization.h" #include "mojo/public/cpp/bindings/lib/template_util.h" +#include "mojo/public/cpp/bindings/map_traits_flat_map.h" #include "mojo/public/cpp/bindings/map_traits_stl.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/bindings/string_traits_stl.h" -#include "mojo/public/cpp/bindings/string_traits_string16.h" #include "mojo/public/cpp/bindings/string_traits_string_piece.h" namespace mojo { namespace internal { -template <typename MojomType, typename DataArrayType, typename UserType> -DataArrayType StructSerializeImpl(UserType* input) { - static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value, - "Unexpected type."); +template <typename MojomType, typename EnableType = void> +struct MojomSerializationImplTraits; + +template <typename MojomType> +struct MojomSerializationImplTraits< + MojomType, + typename std::enable_if< + BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value>::type> { + template <typename MaybeConstUserType, typename WriterType> + static void Serialize(MaybeConstUserType& input, + Buffer* buffer, + WriterType* writer, + SerializationContext* context) { + mojo::internal::Serialize<MojomType>(input, buffer, writer, context); + } +}; + +template <typename MojomType> +struct MojomSerializationImplTraits< + MojomType, + typename std::enable_if< + BelongsTo<MojomType, MojomTypeCategory::UNION>::value>::type> { + template <typename MaybeConstUserType, typename WriterType> + static void Serialize(MaybeConstUserType& input, + Buffer* buffer, + WriterType* writer, + SerializationContext* context) { + mojo::internal::Serialize<MojomType>(input, buffer, writer, + false /* inline */, context); + } +}; +template <typename MojomType, typename UserType> +mojo::Message SerializeAsMessageImpl(UserType* input) { SerializationContext context; - size_t size = PrepareToSerialize<MojomType>(*input, &context); - DCHECK_EQ(size, Align(size)); + mojo::Message message(0, 0, 0, 0, nullptr); + typename MojomTypeTraits<MojomType>::Data::BufferWriter writer; + MojomSerializationImplTraits<MojomType>::Serialize( + *input, message.payload_buffer(), &writer, &context); + message.AttachHandlesFromSerializationContext(&context); + return message; +} +template <typename MojomType, typename DataArrayType, typename UserType> +DataArrayType SerializeImpl(UserType* input) { + static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value || + BelongsTo<MojomType, MojomTypeCategory::UNION>::value, + "Unexpected type."); + Message message = SerializeAsMessageImpl<MojomType>(input); + uint32_t size = message.payload_num_bytes(); DataArrayType result(size); - if (size == 0) - return result; - - void* result_buffer = &result.front(); - // The serialization logic requires that the buffer is 8-byte aligned. If the - // result buffer is not properly aligned, we have to do an extra copy. In - // practice, this should never happen for std::vector. - bool need_copy = !IsAligned(result_buffer); - - if (need_copy) { - // calloc sets the memory to all zero. - result_buffer = calloc(size, 1); - DCHECK(IsAligned(result_buffer)); - } - - Buffer buffer; - buffer.Initialize(result_buffer, size); - typename MojomTypeTraits<MojomType>::Data* data = nullptr; - Serialize<MojomType>(*input, &buffer, &data, &context); - - if (need_copy) { - memcpy(&result.front(), result_buffer, size); - free(result_buffer); - } - + if (size) + memcpy(&result.front(), message.payload(), size); return result; } -template <typename MojomType, typename DataArrayType, typename UserType> -bool StructDeserializeImpl(const DataArrayType& input, - UserType* output, - bool (*validate_func)(const void*, - ValidationContext*)) { - static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value, +template <typename MojomType, typename UserType> +bool DeserializeImpl(const void* data, + size_t data_num_bytes, + std::vector<mojo::ScopedHandle> handles, + UserType* output, + bool (*validate_func)(const void*, ValidationContext*)) { + static_assert(BelongsTo<MojomType, MojomTypeCategory::STRUCT>::value || + BelongsTo<MojomType, MojomTypeCategory::UNION>::value, "Unexpected type."); using DataType = typename MojomTypeTraits<MojomType>::Data; - // TODO(sammc): Use DataArrayType::empty() once WTF::Vector::empty() exists. - void* input_buffer = - input.size() == 0 - ? nullptr - : const_cast<void*>(reinterpret_cast<const void*>(&input.front())); + const void* input_buffer = data_num_bytes == 0 ? nullptr : data; + void* aligned_input_buffer = nullptr; - // Please see comments in StructSerializeImpl. + // Validation code will insist that the input buffer is aligned, so we ensure + // that here. If the input data is not aligned, we (sadly) copy into an + // aligned buffer. In practice this should happen only rarely if ever. bool need_copy = !IsAligned(input_buffer); - if (need_copy) { - input_buffer = malloc(input.size()); - DCHECK(IsAligned(input_buffer)); - memcpy(input_buffer, &input.front(), input.size()); + aligned_input_buffer = malloc(data_num_bytes); + DCHECK(IsAligned(aligned_input_buffer)); + memcpy(aligned_input_buffer, data, data_num_bytes); + input_buffer = aligned_input_buffer; } - ValidationContext validation_context(input_buffer, input.size(), 0, 0); + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(data_num_bytes)); + ValidationContext validation_context( + input_buffer, static_cast<uint32_t>(data_num_bytes), handles.size(), 0); bool result = false; if (validate_func(input_buffer, &validation_context)) { - auto data = reinterpret_cast<DataType*>(input_buffer); SerializationContext context; - result = Deserialize<MojomType>(data, output, &context); + *context.mutable_handles() = std::move(handles); + result = Deserialize<MojomType>( + reinterpret_cast<DataType*>(const_cast<void*>(input_buffer)), output, + &context); } - if (need_copy) - free(input_buffer); + if (aligned_input_buffer) + free(aligned_input_buffer); return result; } diff --git a/mojo/public/cpp/bindings/lib/serialization_context.cc b/mojo/public/cpp/bindings/lib/serialization_context.cc index e2fd5c6e18..267b54154b 100644 --- a/mojo/public/cpp/bindings/lib/serialization_context.cc +++ b/mojo/public/cpp/bindings/lib/serialization_context.cc @@ -7,50 +7,78 @@ #include <limits> #include "base/logging.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/system/core.h" namespace mojo { namespace internal { -SerializedHandleVector::SerializedHandleVector() {} +SerializationContext::SerializationContext() = default; -SerializedHandleVector::~SerializedHandleVector() { - for (auto handle : handles_) { - if (handle.is_valid()) { - MojoResult rv = MojoClose(handle.value()); - DCHECK_EQ(rv, MOJO_RESULT_OK); - } +SerializationContext::~SerializationContext() = default; + +void SerializationContext::AddHandle(mojo::ScopedHandle handle, + Handle_Data* out_data) { + if (!handle.is_valid()) { + out_data->value = kEncodedInvalidHandleValue; + } else { + DCHECK_LT(handles_.size(), std::numeric_limits<uint32_t>::max()); + out_data->value = static_cast<uint32_t>(handles_.size()); + handles_.emplace_back(std::move(handle)); } } -Handle_Data SerializedHandleVector::AddHandle(mojo::Handle handle) { - Handle_Data data; +void SerializationContext::AddInterfaceInfo( + mojo::ScopedMessagePipeHandle handle, + uint32_t version, + Interface_Data* out_data) { + AddHandle(ScopedHandle::From(std::move(handle)), &out_data->handle); + out_data->version = version; +} + +void SerializationContext::AddAssociatedEndpoint( + ScopedInterfaceEndpointHandle handle, + AssociatedEndpointHandle_Data* out_data) { if (!handle.is_valid()) { - data.value = kEncodedInvalidHandleValue; + out_data->value = kEncodedInvalidHandleValue; } else { - DCHECK_LT(handles_.size(), std::numeric_limits<uint32_t>::max()); - data.value = static_cast<uint32_t>(handles_.size()); - handles_.push_back(handle); + DCHECK_LT(associated_endpoint_handles_.size(), + std::numeric_limits<uint32_t>::max()); + out_data->value = + static_cast<uint32_t>(associated_endpoint_handles_.size()); + associated_endpoint_handles_.emplace_back(std::move(handle)); } - return data; } -mojo::Handle SerializedHandleVector::TakeHandle( - const Handle_Data& encoded_handle) { - if (!encoded_handle.is_valid()) - return mojo::Handle(); - DCHECK_LT(encoded_handle.value, handles_.size()); - return FetchAndReset(&handles_[encoded_handle.value]); +void SerializationContext::AddAssociatedInterfaceInfo( + ScopedInterfaceEndpointHandle handle, + uint32_t version, + AssociatedInterface_Data* out_data) { + AddAssociatedEndpoint(std::move(handle), &out_data->handle); + out_data->version = version; } -void SerializedHandleVector::Swap(std::vector<mojo::Handle>* other) { - handles_.swap(*other); +void SerializationContext::TakeHandlesFromMessage(Message* message) { + handles_.swap(*message->mutable_handles()); + associated_endpoint_handles_.swap( + *message->mutable_associated_endpoint_handles()); } -SerializationContext::SerializationContext() {} +mojo::ScopedHandle SerializationContext::TakeHandle( + const Handle_Data& encoded_handle) { + if (!encoded_handle.is_valid()) + return mojo::ScopedHandle(); + DCHECK_LT(encoded_handle.value, handles_.size()); + return std::move(handles_[encoded_handle.value]); +} -SerializationContext::~SerializationContext() { - DCHECK(!custom_contexts || custom_contexts->empty()); +mojo::ScopedInterfaceEndpointHandle +SerializationContext::TakeAssociatedEndpointHandle( + const AssociatedEndpointHandle_Data& encoded_handle) { + if (!encoded_handle.is_valid()) + return mojo::ScopedInterfaceEndpointHandle(); + DCHECK_LT(encoded_handle.value, associated_endpoint_handles_.size()); + return std::move(associated_endpoint_handles_[encoded_handle.value]); } } // namespace internal diff --git a/mojo/public/cpp/bindings/lib/serialization_context.h b/mojo/public/cpp/bindings/lib/serialization_context.h index a34fe3d4ed..0e3c0788dc 100644 --- a/mojo/public/cpp/bindings/lib/serialization_context.h +++ b/mojo/public/cpp/bindings/lib/serialization_context.h @@ -8,67 +8,88 @@ #include <stddef.h> #include <memory> -#include <queue> #include <vector> +#include "base/component_export.h" +#include "base/containers/stack_container.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" #include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" #include "mojo/public/cpp/system/handle.h" namespace mojo { + +class Message; + namespace internal { -// A container for handles during serialization/deserialization. -class MOJO_CPP_BINDINGS_EXPORT SerializedHandleVector { +// Context information for serialization/deserialization routines. +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) SerializationContext { public: - SerializedHandleVector(); - ~SerializedHandleVector(); + SerializationContext(); + ~SerializationContext(); - size_t size() const { return handles_.size(); } + // Adds a handle to the handle list and outputs its serialized form in + // |*out_data|. + void AddHandle(mojo::ScopedHandle handle, Handle_Data* out_data); + + // Adds an interface info to the handle list and outputs its serialized form + // in |*out_data|. + void AddInterfaceInfo(mojo::ScopedMessagePipeHandle handle, + uint32_t version, + Interface_Data* out_data); + + // Adds an associated interface endpoint (for e.g. an + // AssociatedInterfaceRequest) to this context and outputs its serialized form + // in |*out_data|. + void AddAssociatedEndpoint(ScopedInterfaceEndpointHandle handle, + AssociatedEndpointHandle_Data* out_data); + + // Adds an associated interface info to associated endpoint handle and version + // data lists and outputs its serialized form in |*out_data|. + void AddAssociatedInterfaceInfo(ScopedInterfaceEndpointHandle handle, + uint32_t version, + AssociatedInterface_Data* out_data); + + const std::vector<mojo::ScopedHandle>* handles() { return &handles_; } + std::vector<mojo::ScopedHandle>* mutable_handles() { return &handles_; } + + const std::vector<ScopedInterfaceEndpointHandle>* + associated_endpoint_handles() const { + return &associated_endpoint_handles_; + } + std::vector<ScopedInterfaceEndpointHandle>* + mutable_associated_endpoint_handles() { + return &associated_endpoint_handles_; + } - // Adds a handle to the handle list and returns its index for encoding. - Handle_Data AddHandle(mojo::Handle handle); + // Takes handles from a received Message object and assumes ownership of them. + // Individual handles can be extracted using Take* methods below. + void TakeHandlesFromMessage(Message* message); // Takes a handle from the list of serialized handle data. - mojo::Handle TakeHandle(const Handle_Data& encoded_handle); + mojo::ScopedHandle TakeHandle(const Handle_Data& encoded_handle); // Takes a handle from the list of serialized handle data and returns it in // |*out_handle| as a specific scoped handle type. template <typename T> ScopedHandleBase<T> TakeHandleAs(const Handle_Data& encoded_handle) { - return MakeScopedHandle(T(TakeHandle(encoded_handle).value())); + return ScopedHandleBase<T>::From(TakeHandle(encoded_handle)); } - // Swaps all owned handles out with another Handle vector. - void Swap(std::vector<mojo::Handle>* other); + mojo::ScopedInterfaceEndpointHandle TakeAssociatedEndpointHandle( + const AssociatedEndpointHandle_Data& encoded_handle); private: - // Handles are owned by this object. - std::vector<mojo::Handle> handles_; - - DISALLOW_COPY_AND_ASSIGN(SerializedHandleVector); -}; - -// Context information for serialization/deserialization routines. -struct MOJO_CPP_BINDINGS_EXPORT SerializationContext { - SerializationContext(); - - ~SerializationContext(); - - // Opaque context pointers returned by StringTraits::SetUpContext(). - std::unique_ptr<std::queue<void*>> custom_contexts; - - // Stashes handles encoded in a message by index. - SerializedHandleVector handles; - - // The number of ScopedInterfaceEndpointHandles that need to be serialized. - // It is calculated by PrepareToSerialize(). - uint32_t associated_endpoint_count = 0; + // Handles owned by this object. Used during serialization to hold onto + // handles accumulated during pre-serialization, and used during + // deserialization to hold onto handles extracted from a message. + std::vector<mojo::ScopedHandle> handles_; // Stashes ScopedInterfaceEndpointHandles encoded in a message by index. - std::vector<ScopedInterfaceEndpointHandle> associated_endpoint_handles; + std::vector<ScopedInterfaceEndpointHandle> associated_endpoint_handles_; + + DISALLOW_COPY_AND_ASSIGN(SerializationContext); }; } // namespace internal diff --git a/mojo/public/cpp/bindings/lib/serialization_forward.h b/mojo/public/cpp/bindings/lib/serialization_forward.h index 55c9982ccc..562951ee4a 100644 --- a/mojo/public/cpp/bindings/lib/serialization_forward.h +++ b/mojo/public/cpp/bindings/lib/serialization_forward.h @@ -33,22 +33,6 @@ struct IsOptionalWrapper { typename std::remove_reference<T>::type>::type>::value; }; -// PrepareToSerialize() must be matched by a Serialize() for the same input -// later. Moreover, within the same SerializationContext if PrepareToSerialize() -// is called for |input_1|, ..., |input_n|, Serialize() must be called for -// those objects in the exact same order. -template <typename MojomType, - typename InputUserType, - typename... Args, - typename std::enable_if< - !IsOptionalWrapper<InputUserType>::value>::type* = nullptr> -size_t PrepareToSerialize(InputUserType&& input, Args&&... args) { - return Serializer<MojomType, - typename std::remove_reference<InputUserType>::type>:: - PrepareToSerialize(std::forward<InputUserType>(input), - std::forward<Args>(args)...); -} - template <typename MojomType, typename InputUserType, typename... Args, @@ -71,33 +55,19 @@ bool Deserialize(DataType&& input, InputUserType* output, Args&&... args) { std::forward<DataType>(input), output, std::forward<Args>(args)...); } -// Specialization that unwraps base::Optional<>. template <typename MojomType, typename InputUserType, - typename... Args, - typename std::enable_if< - IsOptionalWrapper<InputUserType>::value>::type* = nullptr> -size_t PrepareToSerialize(InputUserType&& input, Args&&... args) { - if (!input) - return 0; - return PrepareToSerialize<MojomType>(*input, std::forward<Args>(args)...); -} - -template <typename MojomType, - typename InputUserType, - typename DataType, + typename BufferWriterType, typename... Args, typename std::enable_if< IsOptionalWrapper<InputUserType>::value>::type* = nullptr> void Serialize(InputUserType&& input, Buffer* buffer, - DataType** output, + BufferWriterType* writer, Args&&... args) { - if (!input) { - *output = nullptr; + if (!input) return; - } - Serialize<MojomType>(*input, buffer, output, std::forward<Args>(args)...); + Serialize<MojomType>(*input, buffer, writer, std::forward<Args>(args)...); } template <typename MojomType, diff --git a/mojo/public/cpp/bindings/lib/serialization_util.h b/mojo/public/cpp/bindings/lib/serialization_util.h index 4820a014ec..a7a99b3bb7 100644 --- a/mojo/public/cpp/bindings/lib/serialization_util.h +++ b/mojo/public/cpp/bindings/lib/serialization_util.h @@ -21,7 +21,7 @@ namespace internal { template <typename T> struct HasIsNullMethod { template <typename U> - static char Test(decltype(U::IsNull)*); + static char Test(decltype(U::IsNull) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); @@ -48,7 +48,7 @@ bool CallIsNullIfExists(const UserType& input) { template <typename T> struct HasSetToNullMethod { template <typename U> - static char Test(decltype(U::SetToNull)*); + static char Test(decltype(U::SetToNull) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); @@ -80,7 +80,7 @@ bool CallSetToNullIfExists(UserType* output) { template <typename T> struct HasSetUpContextMethod { template <typename U> - static char Test(decltype(U::SetUpContext)*); + static char Test(decltype(U::SetUpContext) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); @@ -97,17 +97,7 @@ template <typename Traits> struct CustomContextHelper<Traits, true> { template <typename MaybeConstUserType> static void* SetUp(MaybeConstUserType& input, SerializationContext* context) { - void* custom_context = Traits::SetUpContext(input); - if (!context->custom_contexts) - context->custom_contexts.reset(new std::queue<void*>()); - context->custom_contexts->push(custom_context); - return custom_context; - } - - static void* GetNext(SerializationContext* context) { - void* custom_context = context->custom_contexts->front(); - context->custom_contexts->pop(); - return custom_context; + return Traits::SetUpContext(input); } template <typename MaybeConstUserType> @@ -123,8 +113,6 @@ struct CustomContextHelper<Traits, false> { return nullptr; } - static void* GetNext(SerializationContext* context) { return nullptr; } - template <typename MaybeConstUserType> static void TearDown(MaybeConstUserType& input, void* custom_context) { DCHECK(!custom_context); @@ -148,7 +136,8 @@ ReturnType CallWithContext(ReturnType (*f)(ParamType), template <typename T, typename MaybeConstUserType> struct HasGetBeginMethod { template <typename U> - static char Test(decltype(U::GetBegin(std::declval<MaybeConstUserType&>()))*); + static char Test( + decltype(U::GetBegin(std::declval<MaybeConstUserType&>())) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); @@ -179,7 +168,7 @@ size_t CallGetBeginIfExists(MaybeConstUserType& input) { template <typename T, typename MaybeConstUserType> struct HasGetDataMethod { template <typename U> - static char Test(decltype(U::GetData(std::declval<MaybeConstUserType&>()))*); + static char Test(decltype(U::GetData(std::declval<MaybeConstUserType&>())) *); template <typename U> static int Test(...); static const bool value = sizeof(Test<T>(0)) == sizeof(char); diff --git a/mojo/public/cpp/bindings/lib/string_serialization.h b/mojo/public/cpp/bindings/lib/string_serialization.h index 6e0c758576..1fe6b87af7 100644 --- a/mojo/public/cpp/bindings/lib/string_serialization.h +++ b/mojo/public/cpp/bindings/lib/string_serialization.h @@ -22,36 +22,18 @@ struct Serializer<StringDataView, MaybeConstUserType> { using UserType = typename std::remove_const<MaybeConstUserType>::type; using Traits = StringTraits<UserType>; - static size_t PrepareToSerialize(MaybeConstUserType& input, - SerializationContext* context) { - if (CallIsNullIfExists<Traits>(input)) - return 0; - - void* custom_context = CustomContextHelper<Traits>::SetUp(input, context); - return Align(sizeof(String_Data) + - CallWithContext(Traits::GetSize, input, custom_context)); - } - static void Serialize(MaybeConstUserType& input, Buffer* buffer, - String_Data** output, + String_Data::BufferWriter* writer, SerializationContext* context) { - if (CallIsNullIfExists<Traits>(input)) { - *output = nullptr; + if (CallIsNullIfExists<Traits>(input)) return; - } - - void* custom_context = CustomContextHelper<Traits>::GetNext(context); - - String_Data* result = String_Data::New( - CallWithContext(Traits::GetSize, input, custom_context), buffer); - if (result) { - memcpy(result->storage(), - CallWithContext(Traits::GetData, input, custom_context), - CallWithContext(Traits::GetSize, input, custom_context)); - } - *output = result; + void* custom_context = CustomContextHelper<Traits>::SetUp(input, context); + const size_t size = CallWithContext(Traits::GetSize, input, custom_context); + writer->Allocate(size, buffer); + memcpy((*writer)->storage(), + CallWithContext(Traits::GetData, input, custom_context), size); CustomContextHelper<Traits>::TearDown(input, custom_context); } diff --git a/mojo/public/cpp/bindings/lib/string_traits_string16.cc b/mojo/public/cpp/bindings/lib/string_traits_string16.cc deleted file mode 100644 index 95ff6ccf25..0000000000 --- a/mojo/public/cpp/bindings/lib/string_traits_string16.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "mojo/public/cpp/bindings/string_traits_string16.h" - -#include <string> - -#include "base/strings/utf_string_conversions.h" - -namespace mojo { - -// static -void* StringTraits<base::string16>::SetUpContext(const base::string16& input) { - return new std::string(base::UTF16ToUTF8(input)); -} - -// static -void StringTraits<base::string16>::TearDownContext(const base::string16& input, - void* context) { - delete static_cast<std::string*>(context); -} - -// static -size_t StringTraits<base::string16>::GetSize(const base::string16& input, - void* context) { - return static_cast<std::string*>(context)->size(); -} - -// static -const char* StringTraits<base::string16>::GetData(const base::string16& input, - void* context) { - return static_cast<std::string*>(context)->data(); -} - -// static -bool StringTraits<base::string16>::Read(StringDataView input, - base::string16* output) { - return base::UTF8ToUTF16(input.storage(), input.size(), output); -} - -} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/string_traits_wtf.cc b/mojo/public/cpp/bindings/lib/string_traits_wtf.cc index 203f6f5903..71b758c49c 100644 --- a/mojo/public/cpp/bindings/lib/string_traits_wtf.cc +++ b/mojo/public/cpp/bindings/lib/string_traits_wtf.cc @@ -8,7 +8,8 @@ #include "base/logging.h" #include "mojo/public/cpp/bindings/lib/array_internal.h" -#include "third_party/WebKit/Source/wtf/text/StringUTF8Adaptor.h" +#include "mojo/public/cpp/bindings/string_data_view.h" +#include "third_party/blink/renderer/platform/wtf/text/string_utf8_adaptor.h" namespace mojo { namespace { @@ -16,7 +17,7 @@ namespace { struct UTF8AdaptorInfo { explicit UTF8AdaptorInfo(const WTF::String& input) : utf8_adaptor(input) { #if DCHECK_IS_ON() - original_size_in_bytes = input.charactersSizeInBytes(); + original_size_in_bytes = input.CharactersSizeInBytes(); #endif } @@ -34,7 +35,7 @@ UTF8AdaptorInfo* ToAdaptor(const WTF::String& input, void* context) { UTF8AdaptorInfo* adaptor = static_cast<UTF8AdaptorInfo*>(context); #if DCHECK_IS_ON() - DCHECK_EQ(adaptor->original_size_in_bytes, input.charactersSizeInBytes()); + DCHECK_EQ(adaptor->original_size_in_bytes, input.CharactersSizeInBytes()); #endif return adaptor; } @@ -43,7 +44,7 @@ UTF8AdaptorInfo* ToAdaptor(const WTF::String& input, void* context) { // static void StringTraits<WTF::String>::SetToNull(WTF::String* output) { - if (output->isNull()) + if (output->IsNull()) return; WTF::String result; @@ -70,13 +71,13 @@ size_t StringTraits<WTF::String>::GetSize(const WTF::String& input, // static const char* StringTraits<WTF::String>::GetData(const WTF::String& input, void* context) { - return ToAdaptor(input, context)->utf8_adaptor.data(); + return ToAdaptor(input, context)->utf8_adaptor.Data(); } // static bool StringTraits<WTF::String>::Read(StringDataView input, WTF::String* output) { - WTF::String result = WTF::String::fromUTF8(input.storage(), input.size()); + WTF::String result = WTF::String::FromUTF8(input.storage(), input.size()); output->swap(result); return true; } diff --git a/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc b/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc index 585a8f094c..2b359861d7 100644 --- a/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc +++ b/mojo/public/cpp/bindings/lib/sync_call_restrictions.cc @@ -7,85 +7,79 @@ #if ENABLE_SYNC_CALL_RESTRICTIONS #include "base/debug/leak_annotations.h" -#include "base/lazy_instance.h" #include "base/logging.h" -#include "base/threading/thread_local.h" +#include "base/macros.h" +#include "base/no_destructor.h" +#include "base/synchronization/lock.h" +#include "base/threading/sequence_local_storage_slot.h" #include "mojo/public/c/system/core.h" namespace mojo { namespace { -class SyncCallSettings { +class GlobalSyncCallSettings { public: - static SyncCallSettings* current(); + GlobalSyncCallSettings() = default; + ~GlobalSyncCallSettings() = default; - bool allowed() const { - return scoped_allow_count_ > 0 || system_defined_value_; + bool sync_call_allowed_by_default() const { + base::AutoLock lock(lock_); + return sync_call_allowed_by_default_; } - void IncreaseScopedAllowCount() { scoped_allow_count_++; } - void DecreaseScopedAllowCount() { - DCHECK_LT(0u, scoped_allow_count_); - scoped_allow_count_--; + void DisallowSyncCallByDefault() { + base::AutoLock lock(lock_); + sync_call_allowed_by_default_ = false; } private: - SyncCallSettings(); - ~SyncCallSettings(); + mutable base::Lock lock_; + bool sync_call_allowed_by_default_ = true; - bool system_defined_value_ = true; - size_t scoped_allow_count_ = 0; + DISALLOW_COPY_AND_ASSIGN(GlobalSyncCallSettings); }; -base::LazyInstance<base::ThreadLocalPointer<SyncCallSettings>>::DestructorAtExit - g_sync_call_settings = LAZY_INSTANCE_INITIALIZER; - -// static -SyncCallSettings* SyncCallSettings::current() { - SyncCallSettings* result = g_sync_call_settings.Pointer()->Get(); - if (!result) { - result = new SyncCallSettings(); - ANNOTATE_LEAKING_OBJECT_PTR(result); - DCHECK_EQ(result, g_sync_call_settings.Pointer()->Get()); - } - return result; -} - -SyncCallSettings::SyncCallSettings() { - MojoResult result = MojoGetProperty(MOJO_PROPERTY_TYPE_SYNC_CALL_ALLOWED, - &system_defined_value_); - DCHECK_EQ(MOJO_RESULT_OK, result); - - DCHECK(!g_sync_call_settings.Pointer()->Get()); - g_sync_call_settings.Pointer()->Set(this); +GlobalSyncCallSettings& GetGlobalSettings() { + static base::NoDestructor<GlobalSyncCallSettings> global_settings; + return *global_settings; } -SyncCallSettings::~SyncCallSettings() { - g_sync_call_settings.Pointer()->Set(nullptr); +size_t& GetSequenceLocalScopedAllowCount() { + static base::NoDestructor<base::SequenceLocalStorageSlot<size_t>> count; + return count->Get(); } } // namespace // static void SyncCallRestrictions::AssertSyncCallAllowed() { - if (!SyncCallSettings::current()->allowed()) { - LOG(FATAL) << "Mojo sync calls are not allowed in this process because " - << "they can lead to jank and deadlock. If you must make an " - << "exception, please see " - << "SyncCallRestrictions::ScopedAllowSyncCall and consult " - << "mojo/OWNERS."; - } + if (GetGlobalSettings().sync_call_allowed_by_default()) + return; + if (GetSequenceLocalScopedAllowCount() > 0) + return; + + LOG(FATAL) << "Mojo sync calls are not allowed in this process because " + << "they can lead to jank and deadlock. If you must make an " + << "exception, please see " + << "SyncCallRestrictions::ScopedAllowSyncCall and consult " + << "mojo/OWNERS."; +} + +// static +void SyncCallRestrictions::DisallowSyncCall() { + GetGlobalSettings().DisallowSyncCallByDefault(); } // static void SyncCallRestrictions::IncreaseScopedAllowCount() { - SyncCallSettings::current()->IncreaseScopedAllowCount(); + ++GetSequenceLocalScopedAllowCount(); } // static void SyncCallRestrictions::DecreaseScopedAllowCount() { - SyncCallSettings::current()->DecreaseScopedAllowCount(); + DCHECK_GT(GetSequenceLocalScopedAllowCount(), 0u); + --GetSequenceLocalScopedAllowCount(); } } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/sync_event_watcher.cc b/mojo/public/cpp/bindings/lib/sync_event_watcher.cc index b1c97e3691..17165912fc 100644 --- a/mojo/public/cpp/bindings/lib/sync_event_watcher.cc +++ b/mojo/public/cpp/bindings/lib/sync_event_watcher.cc @@ -4,6 +4,9 @@ #include "mojo/public/cpp/bindings/sync_event_watcher.h" +#include <algorithm> + +#include "base/containers/stack_container.h" #include "base/logging.h" namespace mojo { @@ -16,19 +19,20 @@ SyncEventWatcher::SyncEventWatcher(base::WaitableEvent* event, destroyed_(new base::RefCountedData<bool>(false)) {} SyncEventWatcher::~SyncEventWatcher() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (registered_) - registry_->UnregisterEvent(event_); + registry_->UnregisterEvent(event_, callback_); destroyed_->data = true; } void SyncEventWatcher::AllowWokenUpBySyncWatchOnSameThread() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); IncrementRegisterCount(); } -bool SyncEventWatcher::SyncWatch(const bool* should_stop) { - DCHECK(thread_checker_.CalledOnValidThread()); +bool SyncEventWatcher::SyncWatch(const bool** stop_flags, + size_t num_stop_flags) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); IncrementRegisterCount(); if (!registered_) { DecrementRegisterCount(); @@ -38,8 +42,14 @@ bool SyncEventWatcher::SyncWatch(const bool* should_stop) { // This object may be destroyed during the Wait() call. So we have to preserve // the boolean that Wait uses. auto destroyed = destroyed_; - const bool* should_stop_array[] = {should_stop, &destroyed->data}; - bool result = registry_->Wait(should_stop_array, 2); + + constexpr size_t kFlagStackCapacity = 4; + base::StackVector<const bool*, kFlagStackCapacity> should_stop_array; + should_stop_array.container().push_back(&destroyed->data); + std::copy(stop_flags, stop_flags + num_stop_flags, + std::back_inserter(should_stop_array.container())); + bool result = registry_->Wait(should_stop_array.container().data(), + should_stop_array.container().size()); // This object has been destroyed. if (destroyed->data) @@ -51,15 +61,17 @@ bool SyncEventWatcher::SyncWatch(const bool* should_stop) { void SyncEventWatcher::IncrementRegisterCount() { register_request_count_++; - if (!registered_) - registered_ = registry_->RegisterEvent(event_, callback_); + if (!registered_) { + registry_->RegisterEvent(event_, callback_); + registered_ = true; + } } void SyncEventWatcher::DecrementRegisterCount() { DCHECK_GT(register_request_count_, 0u); register_request_count_--; if (register_request_count_ == 0 && registered_) { - registry_->UnregisterEvent(event_); + registry_->UnregisterEvent(event_, callback_); registered_ = false; } } diff --git a/mojo/public/cpp/bindings/lib/sync_handle_registry.cc b/mojo/public/cpp/bindings/lib/sync_handle_registry.cc index fd3df396ec..2ac4833445 100644 --- a/mojo/public/cpp/bindings/lib/sync_handle_registry.cc +++ b/mojo/public/cpp/bindings/lib/sync_handle_registry.cc @@ -4,27 +4,36 @@ #include "mojo/public/cpp/bindings/sync_handle_registry.h" +#include <algorithm> + #include "base/lazy_instance.h" #include "base/logging.h" #include "base/stl_util.h" -#include "base/threading/thread_local.h" +#include "base/threading/sequence_local_storage_slot.h" +#include "base/threading/sequenced_task_runner_handle.h" #include "mojo/public/c/system/core.h" namespace mojo { namespace { -base::LazyInstance<base::ThreadLocalPointer<SyncHandleRegistry>>::Leaky +base::LazyInstance< + base::SequenceLocalStorageSlot<scoped_refptr<SyncHandleRegistry>>>::Leaky g_current_sync_handle_watcher = LAZY_INSTANCE_INITIALIZER; } // namespace // static scoped_refptr<SyncHandleRegistry> SyncHandleRegistry::current() { - scoped_refptr<SyncHandleRegistry> result( - g_current_sync_handle_watcher.Pointer()->Get()); + // SyncMessageFilter can be used on threads without sequence-local storage + // being available. Those receive a unique, standalone SyncHandleRegistry. + if (!base::SequencedTaskRunnerHandle::IsSet()) + return new SyncHandleRegistry(); + + scoped_refptr<SyncHandleRegistry> result = + g_current_sync_handle_watcher.Get().Get(); if (!result) { result = new SyncHandleRegistry(); - DCHECK_EQ(result.get(), g_current_sync_handle_watcher.Pointer()->Get()); + g_current_sync_handle_watcher.Get().Set(result); } return result; } @@ -32,7 +41,7 @@ scoped_refptr<SyncHandleRegistry> SyncHandleRegistry::current() { bool SyncHandleRegistry::RegisterHandle(const Handle& handle, MojoHandleSignals handle_signals, const HandleCallback& callback) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (base::ContainsKey(handles_, handle)) return false; @@ -46,7 +55,7 @@ bool SyncHandleRegistry::RegisterHandle(const Handle& handle, } void SyncHandleRegistry::UnregisterHandle(const Handle& handle) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!base::ContainsKey(handles_, handle)) return; @@ -55,27 +64,63 @@ void SyncHandleRegistry::UnregisterHandle(const Handle& handle) { handles_.erase(handle); } -bool SyncHandleRegistry::RegisterEvent(base::WaitableEvent* event, +void SyncHandleRegistry::RegisterEvent(base::WaitableEvent* event, const base::Closure& callback) { - auto result = events_.insert({event, callback}); - DCHECK(result.second); - MojoResult rv = wait_set_.AddEvent(event); - if (rv == MOJO_RESULT_OK) - return true; - DCHECK_EQ(MOJO_RESULT_ALREADY_EXISTS, rv); - return false; + auto it = events_.find(event); + if (it == events_.end()) { + auto result = events_.emplace(event, EventCallbackList{}); + it = result.first; + } + + // The event may already be in the WaitSet, but we don't care. This will be a + // no-op in that case, which is more efficient than scanning the list of + // callbacks to see if any are valid. + wait_set_.AddEvent(event); + + it->second.container().push_back(callback); } -void SyncHandleRegistry::UnregisterEvent(base::WaitableEvent* event) { +void SyncHandleRegistry::UnregisterEvent(base::WaitableEvent* event, + const base::Closure& callback) { auto it = events_.find(event); - DCHECK(it != events_.end()); - events_.erase(it); - MojoResult rv = wait_set_.RemoveEvent(event); - DCHECK_EQ(MOJO_RESULT_OK, rv); + if (it == events_.end()) + return; + + bool has_valid_callbacks = false; + auto& callbacks = it->second.container(); + if (is_dispatching_event_callbacks_) { + // Not safe to remove any elements from |callbacks| here since an outer + // stack frame is currently iterating over it in Wait(). + for (auto& cb : callbacks) { + if (cb.Equals(callback)) + cb.Reset(); + else if (cb) + has_valid_callbacks = true; + } + remove_invalid_event_callbacks_after_dispatch_ = true; + } else { + callbacks.erase(std::remove_if(callbacks.begin(), callbacks.end(), + [&callback](const base::Closure& cb) { + return cb.Equals(callback); + }), + callbacks.end()); + if (callbacks.empty()) + events_.erase(it); + else + has_valid_callbacks = true; + } + + if (!has_valid_callbacks) { + // Regardless of whether or not we're nested within a Wait(), we need to + // ensure that |event| is removed from the WaitSet before returning if this + // was the last callback registered for it. + MojoResult rv = wait_set_.RemoveEvent(event); + DCHECK_EQ(MOJO_RESULT_OK, rv); + } } bool SyncHandleRegistry::Wait(const bool* should_stop[], size_t count) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); size_t num_ready_handles; Handle ready_handle; @@ -83,9 +128,10 @@ bool SyncHandleRegistry::Wait(const bool* should_stop[], size_t count) { scoped_refptr<SyncHandleRegistry> preserver(this); while (true) { - for (size_t i = 0; i < count; ++i) + for (size_t i = 0; i < count; ++i) { if (*should_stop[i]) return true; + } // TODO(yzshen): Theoretically it can reduce sync call re-entrancy if we // give priority to the handle that is waiting for sync response. @@ -102,34 +148,51 @@ bool SyncHandleRegistry::Wait(const bool* should_stop[], size_t count) { if (ready_event) { const auto iter = events_.find(ready_event); DCHECK(iter != events_.end()); - iter->second.Run(); + bool was_dispatching_event_callbacks = is_dispatching_event_callbacks_; + is_dispatching_event_callbacks_ = true; + + // NOTE: It's possible for the container to be extended by any of these + // callbacks if they call RegisterEvent, so we are careful to iterate by + // index. Also note that conversely, elements cannot be *removed* from the + // container, by any of these callbacks, so it is safe to assume the size + // only stays the same or increases, with no elements changing position. + auto& callbacks = iter->second.container(); + for (size_t i = 0; i < callbacks.size(); ++i) { + auto& callback = callbacks[i]; + if (callback) + callback.Run(); + } + + is_dispatching_event_callbacks_ = was_dispatching_event_callbacks; + if (!was_dispatching_event_callbacks && + remove_invalid_event_callbacks_after_dispatch_) { + // If we've had events unregistered within any callback dispatch, now is + // a good time to prune them from the map. + RemoveInvalidEventCallbacks(); + remove_invalid_event_callbacks_after_dispatch_ = false; + } } }; return false; } -SyncHandleRegistry::SyncHandleRegistry() { - DCHECK(!g_current_sync_handle_watcher.Pointer()->Get()); - g_current_sync_handle_watcher.Pointer()->Set(this); -} - -SyncHandleRegistry::~SyncHandleRegistry() { - DCHECK(thread_checker_.CalledOnValidThread()); - - // This object may be destructed after the thread local storage slot used by - // |g_current_sync_handle_watcher| is reset during thread shutdown. - // For example, another slot in the thread local storage holds a referrence to - // this object, and that slot is cleaned up after - // |g_current_sync_handle_watcher|. - if (!g_current_sync_handle_watcher.Pointer()->Get()) - return; - - // If this breaks, it is likely that the global variable is bulit into and - // accessed from multiple modules. - DCHECK_EQ(this, g_current_sync_handle_watcher.Pointer()->Get()); - - g_current_sync_handle_watcher.Pointer()->Set(nullptr); +SyncHandleRegistry::SyncHandleRegistry() = default; + +SyncHandleRegistry::~SyncHandleRegistry() = default; + +void SyncHandleRegistry::RemoveInvalidEventCallbacks() { + for (auto it = events_.begin(); it != events_.end();) { + auto& callbacks = it->second.container(); + callbacks.erase( + std::remove_if(callbacks.begin(), callbacks.end(), + [](const base::Closure& callback) { return !callback; }), + callbacks.end()); + if (callbacks.empty()) + events_.erase(it++); + else + ++it; + } } } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc b/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc index f20af56b20..294b8a1a4b 100644 --- a/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc +++ b/mojo/public/cpp/bindings/lib/sync_handle_watcher.cc @@ -21,7 +21,7 @@ SyncHandleWatcher::SyncHandleWatcher( destroyed_(new base::RefCountedData<bool>(false)) {} SyncHandleWatcher::~SyncHandleWatcher() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (registered_) registry_->UnregisterHandle(handle_); @@ -29,12 +29,12 @@ SyncHandleWatcher::~SyncHandleWatcher() { } void SyncHandleWatcher::AllowWokenUpBySyncWatchOnSameThread() { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); IncrementRegisterCount(); } bool SyncHandleWatcher::SyncWatch(const bool* should_stop) { - DCHECK(thread_checker_.CalledOnValidThread()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); IncrementRegisterCount(); if (!registered_) { DecrementRegisterCount(); diff --git a/mojo/public/cpp/bindings/lib/task_runner_helper.cc b/mojo/public/cpp/bindings/lib/task_runner_helper.cc new file mode 100644 index 0000000000..6104a9740e --- /dev/null +++ b/mojo/public/cpp/bindings/lib/task_runner_helper.cc @@ -0,0 +1,24 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/task_runner_helper.h" + +#include "base/sequenced_task_runner.h" +#include "base/threading/sequenced_task_runner_handle.h" + +namespace mojo { +namespace internal { + +scoped_refptr<base::SequencedTaskRunner> +GetTaskRunnerToUseFromUserProvidedTaskRunner( + scoped_refptr<base::SequencedTaskRunner> runner) { + if (runner) { + DCHECK(runner->RunsTasksInCurrentSequence()); + return runner; + } + return base::SequencedTaskRunnerHandle::Get(); +} + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/task_runner_helper.h b/mojo/public/cpp/bindings/lib/task_runner_helper.h new file mode 100644 index 0000000000..d34d179675 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/task_runner_helper.h @@ -0,0 +1,28 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_TASK_RUNNER_HELPER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_TASK_RUNNER_HELPER_H_ + +#include "base/memory/ref_counted.h" + +namespace base { +class SequencedTaskRunner; +} // namespace base + +namespace mojo { +namespace internal { + +// Returns the SequencedTaskRunner to use from the optional user-provided +// SequencedTaskRunner. If |runner| is provided non-null, it is returned. +// Otherwise, SequencedTaskRunnerHandle::Get() is returned. If |runner| is non- +// null, it must run tasks on the current sequence. +scoped_refptr<base::SequencedTaskRunner> +GetTaskRunnerToUseFromUserProvidedTaskRunner( + scoped_refptr<base::SequencedTaskRunner> runner); + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_TASK_RUNNER_HELPER_H_ diff --git a/mojo/public/cpp/bindings/lib/template_util.h b/mojo/public/cpp/bindings/lib/template_util.h index 5151123ac0..383eb91593 100644 --- a/mojo/public/cpp/bindings/lib/template_util.h +++ b/mojo/public/cpp/bindings/lib/template_util.h @@ -114,6 +114,11 @@ struct Conditional<false, T, F> { typedef F type; }; +template <typename T> +struct AlwaysFalse { + static const bool value = false; +}; + } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/union_accessor.h b/mojo/public/cpp/bindings/lib/union_accessor.h deleted file mode 100644 index 821aede595..0000000000 --- a/mojo/public/cpp/bindings/lib/union_accessor.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2015 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_UNION_ACCESSOR_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_LIB_UNION_ACCESSOR_H_ - -namespace mojo { -namespace internal { - -// When serializing and deserializing Unions, it is necessary to access -// the private fields and methods of the Union. This allows us to do that -// without leaking those same fields and methods in the Union interface. -// All Union wrappers are friends of this class allowing such access. -template <typename U> -class UnionAccessor { - public: - explicit UnionAccessor(U* u) : u_(u) {} - - typename U::Union_* data() { return &(u_->data_); } - - typename U::Tag* tag() { return &(u_->tag_); } - - void SwitchActive(typename U::Tag new_tag) { u_->SwitchActive(new_tag); } - - private: - U* u_; -}; - -} // namespace internal -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_UNION_ACCESSOR_H_ diff --git a/mojo/public/cpp/bindings/lib/unserialized_message_context.cc b/mojo/public/cpp/bindings/lib/unserialized_message_context.cc new file mode 100644 index 0000000000..b029f4ef00 --- /dev/null +++ b/mojo/public/cpp/bindings/lib/unserialized_message_context.cc @@ -0,0 +1,24 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/lib/unserialized_message_context.h" + +namespace mojo { +namespace internal { + +UnserializedMessageContext::UnserializedMessageContext(const Tag* tag, + uint32_t message_name, + uint32_t message_flags) + : tag_(tag) { + header_.interface_id = 0; + header_.version = 1; + header_.name = message_name; + header_.flags = message_flags; + header_.num_bytes = sizeof(header_); +} + +UnserializedMessageContext::~UnserializedMessageContext() = default; + +} // namespace internal +} // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/unserialized_message_context.h b/mojo/public/cpp/bindings/lib/unserialized_message_context.h new file mode 100644 index 0000000000..4886a981dc --- /dev/null +++ b/mojo/public/cpp/bindings/lib/unserialized_message_context.h @@ -0,0 +1,63 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_LIB_UNSERIALIZED_MESSAGE_CONTEXT_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_LIB_UNSERIALIZED_MESSAGE_CONTEXT_H_ + +#include <stdint.h> + +#include "base/component_export.h" +#include "base/macros.h" +#include "base/optional.h" +#include "mojo/public/c/system/types.h" +#include "mojo/public/cpp/bindings/lib/buffer.h" +#include "mojo/public/cpp/bindings/lib/message_internal.h" +#include "mojo/public/cpp/bindings/lib/serialization_context.h" + +namespace mojo { +namespace internal { + +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) UnserializedMessageContext { + public: + struct Tag {}; + + UnserializedMessageContext(const Tag* tag, + uint32_t message_name, + uint32_t message_flags); + virtual ~UnserializedMessageContext(); + + template <typename MessageType> + MessageType* SafeCast() { + if (&MessageType::kMessageTag != tag_) + return nullptr; + return static_cast<MessageType*>(this); + } + + const Tag* tag() const { return tag_; } + uint32_t message_name() const { return header_.name; } + uint32_t message_flags() const { return header_.flags; } + + MessageHeaderV1* header() { return &header_; } + + virtual void Serialize(SerializationContext* serialization_context, + Buffer* buffer) = 0; + + private: + // The |tag_| is used for run-time type identification of specific + // unserialized message types, e.g. messages generated by mojom bindings. This + // allows opaque message objects to be safely downcast once pulled off a pipe. + const Tag* const tag_; + + // We store message metadata in a serialized header structure to simplify + // Message implementation which needs to query such metadata for both + // serialized and unserialized message objects. + MessageHeaderV1 header_; + + DISALLOW_COPY_AND_ASSIGN(UnserializedMessageContext); +}; + +} // namespace internal +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_UNSERIALIZED_MESSAGE_CONTEXT_H_ diff --git a/mojo/public/cpp/bindings/lib/validation_context.h b/mojo/public/cpp/bindings/lib/validation_context.h index ed6c6542e7..7c4de47327 100644 --- a/mojo/public/cpp/bindings/lib/validation_context.h +++ b/mojo/public/cpp/bindings/lib/validation_context.h @@ -9,9 +9,9 @@ #include <stdint.h> #include "base/compiler_specific.h" +#include "base/component_export.h" #include "base/macros.h" #include "base/strings/string_piece.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" static const int kMaxRecursionDepth = 100; @@ -24,7 +24,7 @@ namespace internal { // ValidationContext is used when validating object sizes, pointers and handle // indices in the payload of incoming messages. -class MOJO_CPP_BINDINGS_EXPORT ValidationContext { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) ValidationContext { public: // [data, data + data_num_bytes) specifies the initial valid memory range. // [0, num_handles) specifies the initial valid range of handle indices. diff --git a/mojo/public/cpp/bindings/lib/validation_errors.h b/mojo/public/cpp/bindings/lib/validation_errors.h index 122418d9e3..e48e37c6b6 100644 --- a/mojo/public/cpp/bindings/lib/validation_errors.h +++ b/mojo/public/cpp/bindings/lib/validation_errors.h @@ -6,9 +6,9 @@ #define MOJO_PUBLIC_CPP_BINDINGS_LIB_VALIDATION_ERRORS_H_ #include "base/callback.h" +#include "base/component_export.h" #include "base/logging.h" #include "base/macros.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/lib/validation_context.h" namespace mojo { @@ -76,23 +76,24 @@ enum ValidationError { VALIDATION_ERROR_MAX_RECURSION_DEPTH, }; -MOJO_CPP_BINDINGS_EXPORT const char* ValidationErrorToString( - ValidationError error); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +const char* ValidationErrorToString(ValidationError error); -MOJO_CPP_BINDINGS_EXPORT void ReportValidationError( - ValidationContext* context, - ValidationError error, - const char* description = nullptr); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +void ReportValidationError(ValidationContext* context, + ValidationError error, + const char* description = nullptr); -MOJO_CPP_BINDINGS_EXPORT void ReportValidationErrorForMessage( - mojo::Message* message, - ValidationError error, - const char* description = nullptr); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +void ReportValidationErrorForMessage(mojo::Message* message, + ValidationError error, + const char* description = nullptr); // This class may be used by tests to suppress validation error logging. This is // not thread-safe and must only be instantiated on the main thread with no // other threads using Mojo bindings at the time of construction or destruction. -class MOJO_CPP_BINDINGS_EXPORT ScopedSuppressValidationErrorLoggingForTests { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) + ScopedSuppressValidationErrorLoggingForTests { public: ScopedSuppressValidationErrorLoggingForTests(); ~ScopedSuppressValidationErrorLoggingForTests(); @@ -105,7 +106,8 @@ class MOJO_CPP_BINDINGS_EXPORT ScopedSuppressValidationErrorLoggingForTests { // Only used by validation tests and when there is only one thread doing message // validation. -class MOJO_CPP_BINDINGS_EXPORT ValidationErrorObserverForTesting { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) + ValidationErrorObserverForTesting { public: explicit ValidationErrorObserverForTesting(const base::Closure& callback); ~ValidationErrorObserverForTesting(); @@ -127,11 +129,13 @@ class MOJO_CPP_BINDINGS_EXPORT ValidationErrorObserverForTesting { // // The function returns true if the error is recorded (by a // SerializationWarningObserverForTesting object), false otherwise. -MOJO_CPP_BINDINGS_EXPORT bool ReportSerializationWarning(ValidationError error); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ReportSerializationWarning(ValidationError error); // Only used by serialization tests and when there is only one thread doing // message serialization. -class MOJO_CPP_BINDINGS_EXPORT SerializationWarningObserverForTesting { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) + SerializationWarningObserverForTesting { public: SerializationWarningObserverForTesting(); ~SerializationWarningObserverForTesting(); diff --git a/mojo/public/cpp/bindings/lib/validation_util.cc b/mojo/public/cpp/bindings/lib/validation_util.cc index 7614df5cbc..4b414c4e3b 100644 --- a/mojo/public/cpp/bindings/lib/validation_util.cc +++ b/mojo/public/cpp/bindings/lib/validation_util.cc @@ -8,14 +8,25 @@ #include <limits> +#include "base/strings/stringprintf.h" #include "mojo/public/cpp/bindings/lib/message_internal.h" #include "mojo/public/cpp/bindings/lib/serialization_util.h" #include "mojo/public/cpp/bindings/lib/validation_errors.h" -#include "mojo/public/interfaces/bindings/interface_control_messages.mojom.h" namespace mojo { namespace internal { +void ReportNonNullableValidationError(ValidationContext* validation_context, + ValidationError error, + int field_index) { + const char* null_or_invalid = + error == VALIDATION_ERROR_UNEXPECTED_NULL_POINTER ? "null" : "invalid"; + + std::string error_message = + base::StringPrintf("%s field %d", null_or_invalid, field_index); + ReportValidationError(validation_context, error, error_message.c_str()); +} + bool ValidateStructHeaderAndClaimMemory(const void* data, ValidationContext* validation_context) { if (!IsAligned(data)) { @@ -118,53 +129,53 @@ bool IsHandleOrInterfaceValid(const Handle_Data& input) { bool ValidateHandleOrInterfaceNonNullable( const AssociatedInterface_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (IsHandleOrInterfaceValid(input)) return true; - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, - error_message); + ReportNonNullableValidationError( + validation_context, VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, + field_index); return false; } bool ValidateHandleOrInterfaceNonNullable( const AssociatedEndpointHandle_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (IsHandleOrInterfaceValid(input)) return true; - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, - error_message); + ReportNonNullableValidationError( + validation_context, VALIDATION_ERROR_UNEXPECTED_INVALID_INTERFACE_ID, + field_index); return false; } bool ValidateHandleOrInterfaceNonNullable( const Interface_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (IsHandleOrInterfaceValid(input)) return true; - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, - error_message); + ReportNonNullableValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, + field_index); return false; } bool ValidateHandleOrInterfaceNonNullable( const Handle_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (IsHandleOrInterfaceValid(input)) return true; - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, - error_message); + ReportNonNullableValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_INVALID_HANDLE, + field_index); return false; } diff --git a/mojo/public/cpp/bindings/lib/validation_util.h b/mojo/public/cpp/bindings/lib/validation_util.h index ea5a991668..3b88956f7a 100644 --- a/mojo/public/cpp/bindings/lib/validation_util.h +++ b/mojo/public/cpp/bindings/lib/validation_util.h @@ -7,7 +7,7 @@ #include <stdint.h> -#include "mojo/public/cpp/bindings/bindings_export.h" +#include "base/component_export.h" #include "mojo/public/cpp/bindings/lib/bindings_internal.h" #include "mojo/public/cpp/bindings/lib/serialization_util.h" #include "mojo/public/cpp/bindings/lib/validate_params.h" @@ -18,6 +18,12 @@ namespace mojo { namespace internal { +// Calls ReportValidationError() with a constructed error string. +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +void ReportNonNullableValidationError(ValidationContext* validation_context, + ValidationError error, + int field_index); + // Checks whether decoding the pointer will overflow and produce a pointer // smaller than |offset|. inline bool ValidateEncodedPointer(const uint64_t* offset) { @@ -47,32 +53,35 @@ bool ValidatePointer(const Pointer<T>& input, // |validation_context|. On success, the memory range is marked as occupied. // Note: Does not verify |version| or that |num_bytes| is correct for the // claimed version. -MOJO_CPP_BINDINGS_EXPORT bool ValidateStructHeaderAndClaimMemory( - const void* data, - ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateStructHeaderAndClaimMemory(const void* data, + ValidationContext* validation_context); // Validates that |data| contains a valid union header, in terms of alignment // and size. It checks that the memory range [data, data + kUnionDataSize) is // not marked as occupied by other objects in |validation_context|. On success, // the memory range is marked as occupied. -MOJO_CPP_BINDINGS_EXPORT bool ValidateNonInlinedUnionHeaderAndClaimMemory( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateNonInlinedUnionHeaderAndClaimMemory( const void* data, ValidationContext* validation_context); // Validates that the message is a request which doesn't expect a response. -MOJO_CPP_BINDINGS_EXPORT bool ValidateMessageIsRequestWithoutResponse( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateMessageIsRequestWithoutResponse( const Message* message, ValidationContext* validation_context); // Validates that the message is a request expecting a response. -MOJO_CPP_BINDINGS_EXPORT bool ValidateMessageIsRequestExpectingResponse( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateMessageIsRequestExpectingResponse( const Message* message, ValidationContext* validation_context); // Validates that the message is a response. -MOJO_CPP_BINDINGS_EXPORT bool ValidateMessageIsResponse( - const Message* message, - ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateMessageIsResponse(const Message* message, + ValidationContext* validation_context); // Validates that the message payload is a valid struct of type ParamsType. template <typename ParamsType> @@ -85,54 +94,56 @@ bool ValidateMessagePayload(const Message* message, // |input| is not null/invalid. template <typename T> bool ValidatePointerNonNullable(const T& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (input.offset) return true; - - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, - error_message); + ReportNonNullableValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + field_index); return false; } template <typename T> bool ValidateInlinedUnionNonNullable(const T& input, - const char* error_message, + int field_index, ValidationContext* validation_context) { if (!input.is_null()) return true; - - ReportValidationError(validation_context, - VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, - error_message); + ReportNonNullableValidationError(validation_context, + VALIDATION_ERROR_UNEXPECTED_NULL_POINTER, + field_index); return false; } -MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( - const AssociatedInterface_Data& input); -MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( - const AssociatedEndpointHandle_Data& input); -MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( - const Interface_Data& input); -MOJO_CPP_BINDINGS_EXPORT bool IsHandleOrInterfaceValid( - const Handle_Data& input); - -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool IsHandleOrInterfaceValid(const AssociatedInterface_Data& input); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool IsHandleOrInterfaceValid(const AssociatedEndpointHandle_Data& input); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool IsHandleOrInterfaceValid(const Interface_Data& input); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool IsHandleOrInterfaceValid(const Handle_Data& input); + +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterfaceNonNullable( const AssociatedInterface_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterfaceNonNullable( const AssociatedEndpointHandle_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterfaceNonNullable( const Interface_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterfaceNonNullable( +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterfaceNonNullable( const Handle_Data& input, - const char* error_message, + int field_index, ValidationContext* validation_context); template <typename T> @@ -187,18 +198,18 @@ bool ValidateNonInlinedUnion(const Pointer<T>& input, T::Validate(input.Get(), validation_context, false); } -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( - const AssociatedInterface_Data& input, - ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( - const AssociatedEndpointHandle_Data& input, - ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( - const Interface_Data& input, - ValidationContext* validation_context); -MOJO_CPP_BINDINGS_EXPORT bool ValidateHandleOrInterface( - const Handle_Data& input, - ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterface(const AssociatedInterface_Data& input, + ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterface(const AssociatedEndpointHandle_Data& input, + ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterface(const Interface_Data& input, + ValidationContext* validation_context); +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) +bool ValidateHandleOrInterface(const Handle_Data& input, + ValidationContext* validation_context); } // namespace internal } // namespace mojo diff --git a/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h b/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h index cb24bc46ee..bb0ee531f5 100644 --- a/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h +++ b/mojo/public/cpp/bindings/lib/wtf_clone_equals_util.h @@ -8,11 +8,10 @@ #include <type_traits> #include "mojo/public/cpp/bindings/clone_traits.h" -#include "mojo/public/cpp/bindings/lib/equals_traits.h" -#include "third_party/WebKit/Source/wtf/HashMap.h" -#include "third_party/WebKit/Source/wtf/Optional.h" -#include "third_party/WebKit/Source/wtf/Vector.h" -#include "third_party/WebKit/Source/wtf/text/WTFString.h" +#include "mojo/public/cpp/bindings/equals_traits.h" +#include "third_party/blink/renderer/platform/wtf/hash_map.h" +#include "third_party/blink/renderer/platform/wtf/text/wtf_string.h" +#include "third_party/blink/renderer/platform/wtf/vector.h" namespace mojo { @@ -20,7 +19,7 @@ template <typename T> struct CloneTraits<WTF::Vector<T>, false> { static WTF::Vector<T> Clone(const WTF::Vector<T>& input) { WTF::Vector<T> result; - result.reserveCapacity(input.size()); + result.ReserveCapacity(input.size()); for (const auto& element : input) result.push_back(mojo::Clone(element)); @@ -32,22 +31,20 @@ template <typename K, typename V> struct CloneTraits<WTF::HashMap<K, V>, false> { static WTF::HashMap<K, V> Clone(const WTF::HashMap<K, V>& input) { WTF::HashMap<K, V> result; - auto input_end = input.end(); - for (auto it = input.begin(); it != input_end; ++it) - result.add(mojo::Clone(it->key), mojo::Clone(it->value)); + for (const auto& element : input) + result.insert(mojo::Clone(element.key), mojo::Clone(element.value)); + return result; } }; -namespace internal { - template <typename T> struct EqualsTraits<WTF::Vector<T>, false> { static bool Equals(const WTF::Vector<T>& a, const WTF::Vector<T>& b) { if (a.size() != b.size()) return false; for (size_t i = 0; i < a.size(); ++i) { - if (!internal::Equals(a[i], b[i])) + if (!mojo::Equals(a[i], b[i])) return false; } return true; @@ -65,14 +62,13 @@ struct EqualsTraits<WTF::HashMap<K, V>, false> { for (auto iter = a.begin(); iter != a_end; ++iter) { auto b_iter = b.find(iter->key); - if (b_iter == b_end || !internal::Equals(iter->value, b_iter->value)) + if (b_iter == b_end || !mojo::Equals(iter->value, b_iter->value)) return false; } return true; } }; -} // namespace internal } // namespace mojo #endif // MOJO_PUBLIC_CPP_BINDINGS_LIB_WTF_CLONE_EQUALS_UTIL_H_ diff --git a/mojo/public/cpp/bindings/lib/wtf_hash_util.h b/mojo/public/cpp/bindings/lib/wtf_hash_util.h index cc590da67a..fa02262e8e 100644 --- a/mojo/public/cpp/bindings/lib/wtf_hash_util.h +++ b/mojo/public/cpp/bindings/lib/wtf_hash_util.h @@ -9,9 +9,9 @@ #include "mojo/public/cpp/bindings/lib/hash_util.h" #include "mojo/public/cpp/bindings/struct_ptr.h" -#include "third_party/WebKit/Source/wtf/HashFunctions.h" -#include "third_party/WebKit/Source/wtf/text/StringHash.h" -#include "third_party/WebKit/Source/wtf/text/WTFString.h" +#include "third_party/blink/renderer/platform/wtf/hash_functions.h" +#include "third_party/blink/renderer/platform/wtf/text/string_hash.h" +#include "third_party/blink/renderer/platform/wtf/text/wtf_string.h" namespace mojo { namespace internal { @@ -48,7 +48,7 @@ struct WTFHashTraits<T, false> { template <> struct WTFHashTraits<WTF::String, false> { static size_t Hash(size_t seed, const WTF::String& value) { - return HashCombine(seed, WTF::StringHash::hash(value)); + return HashCombine(seed, WTF::StringHash::GetHash(value)); } }; @@ -59,25 +59,25 @@ size_t WTFHash(size_t seed, const T& value) { template <typename T> struct StructPtrHashFn { - static unsigned hash(const StructPtr<T>& value) { + static unsigned GetHash(const StructPtr<T>& value) { return value.Hash(kHashSeed); } - static bool equal(const StructPtr<T>& left, const StructPtr<T>& right) { + static bool Equal(const StructPtr<T>& left, const StructPtr<T>& right) { return left.Equals(right); } - static const bool safeToCompareToEmptyOrDeleted = false; + static const bool safe_to_compare_to_empty_or_deleted = false; }; template <typename T> struct InlinedStructPtrHashFn { - static unsigned hash(const InlinedStructPtr<T>& value) { + static unsigned GetHash(const InlinedStructPtr<T>& value) { return value.Hash(kHashSeed); } - static bool equal(const InlinedStructPtr<T>& left, + static bool Equal(const InlinedStructPtr<T>& left, const InlinedStructPtr<T>& right) { return left.Equals(right); } - static const bool safeToCompareToEmptyOrDeleted = false; + static const bool safe_to_compare_to_empty_or_deleted = false; }; } // namespace internal @@ -93,14 +93,14 @@ struct DefaultHash<mojo::StructPtr<T>> { template <typename T> struct HashTraits<mojo::StructPtr<T>> : public GenericHashTraits<mojo::StructPtr<T>> { - static const bool hasIsEmptyValueFunction = true; - static bool isEmptyValue(const mojo::StructPtr<T>& value) { + static const bool kHasIsEmptyValueFunction = true; + static bool IsEmptyValue(const mojo::StructPtr<T>& value) { return value.is_null(); } - static void constructDeletedValue(mojo::StructPtr<T>& slot, bool) { + static void ConstructDeletedValue(mojo::StructPtr<T>& slot, bool) { mojo::internal::StructPtrWTFHelper<T>::ConstructDeletedValue(slot); } - static bool isDeletedValue(const mojo::StructPtr<T>& value) { + static bool IsDeletedValue(const mojo::StructPtr<T>& value) { return mojo::internal::StructPtrWTFHelper<T>::IsHashTableDeletedValue( value); } @@ -114,14 +114,14 @@ struct DefaultHash<mojo::InlinedStructPtr<T>> { template <typename T> struct HashTraits<mojo::InlinedStructPtr<T>> : public GenericHashTraits<mojo::InlinedStructPtr<T>> { - static const bool hasIsEmptyValueFunction = true; - static bool isEmptyValue(const mojo::InlinedStructPtr<T>& value) { + static const bool kHasIsEmptyValueFunction = true; + static bool IsEmptyValue(const mojo::InlinedStructPtr<T>& value) { return value.is_null(); } - static void constructDeletedValue(mojo::InlinedStructPtr<T>& slot, bool) { + static void ConstructDeletedValue(mojo::InlinedStructPtr<T>& slot, bool) { mojo::internal::InlinedStructPtrWTFHelper<T>::ConstructDeletedValue(slot); } - static bool isDeletedValue(const mojo::InlinedStructPtr<T>& value) { + static bool IsDeletedValue(const mojo::InlinedStructPtr<T>& value) { return mojo::internal::InlinedStructPtrWTFHelper< T>::IsHashTableDeletedValue(value); } diff --git a/mojo/public/cpp/bindings/map.h b/mojo/public/cpp/bindings/map.h index c1ba0756a3..350bfad76a 100644 --- a/mojo/public/cpp/bindings/map.h +++ b/mojo/public/cpp/bindings/map.h @@ -6,33 +6,32 @@ #define MOJO_PUBLIC_CPP_BINDINGS_MAP_H_ #include <map> -#include <unordered_map> #include <utility> +#include "base/containers/flat_map.h" + namespace mojo { // TODO(yzshen): These conversion functions should be removed and callsites // should be revisited and changed to use the same map type. template <typename Key, typename Value> -std::unordered_map<Key, Value> MapToUnorderedMap( - const std::map<Key, Value>& input) { - return std::unordered_map<Key, Value>(input.begin(), input.end()); +base::flat_map<Key, Value> MapToFlatMap(const std::map<Key, Value>& input) { + return base::flat_map<Key, Value>(input.begin(), input.end()); } template <typename Key, typename Value> -std::unordered_map<Key, Value> MapToUnorderedMap(std::map<Key, Value>&& input) { - return std::unordered_map<Key, Value>(std::make_move_iterator(input.begin()), - std::make_move_iterator(input.end())); +base::flat_map<Key, Value> MapToFlatMap(std::map<Key, Value>&& input) { + return base::flat_map<Key, Value>(std::make_move_iterator(input.begin()), + std::make_move_iterator(input.end())); } template <typename Key, typename Value> -std::map<Key, Value> UnorderedMapToMap( - const std::unordered_map<Key, Value>& input) { +std::map<Key, Value> FlatMapToMap(const base::flat_map<Key, Value>& input) { return std::map<Key, Value>(input.begin(), input.end()); } template <typename Key, typename Value> -std::map<Key, Value> UnorderedMapToMap(std::unordered_map<Key, Value>&& input) { +std::map<Key, Value> FlatMapToMap(base::flat_map<Key, Value>&& input) { return std::map<Key, Value>(std::make_move_iterator(input.begin()), std::make_move_iterator(input.end())); } diff --git a/mojo/public/cpp/bindings/map_traits.h b/mojo/public/cpp/bindings/map_traits.h index 5c0d8b2846..60bcb59255 100644 --- a/mojo/public/cpp/bindings/map_traits.h +++ b/mojo/public/cpp/bindings/map_traits.h @@ -5,6 +5,8 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_MAP_TRAITS_H_ #define MOJO_PUBLIC_CPP_BINDINGS_MAP_TRAITS_H_ +#include "mojo/public/cpp/bindings/lib/template_util.h" + namespace mojo { // This must be specialized for any type |T| to be serialized/deserialized as @@ -19,6 +21,8 @@ namespace mojo { // using Value = V; // // // These two methods are optional. Please see comments in struct_traits.h +// // Note that unlike with StructTraits, IsNull() is called *twice* during +// // serialization for MapTraits. // static bool IsNull(const CustomMap<K, V>& input); // static void SetToNull(CustomMap<K, V>* output); // @@ -49,7 +53,11 @@ namespace mojo { // }; // template <typename T> -struct MapTraits; +struct MapTraits { + static_assert(internal::AlwaysFalse<T>::value, + "Cannot find the mojo::MapTraits specialization. Did you " + "forget to include the corresponding header file?"); +}; } // namespace mojo diff --git a/mojo/public/cpp/bindings/map_traits_flat_map.h b/mojo/public/cpp/bindings/map_traits_flat_map.h new file mode 100644 index 0000000000..9efbabea14 --- /dev/null +++ b/mojo/public/cpp/bindings/map_traits_flat_map.h @@ -0,0 +1,56 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_MAP_TRAITS_FLAT_MAP_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_MAP_TRAITS_FLAT_MAP_H_ + +#include "base/containers/flat_map.h" +#include "mojo/public/cpp/bindings/map_traits.h" + +namespace mojo { + +template <typename K, typename V, typename Compare> +struct MapTraits<base::flat_map<K, V, Compare>> { + using Key = K; + using Value = V; + using Iterator = typename base::flat_map<K, V, Compare>::iterator; + using ConstIterator = typename base::flat_map<K, V, Compare>::const_iterator; + + static size_t GetSize(const base::flat_map<K, V, Compare>& input) { + return input.size(); + } + + static ConstIterator GetBegin(const base::flat_map<K, V, Compare>& input) { + return input.begin(); + } + static Iterator GetBegin(base::flat_map<K, V, Compare>& input) { + return input.begin(); + } + + static void AdvanceIterator(ConstIterator& iterator) { iterator++; } + static void AdvanceIterator(Iterator& iterator) { iterator++; } + + static const K& GetKey(Iterator& iterator) { return iterator->first; } + static const K& GetKey(ConstIterator& iterator) { return iterator->first; } + + static V& GetValue(Iterator& iterator) { return iterator->second; } + static const V& GetValue(ConstIterator& iterator) { return iterator->second; } + + template <typename MaybeConstKeyType, typename MaybeConstValueType> + static bool Insert(base::flat_map<K, V, Compare>& input, + MaybeConstKeyType&& key, + MaybeConstValueType&& value) { + input.emplace(std::forward<MaybeConstKeyType>(key), + std::forward<MaybeConstValueType>(value)); + return true; + } + + static void SetToEmpty(base::flat_map<K, V, Compare>* output) { + output->clear(); + } +}; + +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_MAP_TRAITS_FLAT_MAP_H_ diff --git a/mojo/public/cpp/bindings/map_traits_stl.h b/mojo/public/cpp/bindings/map_traits_stl.h index 83a4399ce0..1c4a13e554 100644 --- a/mojo/public/cpp/bindings/map_traits_stl.h +++ b/mojo/public/cpp/bindings/map_traits_stl.h @@ -12,29 +12,33 @@ namespace mojo { -template <typename K, typename V> -struct MapTraits<std::map<K, V>> { +template <typename K, typename V, typename Compare> +struct MapTraits<std::map<K, V, Compare>> { using Key = K; using Value = V; - using Iterator = typename std::map<K, V>::iterator; - using ConstIterator = typename std::map<K, V>::const_iterator; + using Iterator = typename std::map<K, V, Compare>::iterator; + using ConstIterator = typename std::map<K, V, Compare>::const_iterator; - static bool IsNull(const std::map<K, V>& input) { + static bool IsNull(const std::map<K, V, Compare>& input) { // std::map<> is always converted to non-null mojom map. return false; } - static void SetToNull(std::map<K, V>* output) { + static void SetToNull(std::map<K, V, Compare>* output) { // std::map<> doesn't support null state. Set it to empty instead. output->clear(); } - static size_t GetSize(const std::map<K, V>& input) { return input.size(); } + static size_t GetSize(const std::map<K, V, Compare>& input) { + return input.size(); + } - static ConstIterator GetBegin(const std::map<K, V>& input) { + static ConstIterator GetBegin(const std::map<K, V, Compare>& input) { + return input.begin(); + } + static Iterator GetBegin(std::map<K, V, Compare>& input) { return input.begin(); } - static Iterator GetBegin(std::map<K, V>& input) { return input.begin(); } static void AdvanceIterator(ConstIterator& iterator) { iterator++; } static void AdvanceIterator(Iterator& iterator) { iterator++; } @@ -45,16 +49,18 @@ struct MapTraits<std::map<K, V>> { static V& GetValue(Iterator& iterator) { return iterator->second; } static const V& GetValue(ConstIterator& iterator) { return iterator->second; } - static bool Insert(std::map<K, V>& input, const K& key, V&& value) { + static bool Insert(std::map<K, V, Compare>& input, const K& key, V&& value) { input.insert(std::make_pair(key, std::forward<V>(value))); return true; } - static bool Insert(std::map<K, V>& input, const K& key, const V& value) { + static bool Insert(std::map<K, V, Compare>& input, + const K& key, + const V& value) { input.insert(std::make_pair(key, value)); return true; } - static void SetToEmpty(std::map<K, V>* output) { output->clear(); } + static void SetToEmpty(std::map<K, V, Compare>* output) { output->clear(); } }; template <typename K, typename V> diff --git a/mojo/public/cpp/bindings/map_traits_wtf_hash_map.h b/mojo/public/cpp/bindings/map_traits_wtf_hash_map.h index dd68b3686a..32deab7aae 100644 --- a/mojo/public/cpp/bindings/map_traits_wtf_hash_map.h +++ b/mojo/public/cpp/bindings/map_traits_wtf_hash_map.h @@ -7,7 +7,7 @@ #include "base/logging.h" #include "mojo/public/cpp/bindings/map_traits.h" -#include "third_party/WebKit/Source/wtf/HashMap.h" +#include "third_party/blink/renderer/platform/wtf/hash_map.h" namespace mojo { @@ -48,7 +48,7 @@ struct MapTraits<WTF::HashMap<K, V>> { template <typename IK, typename IV> static bool Insert(WTF::HashMap<K, V>& input, IK&& key, IV&& value) { - if (!WTF::HashMap<K, V>::isValidKey(key)) { + if (!WTF::HashMap<K, V>::IsValidKey(key)) { LOG(ERROR) << "The key value is disallowed by WTF::HashMap"; return false; } diff --git a/mojo/public/cpp/bindings/message.h b/mojo/public/cpp/bindings/message.h index 48e6900306..7f6e3ea436 100644 --- a/mojo/public/cpp/bindings/message.h +++ b/mojo/public/cpp/bindings/message.h @@ -15,10 +15,12 @@ #include "base/callback.h" #include "base/compiler_specific.h" +#include "base/component_export.h" #include "base/logging.h" -#include "mojo/public/cpp/bindings/bindings_export.h" -#include "mojo/public/cpp/bindings/lib/message_buffer.h" +#include "base/memory/ptr_util.h" +#include "mojo/public/cpp/bindings/lib/buffer.h" #include "mojo/public/cpp/bindings/lib/message_internal.h" +#include "mojo/public/cpp/bindings/lib/unserialized_message_context.h" #include "mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h" #include "mojo/public/cpp/system/message.h" @@ -26,23 +28,58 @@ namespace mojo { class AssociatedGroupController; -using ReportBadMessageCallback = base::Callback<void(const std::string& error)>; +using ReportBadMessageCallback = + base::OnceCallback<void(const std::string& error)>; // Message is a holder for the data and handles to be sent over a MessagePipe. // Message owns its data and handles, but a consumer of Message is free to // mutate the data and handles. The message's data is comprised of a header // followed by payload. -class MOJO_CPP_BINDINGS_EXPORT Message { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) Message { public: static const uint32_t kFlagExpectsResponse = 1 << 0; static const uint32_t kFlagIsResponse = 1 << 1; static const uint32_t kFlagIsSync = 1 << 2; + // Constructs an uninitialized Message object. Message(); + + // See the move-assignment operator below. Message(Message&& other); + // Constructs a new message with an unserialized context attached. This + // message may be serialized later if necessary. + explicit Message( + std::unique_ptr<internal::UnserializedMessageContext> context); + + // Constructs a new serialized Message object with optional handles attached. + // This message is fully functional and may be exchanged for a + // ScopedMessageHandle for transit over a message pipe. See TakeMojoMessage(). + // + // If |handles| is non-null, any handles in |*handles| are attached to the + // newly constructed message. + // + // Note that |payload_size| is only the initially known size of the message + // payload, if any. The payload can be expanded after construction using the + // interface returned by |payload_buffer()|. + Message(uint32_t name, + uint32_t flags, + size_t payload_size, + size_t payload_interface_id_count, + std::vector<ScopedHandle>* handles); + + // Constructs a new serialized Message object from an existing + // ScopedMessageHandle; e.g., one read from a message pipe. + // + // If the message had any handles attached, they will be extracted and + // retrievable via |handles()|. Such messages may NOT be sent back over + // another message pipe, but are otherwise safe to inspect and pass around. + Message(ScopedMessageHandle handle); + ~Message(); + // Moves |other| into a new Message object. The moved-from Message becomes + // invalid and is effectively in a default-constructed state after this call. Message& operator=(Message&& other); // Resets the Message to an uninitialized state. Upon reset, the Message @@ -51,51 +88,47 @@ class MOJO_CPP_BINDINGS_EXPORT Message { void Reset(); // Indicates whether this Message is uninitialized. - bool IsNull() const { return !buffer_; } - - // Initializes a Message with enough space for |capacity| bytes. - void Initialize(size_t capacity, bool zero_initialized); + bool IsNull() const { return !handle_.is_valid(); } - // Initializes a Message from an existing Mojo MessageHandle. - void InitializeFromMojoMessage(ScopedMessageHandle message, - uint32_t num_bytes, - std::vector<Handle>* handles); - - uint32_t data_num_bytes() const { - return static_cast<uint32_t>(buffer_->size()); - } + // Indicates whether this Message is serialized. + bool is_serialized() const { return serialized_; } // Access the raw bytes of the message. const uint8_t* data() const { - return static_cast<const uint8_t*>(buffer_->data()); + DCHECK(payload_buffer_.is_valid()); + return static_cast<const uint8_t*>(payload_buffer_.data()); } + uint8_t* mutable_data() { return const_cast<uint8_t*>(data()); } - uint8_t* mutable_data() { return static_cast<uint8_t*>(buffer_->data()); } + size_t data_num_bytes() const { + DCHECK(payload_buffer_.is_valid()); + return payload_buffer_.cursor(); + } // Access the header. const internal::MessageHeader* header() const { - return static_cast<const internal::MessageHeader*>(buffer_->data()); + return reinterpret_cast<const internal::MessageHeader*>(data()); } internal::MessageHeader* header() { - return static_cast<internal::MessageHeader*>(buffer_->data()); + return reinterpret_cast<internal::MessageHeader*>(mutable_data()); } const internal::MessageHeaderV1* header_v1() const { DCHECK_GE(version(), 1u); - return static_cast<const internal::MessageHeaderV1*>(buffer_->data()); + return reinterpret_cast<const internal::MessageHeaderV1*>(data()); } internal::MessageHeaderV1* header_v1() { DCHECK_GE(version(), 1u); - return static_cast<internal::MessageHeaderV1*>(buffer_->data()); + return reinterpret_cast<internal::MessageHeaderV1*>(mutable_data()); } const internal::MessageHeaderV2* header_v2() const { DCHECK_GE(version(), 2u); - return static_cast<const internal::MessageHeaderV2*>(buffer_->data()); + return reinterpret_cast<const internal::MessageHeaderV2*>(data()); } internal::MessageHeaderV2* header_v2() { DCHECK_GE(version(), 2u); - return static_cast<internal::MessageHeaderV2*>(buffer_->data()); + return reinterpret_cast<internal::MessageHeaderV2*>(mutable_data()); } uint32_t version() const { return header()->version; } @@ -120,9 +153,12 @@ class MOJO_CPP_BINDINGS_EXPORT Message { uint32_t payload_num_interface_ids() const; const uint32_t* payload_interface_ids() const; - // Access the handles. - const std::vector<Handle>* handles() const { return &handles_; } - std::vector<Handle>* mutable_handles() { return &handles_; } + internal::Buffer* payload_buffer() { return &payload_buffer_; } + + // Access the handles of a received message. Note that these are unused on + // outgoing messages. + const std::vector<ScopedHandle>* handles() const { return &handles_; } + std::vector<ScopedHandle>* mutable_handles() { return &handles_; } const std::vector<ScopedInterfaceEndpointHandle>* associated_endpoint_handles() const { @@ -133,8 +169,10 @@ class MOJO_CPP_BINDINGS_EXPORT Message { return &associated_endpoint_handles_; } - // Access the underlying Buffer interface. - internal::Buffer* buffer() { return buffer_.get(); } + // Takes ownership of any handles within |*context| and attaches them to this + // Message. + void AttachHandlesFromSerializationContext( + internal::SerializationContext* context); // Takes a scoped MessageHandle which may be passed to |WriteMessageNew()| for // transmission. Note that this invalidates this Message object, taking @@ -155,20 +193,68 @@ class MOJO_CPP_BINDINGS_EXPORT Message { bool DeserializeAssociatedEndpointHandles( AssociatedGroupController* group_controller); + // If this Message has an unserialized message context attached, force it to + // be serialized immediately. Otherwise this does nothing. + void SerializeIfNecessary(); + + // Takes the unserialized message context from this Message if its tag matches + // |tag|. + std::unique_ptr<internal::UnserializedMessageContext> TakeUnserializedContext( + const internal::UnserializedMessageContext::Tag* tag); + + template <typename MessageType> + std::unique_ptr<MessageType> TakeUnserializedContext() { + auto generic_context = TakeUnserializedContext(&MessageType::kMessageTag); + if (!generic_context) + return nullptr; + return base::WrapUnique( + generic_context.release()->template SafeCast<MessageType>()); + } + +#if defined(ENABLE_IPC_FUZZER) + const char* interface_name() const { return interface_name_; } + void set_interface_name(const char* interface_name) { + interface_name_ = interface_name; + } + + const char* method_name() const { return method_name_; } + void set_method_name(const char* method_name) { method_name_ = method_name; } +#endif + private: - void CloseHandles(); + ScopedMessageHandle handle_; + + // A Buffer which may be used to allocate blocks of data within the message + // payload for reading or writing. + internal::Buffer payload_buffer_; - std::unique_ptr<internal::MessageBuffer> buffer_; - std::vector<Handle> handles_; + std::vector<ScopedHandle> handles_; std::vector<ScopedInterfaceEndpointHandle> associated_endpoint_handles_; + // Indicates whether this Message object is transferable, i.e. can be sent + // elsewhere. In general this is true unless |handle_| is invalid or + // serialized handles have been extracted from the serialized message object + // identified by |handle_|. + bool transferable_ = false; + + // Indicates whether this Message object is serialized. + bool serialized_ = false; + +#if defined(ENABLE_IPC_FUZZER) + const char* interface_name_ = nullptr; + const char* method_name_ = nullptr; +#endif + DISALLOW_COPY_AND_ASSIGN(Message); }; -class MessageReceiver { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) MessageReceiver { public: virtual ~MessageReceiver() {} + // Indicates whether the receiver prefers to receive serialized messages. + virtual bool PrefersSerializedMessages(); + // The receiver may mutate the given message. Returns true if the message // was accepted and false otherwise, indicating that the message was invalid // or malformed. @@ -197,12 +283,13 @@ class MessageReceiverWithStatus : public MessageReceiver { // Returns |true| if this MessageReceiver is currently bound to a MessagePipe, // the pipe has not been closed, and the pipe has not encountered an error. - virtual bool IsValid() = 0; + virtual bool IsConnected() = 0; - // DCHECKs if this MessageReceiver is currently bound to a MessagePipe, the - // pipe has not been closed, and the pipe has not encountered an error. - // This function may be called on any thread. - virtual void DCheckInvalid(const std::string& message) = 0; + // Determines if this MessageReceiver is still bound to a message pipe and has + // not encountered any errors. This is asynchronous but may be called from any + // sequence. |callback| is eventually invoked from an arbitrary sequence with + // the result of the query. + virtual void IsConnectedAsync(base::OnceCallback<void(bool)> callback) = 0; }; // An alternative to MessageReceiverWithResponder for cases in which it @@ -221,8 +308,8 @@ class MessageReceiverWithResponderStatus : public MessageReceiver { responder) WARN_UNUSED_RESULT = 0; }; -class MOJO_CPP_BINDINGS_EXPORT PassThroughFilter - : NON_EXPORTED_BASE(public MessageReceiver) { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) PassThroughFilter + : public MessageReceiver { public: PassThroughFilter(); ~PassThroughFilter() override; @@ -251,7 +338,7 @@ class SyncMessageResponseSetup; // if (response_value.IsBad()) // response_context.ReportBadMessage("Bad response_value!"); // -class MOJO_CPP_BINDINGS_EXPORT SyncMessageResponseContext { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) SyncMessageResponseContext { public: SyncMessageResponseContext(); ~SyncMessageResponseContext(); @@ -260,14 +347,13 @@ class MOJO_CPP_BINDINGS_EXPORT SyncMessageResponseContext { void ReportBadMessage(const std::string& error); - const ReportBadMessageCallback& GetBadMessageCallback(); + ReportBadMessageCallback GetBadMessageCallback(); private: friend class internal::SyncMessageResponseSetup; SyncMessageResponseContext* outer_context_; Message response_; - ReportBadMessageCallback bad_message_callback_; DISALLOW_COPY_AND_ASSIGN(SyncMessageResponseContext); }; @@ -279,14 +365,15 @@ class MOJO_CPP_BINDINGS_EXPORT SyncMessageResponseContext { // dispatched, otherwise returns an error code if something went wrong. // // NOTE: The message hasn't been validated and may be malformed! +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) MojoResult ReadMessage(MessagePipeHandle handle, Message* message); // Reports the currently dispatching Message as bad. Note that this is only // legal to call from directly within the stack frame of a message dispatch. If // you need to do asynchronous work before you can determine the legitimacy of -// a message, use TakeBadMessageCallback() and retain its result until you're +// a message, use GetBadMessageCallback() and retain its result until you're // ready to invoke or discard it. -MOJO_CPP_BINDINGS_EXPORT +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) void ReportBadMessage(const std::string& error); // Acquires a callback which may be run to report the currently dispatching @@ -294,7 +381,7 @@ void ReportBadMessage(const std::string& error); // stack frame of a message dispatch, but the returned callback may be called // exactly once any time thereafter to report the message as bad. This may only // be called once per message. -MOJO_CPP_BINDINGS_EXPORT +COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) ReportBadMessageCallback GetBadMessageCallback(); } // namespace mojo diff --git a/mojo/public/cpp/bindings/message_dumper.h b/mojo/public/cpp/bindings/message_dumper.h new file mode 100644 index 0000000000..44cf384ab0 --- /dev/null +++ b/mojo/public/cpp/bindings/message_dumper.h @@ -0,0 +1,43 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_DUMPER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_DUMPER_H_ + +#include "base/files/file_path.h" +#include "mojo/public/cpp/bindings/message.h" +#include "mojo/public/cpp/bindings/message_dumper.h" + +namespace mojo { + +class MessageDumper : public mojo::MessageReceiver { + public: + MessageDumper(); + ~MessageDumper() override; + + bool Accept(mojo::Message* message) override; + + struct MessageEntry { + MessageEntry(const uint8_t* data, + uint32_t data_size, + const char* interface_name, + const char* method_name); + MessageEntry(const MessageEntry& entry); + ~MessageEntry(); + + const char* interface_name; + const char* method_name; + std::vector<uint8_t> data_bytes; + }; + + static void SetMessageDumpDirectory(const base::FilePath& directory); + static const base::FilePath& GetMessageDumpDirectory(); + + private: + uint32_t identifier_; +}; + +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_DUMPER_H_ diff --git a/mojo/public/cpp/bindings/message_header_validator.h b/mojo/public/cpp/bindings/message_header_validator.h index 50c19dbe04..621d14fdec 100644 --- a/mojo/public/cpp/bindings/message_header_validator.h +++ b/mojo/public/cpp/bindings/message_header_validator.h @@ -6,13 +6,13 @@ #define MOJO_PUBLIC_CPP_BINDINGS_MESSAGE_HEADER_VALIDATOR_H_ #include "base/compiler_specific.h" -#include "mojo/public/cpp/bindings/bindings_export.h" +#include "base/component_export.h" #include "mojo/public/cpp/bindings/message.h" namespace mojo { -class MOJO_CPP_BINDINGS_EXPORT MessageHeaderValidator - : NON_EXPORTED_BASE(public MessageReceiver) { +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) MessageHeaderValidator + : public MessageReceiver { public: MessageHeaderValidator(); explicit MessageHeaderValidator(const std::string& description); diff --git a/mojo/public/cpp/bindings/mojo_buildflags.h b/mojo/public/cpp/bindings/mojo_buildflags.h new file mode 100644 index 0000000000..fb646fc92e --- /dev/null +++ b/mojo/public/cpp/bindings/mojo_buildflags.h @@ -0,0 +1,6 @@ +#ifndef CPP_MOJO_BUILD_FLAGS_H_ +#define CPP_MOJO_BUILD_FLAGS_H_ + +#include <build/buildflag.h> +#define BUILDFLAG_INTERNAL_MOJO_TRACE_ENABLED() (0) +#endif // CPP_MOJO_BUILD_FLAGS_H_ diff --git a/mojo/public/cpp/bindings/native_struct.h b/mojo/public/cpp/bindings/native_struct.h deleted file mode 100644 index ac27250bcc..0000000000 --- a/mojo/public/cpp/bindings/native_struct.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_NATIVE_STRUCT_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_NATIVE_STRUCT_H_ - -#include <vector> - -#include "base/optional.h" -#include "mojo/public/cpp/bindings/bindings_export.h" -#include "mojo/public/cpp/bindings/lib/native_struct_data.h" -#include "mojo/public/cpp/bindings/struct_ptr.h" -#include "mojo/public/cpp/bindings/type_converter.h" - -namespace mojo { - -class NativeStruct; -using NativeStructPtr = StructPtr<NativeStruct>; - -// Native-only structs correspond to "[Native] struct Foo;" definitions in -// mojom. -class MOJO_CPP_BINDINGS_EXPORT NativeStruct { - public: - using Data_ = internal::NativeStruct_Data; - - static NativeStructPtr New(); - - template <typename U> - static NativeStructPtr From(const U& u) { - return TypeConverter<NativeStructPtr, U>::Convert(u); - } - - template <typename U> - U To() const { - return TypeConverter<U, NativeStruct>::Convert(*this); - } - - NativeStruct(); - ~NativeStruct(); - - NativeStructPtr Clone() const; - bool Equals(const NativeStruct& other) const; - size_t Hash(size_t seed) const; - - base::Optional<std::vector<uint8_t>> data; -}; - -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_NATIVE_STRUCT_H_ diff --git a/mojo/public/cpp/bindings/native_struct_data_view.h b/mojo/public/cpp/bindings/native_struct_data_view.h deleted file mode 100644 index 613bd7a0b0..0000000000 --- a/mojo/public/cpp/bindings/native_struct_data_view.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_NATIVE_STRUCT_DATA_VIEW_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_NATIVE_STRUCT_DATA_VIEW_H_ - -#include "mojo/public/cpp/bindings/lib/native_struct_data.h" -#include "mojo/public/cpp/bindings/lib/serialization_context.h" - -namespace mojo { - -class NativeStructDataView { - public: - using Data_ = internal::NativeStruct_Data; - - NativeStructDataView() {} - - NativeStructDataView(Data_* data, internal::SerializationContext* context) - : data_(data) {} - - bool is_null() const { return !data_; } - - size_t size() const { return data_->data.size(); } - - uint8_t operator[](size_t index) const { return data_->data.at(index); } - - const uint8_t* data() const { return data_->data.storage(); } - - private: - Data_* data_ = nullptr; -}; - -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_NATIVE_STRUCT_DATA_VIEW_H_ diff --git a/mojo/public/cpp/bindings/pipe_control_message_handler.h b/mojo/public/cpp/bindings/pipe_control_message_handler.h index a5c04da627..071d9fe46d 100644 --- a/mojo/public/cpp/bindings/pipe_control_message_handler.h +++ b/mojo/public/cpp/bindings/pipe_control_message_handler.h @@ -18,7 +18,7 @@ class PipeControlMessageHandlerDelegate; // Handler for messages defined in pipe_control_messages.mojom. class MOJO_CPP_BINDINGS_EXPORT PipeControlMessageHandler - : NON_EXPORTED_BASE(public MessageReceiver) { + : public MessageReceiver { public: explicit PipeControlMessageHandler( PipeControlMessageHandlerDelegate* delegate); diff --git a/mojo/public/cpp/bindings/pipe_control_message_proxy.h b/mojo/public/cpp/bindings/pipe_control_message_proxy.h index 52c408f827..f57f039a4e 100644 --- a/mojo/public/cpp/bindings/pipe_control_message_proxy.h +++ b/mojo/public/cpp/bindings/pipe_control_message_proxy.h @@ -19,11 +19,11 @@ class MessageReceiver; // Proxy for request messages defined in pipe_control_messages.mojom. // -// NOTE: This object may be used from multiple threads. +// NOTE: This object may be used from multiple sequences. class MOJO_CPP_BINDINGS_EXPORT PipeControlMessageProxy { public: // Doesn't take ownership of |receiver|. If This PipeControlMessageProxy will - // be used from multiple threads, |receiver| must be thread-safe. + // be used from multiple sequences, |receiver| must be thread-safe. explicit PipeControlMessageProxy(MessageReceiver* receiver); void NotifyPeerEndpointClosed(InterfaceId id, diff --git a/mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h b/mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h index 16527cf747..0637d49755 100644 --- a/mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h +++ b/mojo/public/cpp/bindings/scoped_interface_endpoint_handle.h @@ -8,12 +8,12 @@ #include <string> #include "base/callback.h" +#include "base/component_export.h" #include "base/macros.h" #include "base/memory/ref_counted.h" #include "base/optional.h" #include "base/single_thread_task_runner.h" #include "base/threading/thread_task_runner_handle.h" -#include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/disconnect_reason.h" #include "mojo/public/cpp/bindings/interface_id.h" @@ -24,8 +24,8 @@ class AssociatedGroupController; // ScopedInterfaceEndpointHandle refers to one end of an interface, either the // implementation side or the client side. // Threading: At any given time, a ScopedInterfaceEndpointHandle should only -// be accessed from a single thread. -class MOJO_CPP_BINDINGS_EXPORT ScopedInterfaceEndpointHandle { +// be accessed from a single sequence. +class COMPONENT_EXPORT(MOJO_CPP_BINDINGS_BASE) ScopedInterfaceEndpointHandle { public: // Creates a pair of handles representing the two endpoints of an interface, // which are not yet associated with a message pipe. @@ -69,7 +69,7 @@ class MOJO_CPP_BINDINGS_EXPORT ScopedInterfaceEndpointHandle { using AssociationEventCallback = base::OnceCallback<void(AssociationEvent)>; // Note: // - |handler| won't run if the handle is invalid. Otherwise, |handler| is run - // on the calling thread asynchronously, even if the interface has already + // on the calling sequence asynchronously, even if the interface has already // been associated or the peer has been closed before association. // - |handler| won't be called after this object is destroyed or reset. // - A null |handler| can be used to cancel the previous callback. @@ -98,8 +98,8 @@ class MOJO_CPP_BINDINGS_EXPORT ScopedInterfaceEndpointHandle { void ResetInternal(const base::Optional<DisconnectReason>& reason); // Used by AssociatedGroup. - // It is safe to run the returned callback on any thread, or after this handle - // is destroyed. + // It is safe to run the returned callback on any sequence, or after this + // handle is destroyed. // The return value of the getter: // - If the getter is retrieved when the handle is invalid, the return value // of the getter will always be null. diff --git a/mojo/public/cpp/bindings/sequence_local_sync_event_watcher.h b/mojo/public/cpp/bindings/sequence_local_sync_event_watcher.h new file mode 100644 index 0000000000..ad50bde436 --- /dev/null +++ b/mojo/public/cpp/bindings/sequence_local_sync_event_watcher.h @@ -0,0 +1,69 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_SEQUENCE_LOCAL_SYNC_EVENT_WATCHER_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_SEQUENCE_LOCAL_SYNC_EVENT_WATCHER_H_ + +#include "base/callback.h" +#include "base/macros.h" +#include "base/memory/weak_ptr.h" +#include "mojo/public/cpp/bindings/bindings_export.h" + +namespace mojo { + +// This encapsulates a SyncEventWatcher watching an event shared by all +// |SequenceLocalSyncEventWatcher| on the same sequence. This class is NOT +// sequence-safe in general, but |SignalEvent()| is safe to call from any +// sequence. +// +// Interfaces which support sync messages use a WaitableEvent to block and +// be signaled when messages are available, but having a WaitableEvent for every +// such interface endpoint would cause the number of WaitableEvents to grow +// arbitrarily large. +// +// Some platform constraints may limit the number of WaitableEvents the bindings +// layer can wait upon concurrently, so this type is used to keep the number +// of such events fixed at a small constant value per sequence regardless of the +// number of active interface endpoints supporting sync messages on that +// sequence. +class MOJO_CPP_BINDINGS_EXPORT SequenceLocalSyncEventWatcher { + public: + explicit SequenceLocalSyncEventWatcher( + const base::RepeatingClosure& callback); + ~SequenceLocalSyncEventWatcher(); + + // Signals the shared event on behalf of this specific watcher. Safe to call + // from any sequence. + void SignalEvent(); + + // Resets the shared event on behalf of this specific watcher. + void ResetEvent(); + + // Allows this watcher to be notified during sync wait operations invoked by + // other watchers (for example, other SequenceLocalSyncEventWatchers calling + // |SyncWatch()|) on the same sequence. + void AllowWokenUpBySyncWatchOnSameSequence(); + + // Blocks the calling sequence until the shared event is signaled on behalf of + // this specific watcher (i.e. until someone calls |SignalEvent()| on |this|). + // Behaves similarly to SyncEventWatcher and SyncHandleWatcher, returning + // |true| when |*should_stop| is set to |true|, or |false| if some other + // (e.g. error) event interrupts the wait. + bool SyncWatch(const bool* should_stop); + + private: + class Registration; + class SequenceLocalState; + friend class SequenceLocalState; + + const std::unique_ptr<Registration> registration_; + const base::RepeatingClosure callback_; + bool can_wake_up_during_any_watch_ = false; + + DISALLOW_COPY_AND_ASSIGN(SequenceLocalSyncEventWatcher); +}; + +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_SEQUENCE_LOCAL_SYNC_EVENT_WATCHER_H_ diff --git a/mojo/public/cpp/bindings/string_traits.h b/mojo/public/cpp/bindings/string_traits.h index 7d3075a579..165c9fa8cd 100644 --- a/mojo/public/cpp/bindings/string_traits.h +++ b/mojo/public/cpp/bindings/string_traits.h @@ -5,7 +5,7 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_STRING_TRAITS_H_ #define MOJO_PUBLIC_CPP_BINDINGS_STRING_TRAITS_H_ -#include "mojo/public/cpp/bindings/string_data_view.h" +#include "mojo/public/cpp/bindings/lib/template_util.h" namespace mojo { @@ -47,7 +47,11 @@ namespace mojo { // so that you can do any necessary cleanup. // template <typename T> -struct StringTraits; +struct StringTraits { + static_assert(internal::AlwaysFalse<T>::value, + "Cannot find the mojo::StringTraits specialization. Did you " + "forget to include the corresponding header file?"); +}; } // namespace mojo diff --git a/mojo/public/cpp/bindings/string_traits_string16.h b/mojo/public/cpp/bindings/string_traits_string16.h deleted file mode 100644 index f96973ad91..0000000000 --- a/mojo/public/cpp/bindings/string_traits_string16.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2016 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef MOJO_PUBLIC_CPP_BINDINGS_STRING_TRAITS_STRING16_H_ -#define MOJO_PUBLIC_CPP_BINDINGS_STRING_TRAITS_STRING16_H_ - -#include "base/strings/string16.h" -#include "mojo/public/cpp/bindings/bindings_export.h" -#include "mojo/public/cpp/bindings/string_traits.h" - -namespace mojo { - -template <> -struct MOJO_CPP_BINDINGS_EXPORT StringTraits<base::string16> { - static bool IsNull(const base::string16& input) { - // base::string16 is always converted to non-null mojom string. - return false; - } - - static void SetToNull(base::string16* output) { - // Convert null to an "empty" base::string16. - output->clear(); - } - - static void* SetUpContext(const base::string16& input); - static void TearDownContext(const base::string16& input, void* context); - - static size_t GetSize(const base::string16& input, void* context); - static const char* GetData(const base::string16& input, void* context); - - static bool Read(StringDataView input, base::string16* output); -}; - -} // namespace mojo - -#endif // MOJO_PUBLIC_CPP_BINDINGS_STRING_TRAITS_STRING16_H_ diff --git a/mojo/public/cpp/bindings/string_traits_wtf.h b/mojo/public/cpp/bindings/string_traits_wtf.h index 238c2eb119..51601d1ab8 100644 --- a/mojo/public/cpp/bindings/string_traits_wtf.h +++ b/mojo/public/cpp/bindings/string_traits_wtf.h @@ -7,13 +7,13 @@ #include "mojo/public/cpp/bindings/lib/bindings_internal.h" #include "mojo/public/cpp/bindings/string_traits.h" -#include "third_party/WebKit/Source/wtf/text/WTFString.h" +#include "third_party/blink/renderer/platform/wtf/text/wtf_string.h" namespace mojo { template <> struct StringTraits<WTF::String> { - static bool IsNull(const WTF::String& input) { return input.isNull(); } + static bool IsNull(const WTF::String& input) { return input.IsNull(); } static void SetToNull(WTF::String* output); static void* SetUpContext(const WTF::String& input); diff --git a/mojo/public/cpp/bindings/strong_associated_binding.h b/mojo/public/cpp/bindings/strong_associated_binding.h index a1e299bb2d..fffdeb4d62 100644 --- a/mojo/public/cpp/bindings/strong_associated_binding.h +++ b/mojo/public/cpp/bindings/strong_associated_binding.h @@ -37,12 +37,15 @@ using StrongAssociatedBindingPtr = // To use, call StrongAssociatedBinding<T>::Create() (see below) or the helper // MakeStrongAssociatedBinding function: // -// mojo::MakeStrongAssociatedBinding(base::MakeUnique<FooImpl>(), +// mojo::MakeStrongAssociatedBinding(std::make_unique<FooImpl>(), // std::move(foo_request)); // template <typename Interface> class StrongAssociatedBinding { public: + using ImplPointerType = + typename AssociatedBinding<Interface>::ImplPointerType; + // Create a new StrongAssociatedBinding instance. The instance owns itself, // cleaning up only in the event of a pipe connection error. Returns a WeakPtr // to the new StrongAssociatedBinding instance. @@ -58,16 +61,16 @@ class StrongAssociatedBinding { // // This method may only be called after this StrongAssociatedBinding has been // bound to a message pipe. - void set_connection_error_handler(const base::Closure& error_handler) { + void set_connection_error_handler(base::OnceClosure error_handler) { DCHECK(binding_.is_bound()); - connection_error_handler_ = error_handler; + connection_error_handler_ = std::move(error_handler); connection_error_with_reason_handler_.Reset(); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { DCHECK(binding_.is_bound()); - connection_error_with_reason_handler_ = error_handler; + connection_error_with_reason_handler_ = std::move(error_handler); connection_error_handler_.Reset(); } @@ -82,6 +85,11 @@ class StrongAssociatedBinding { // stimulus. void FlushForTesting() { binding_.FlushForTesting(); } + // Allows test code to swap the interface implementation. + ImplPointerType SwapImplForTesting(ImplPointerType new_impl) { + return binding_.SwapImplForTesting(new_impl); + } + private: StrongAssociatedBinding(std::unique_ptr<Interface> impl, AssociatedInterfaceRequest<Interface> request) @@ -96,15 +104,17 @@ class StrongAssociatedBinding { void OnConnectionError(uint32_t custom_reason, const std::string& description) { - if (!connection_error_handler_.is_null()) - connection_error_handler_.Run(); - else if (!connection_error_with_reason_handler_.is_null()) - connection_error_with_reason_handler_.Run(custom_reason, description); + if (connection_error_handler_) { + std::move(connection_error_handler_).Run(); + } else if (connection_error_with_reason_handler_) { + std::move(connection_error_with_reason_handler_) + .Run(custom_reason, description); + } Close(); } std::unique_ptr<Interface> impl_; - base::Closure connection_error_handler_; + base::OnceClosure connection_error_handler_; ConnectionErrorWithReasonCallback connection_error_with_reason_handler_; AssociatedBinding<Interface> binding_; base::WeakPtrFactory<StrongAssociatedBinding> weak_factory_; diff --git a/mojo/public/cpp/bindings/strong_associated_binding_set.h b/mojo/public/cpp/bindings/strong_associated_binding_set.h new file mode 100644 index 0000000000..8c769698ba --- /dev/null +++ b/mojo/public/cpp/bindings/strong_associated_binding_set.h @@ -0,0 +1,25 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_STRONG_ASSOCIATED_BINDING_SET_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_STRONG_ASSOCIATED_BINDING_SET_H_ + +#include "mojo/public/cpp/bindings/associated_binding.h" +#include "mojo/public/cpp/bindings/associated_binding_set.h" +#include "mojo/public/cpp/bindings/associated_interface_ptr.h" +#include "mojo/public/cpp/bindings/associated_interface_request.h" +#include "mojo/public/cpp/bindings/binding_set.h" +#include "mojo/public/cpp/bindings/unique_ptr_impl_ref_traits.h" + +namespace mojo { + +template <typename Interface, typename ContextType = void> +using StrongAssociatedBindingSet = BindingSetBase< + Interface, + AssociatedBinding<Interface, UniquePtrImplRefTraits<Interface>>, + ContextType>; + +} // namespace mojo + +#endif // MOJO_PUBLIC_CPP_BINDINGS_STRONG_ASSOCIATED_BINDING_SET_H_ diff --git a/mojo/public/cpp/bindings/strong_binding.h b/mojo/public/cpp/bindings/strong_binding.h index f4b4a061cd..40c67e7834 100644 --- a/mojo/public/cpp/bindings/strong_binding.h +++ b/mojo/public/cpp/bindings/strong_binding.h @@ -38,7 +38,7 @@ using StrongBindingPtr = base::WeakPtr<StrongBinding<Interface>>; // To use, call StrongBinding<T>::Create() (see below) or the helper // MakeStrongBinding function: // -// mojo::MakeStrongBinding(base::MakeUnique<FooImpl>(), +// mojo::MakeStrongBinding(std::make_unique<FooImpl>(), // std::move(foo_request)); // template <typename Interface> @@ -59,19 +59,34 @@ class StrongBinding { // // This method may only be called after this StrongBinding has been bound to a // message pipe. - void set_connection_error_handler(const base::Closure& error_handler) { + void set_connection_error_handler(base::OnceClosure error_handler) { DCHECK(binding_.is_bound()); - connection_error_handler_ = error_handler; + connection_error_handler_ = std::move(error_handler); connection_error_with_reason_handler_.Reset(); } void set_connection_error_with_reason_handler( - const ConnectionErrorWithReasonCallback& error_handler) { + ConnectionErrorWithReasonCallback error_handler) { DCHECK(binding_.is_bound()); - connection_error_with_reason_handler_ = error_handler; + connection_error_with_reason_handler_ = std::move(error_handler); connection_error_handler_.Reset(); } + // Stops processing incoming messages until + // ResumeIncomingMethodCallProcessing(). + // Outgoing messages are still sent. + // + // No errors are detected on the message pipe while paused. + // + // This method may only be called if the object has been bound to a message + // pipe and there are no associated interfaces running. + void PauseIncomingMethodCallProcessing() { + binding_.PauseIncomingMethodCallProcessing(); + } + void ResumeIncomingMethodCallProcessing() { + binding_.ResumeIncomingMethodCallProcessing(); + } + // Forces the binding to close. This destroys the StrongBinding instance. void Close() { delete this; } @@ -97,15 +112,17 @@ class StrongBinding { void OnConnectionError(uint32_t custom_reason, const std::string& description) { - if (!connection_error_handler_.is_null()) - connection_error_handler_.Run(); - else if (!connection_error_with_reason_handler_.is_null()) - connection_error_with_reason_handler_.Run(custom_reason, description); + if (connection_error_handler_) { + std::move(connection_error_handler_).Run(); + } else if (connection_error_with_reason_handler_) { + std::move(connection_error_with_reason_handler_) + .Run(custom_reason, description); + } Close(); } std::unique_ptr<Interface> impl_; - base::Closure connection_error_handler_; + base::OnceClosure connection_error_handler_; ConnectionErrorWithReasonCallback connection_error_with_reason_handler_; Binding<Interface> binding_; base::WeakPtrFactory<StrongBinding> weak_factory_; diff --git a/mojo/public/cpp/bindings/strong_binding_set.h b/mojo/public/cpp/bindings/strong_binding_set.h index f6bcd5259c..3f855f428e 100644 --- a/mojo/public/cpp/bindings/strong_binding_set.h +++ b/mojo/public/cpp/bindings/strong_binding_set.h @@ -15,11 +15,13 @@ namespace mojo { // set, and the interface implementation is deleted. When the StrongBindingSet // is destructed, all outstanding bindings in the set are destroyed and all the // bound interface implementations are automatically deleted. -template <typename Interface, typename ContextType = void> -using StrongBindingSet = - BindingSetBase<Interface, - Binding<Interface, UniquePtrImplRefTraits<Interface>>, - ContextType>; +template <typename Interface, + typename ContextType = void, + typename Deleter = std::default_delete<Interface>> +using StrongBindingSet = BindingSetBase< + Interface, + Binding<Interface, UniquePtrImplRefTraits<Interface, Deleter>>, + ContextType>; } // namespace mojo diff --git a/mojo/public/cpp/bindings/struct_ptr.h b/mojo/public/cpp/bindings/struct_ptr.h index b135312e39..5c88f5a39e 100644 --- a/mojo/public/cpp/bindings/struct_ptr.h +++ b/mojo/public/cpp/bindings/struct_ptr.h @@ -81,7 +81,8 @@ class StructPtr { StructPtr Clone() const { return is_null() ? StructPtr() : ptr_->Clone(); } // Compares the pointees (which might both be null). - // TODO(tibell): Get rid of Equals in favor of the operator. Same for Hash. + // TODO(crbug.com/735302): Get rid of Equals in favor of the operator. Same + // for Hash. bool Equals(const StructPtr& other) const { if (is_null() || other.is_null()) return is_null() && other.is_null(); @@ -97,6 +98,10 @@ class StructPtr { explicit operator bool() const { return !is_null(); } + bool operator<(const StructPtr& other) const { + return Hash(internal::kHashSeed) < other.Hash(internal::kHashSeed); + } + private: friend class internal::StructPtrWTFHelper<Struct>; void Take(StructPtr* other) { @@ -192,6 +197,10 @@ class InlinedStructPtr { explicit operator bool() const { return !is_null(); } + bool operator<(const InlinedStructPtr& other) const { + return Hash(internal::kHashSeed) < other.Hash(internal::kHashSeed); + } + private: friend class internal::InlinedStructPtrWTFHelper<Struct>; void Take(InlinedStructPtr* other) { diff --git a/mojo/public/cpp/bindings/struct_traits.h b/mojo/public/cpp/bindings/struct_traits.h index 6cc070fc48..ff86f4219d 100644 --- a/mojo/public/cpp/bindings/struct_traits.h +++ b/mojo/public/cpp/bindings/struct_traits.h @@ -5,6 +5,8 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_STRUCT_TRAITS_H_ #define MOJO_PUBLIC_CPP_BINDINGS_STRUCT_TRAITS_H_ +#include "mojo/public/cpp/bindings/lib/template_util.h" + namespace mojo { // This must be specialized for any type |T| to be serialized/deserialized as @@ -37,7 +39,7 @@ namespace mojo { // // - map: // Value or reference of any type that has a MapTraits defined. -// Supported by default: std::map, std::unordered_map, +// Supported by default: std::map, std::unordered_map, base::flat_map, // WTF::HashMap (in blink). // // - struct: @@ -47,16 +49,13 @@ namespace mojo { // Value of any type that has an EnumTraits defined. // // For any nullable string/struct/array/map/union field you could also -// return value or reference of base::Optional<T>/WTF::Optional<T>, if T -// has the right *Traits defined. -// -// During serialization, getters for string/struct/array/map/union fields -// are called twice (one for size calculation and one for actual -// serialization). If you want to return a value (as opposed to a -// reference) from these getters, you have to be sure that constructing and -// copying the returned object is really cheap. +// return value or reference of base::Optional<T>, if T has the right +// *Traits defined. // -// Getters for fields of other types are called once. +// During serialization, getters for all fields are called exactly once. It +// is therefore reasonably effecient for a getter to construct and return +// temporary value in the event that it cannot return a readily +// serializable reference to some existing object. // // 2. A static Read() method to set the contents of a |T| instance from a // DataViewType. @@ -76,9 +75,10 @@ namespace mojo { // // static bool IsNull(const T& input); // -// If this method returns true, it is guaranteed that none of the getters -// (described in section 1) will be called for the same |input|. So you -// don't have to check whether |input| is null in those getters. +// This method is called exactly once during serialization, and if it +// returns |true|, it is guaranteed that none of the getters (described in +// section 1) will be called for the same |input|. So you don't have to +// check whether |input| is null in those getters. // // If it is not defined, |T| instances are always considered non-null. // @@ -158,7 +158,11 @@ namespace mojo { // }; // template <typename DataViewType, typename T> -struct StructTraits; +struct StructTraits { + static_assert(internal::AlwaysFalse<T>::value, + "Cannot find the mojo::StructTraits specialization. Did you " + "forget to include the corresponding header file?"); +}; } // namespace mojo diff --git a/mojo/public/cpp/bindings/sync_call_restrictions.h b/mojo/public/cpp/bindings/sync_call_restrictions.h index 5529042784..e72cb1d3cb 100644 --- a/mojo/public/cpp/bindings/sync_call_restrictions.h +++ b/mojo/public/cpp/bindings/sync_call_restrictions.h @@ -15,6 +15,20 @@ #define ENABLE_SYNC_CALL_RESTRICTIONS 0 #endif +class ChromeSelectFileDialogFactory; + +namespace sync_preferences { +class PrefServiceSyncable; +} + +namespace content { +class BlinkTestController; +} + +namespace display { +class ForwardingDisplayDelegate; +} + namespace leveldb { class LevelDBMojoProxy; } @@ -24,14 +38,16 @@ class PersistentPrefStoreClient; } namespace ui { -class Gpu; -} +class ClipboardClient; +class HostContextFactoryPrivate; +} // namespace ui -namespace views { -class ClipboardMus; +namespace viz { +class HostFrameSinkManager; } namespace mojo { +class ScopedAllowSyncCallForTesting; // In some processes, sync calls are disallowed. For example, in the browser // process we don't want any sync calls to child processes for performance, @@ -40,38 +56,62 @@ namespace mojo { // // Before processing a sync call, the bindings call // SyncCallRestrictions::AssertSyncCallAllowed() to check whether sync calls are -// allowed. By default, it is determined by the mojo system property -// MOJO_PROPERTY_SYNC_CALL_ALLOWED. If the default setting says no but you have -// a very compelling reason to disregard that (which should be very very rare), -// you can override it by constructing a ScopedAllowSyncCall object, which -// allows making sync calls on the current thread during its lifetime. +// allowed. By default sync calls are allowed but they may be globally +// disallowed within a process by calling DisallowSyncCall(). +// +// If globally disallowed but you but you have a very compelling reason to +// disregard that (which should be very very rare), you can override it by +// constructing a ScopedAllowSyncCall object which allows making sync calls on +// the current sequence during its lifetime. class MOJO_CPP_BINDINGS_EXPORT SyncCallRestrictions { public: #if ENABLE_SYNC_CALL_RESTRICTIONS - // Checks whether the current thread is allowed to make sync calls, and causes - // a DCHECK if not. + // Checks whether the current sequence is allowed to make sync calls, and + // causes a DCHECK if not. static void AssertSyncCallAllowed(); + + // Disables sync calls within the calling process. Any caller who wishes to + // make sync calls once this has been invoked must do so within the extent of + // a ScopedAllowSyncCall or ScopedAllowSyncCallForTesting. + static void DisallowSyncCall(); + #else // Inline the empty definitions of functions so that they can be compiled out. static void AssertSyncCallAllowed() {} + static void DisallowSyncCall() {} #endif private: // DO NOT ADD ANY OTHER FRIEND STATEMENTS, talk to mojo/OWNERS first. // BEGIN ALLOWED USAGE. - friend class ui::Gpu; // http://crbug.com/620058 + // SynchronousCompositorHost is used for Android webview. + friend class content::SynchronousCompositorHost; // LevelDBMojoProxy makes same-process sync calls from the DB thread. friend class leveldb::LevelDBMojoProxy; // Pref service connection is sync at startup. friend class prefs::PersistentPrefStoreClient; - + // Incognito pref service instances are created synchronously. + friend class sync_preferences::PrefServiceSyncable; + friend class mojo::ScopedAllowSyncCallForTesting; + // For file open and save dialogs created synchronously. + friend class ::ChromeSelectFileDialogFactory; + // For synchronous system clipboard access. + friend class ui::ClipboardClient; + // For destroying the GL context/surface that draw to a platform window before + // the platform window is destroyed. + friend class viz::HostFrameSinkManager; + // Allow for layout test pixel dumps. + friend class content::BlinkTestController; + // For preventing frame swaps of wrong size during resize on Windows. + // (https://crbug.com/811945) + friend class ui::HostContextFactoryPrivate; // END ALLOWED USAGE. // BEGIN USAGE THAT NEEDS TO BE FIXED. - // In the non-mus case, we called blocking OS functions in the ui::Clipboard - // implementation which weren't caught by sync call restrictions. Our blocking - // calls to mus, however, are. - friend class views::ClipboardMus; + // In ash::Shell::Init() it assumes that NativeDisplayDelegate will be + // synchronous at first. In mushrome ForwardingDisplayDelegate uses a + // synchronous call to get the display snapshots as a workaround. + friend class display::ForwardingDisplayDelegate; // END USAGE THAT NEEDS TO BE FIXED. #if ENABLE_SYNC_CALL_RESTRICTIONS @@ -84,7 +124,7 @@ class MOJO_CPP_BINDINGS_EXPORT SyncCallRestrictions { // If a process is configured to disallow sync calls in general, constructing // a ScopedAllowSyncCall object temporarily allows making sync calls on the - // current thread. Doing this is almost always incorrect, which is why we + // current sequence. Doing this is almost always incorrect, which is why we // limit who can use this through friend. If you find yourself needing to use // this, talk to mojo/OWNERS. class ScopedAllowSyncCall { @@ -103,6 +143,17 @@ class MOJO_CPP_BINDINGS_EXPORT SyncCallRestrictions { DISALLOW_IMPLICIT_CONSTRUCTORS(SyncCallRestrictions); }; +class ScopedAllowSyncCallForTesting { + public: + ScopedAllowSyncCallForTesting() {} + ~ScopedAllowSyncCallForTesting() {} + + private: + SyncCallRestrictions::ScopedAllowSyncCall scoped_allow_sync_call_; + + DISALLOW_COPY_AND_ASSIGN(ScopedAllowSyncCallForTesting); +}; + } // namespace mojo #endif // MOJO_PUBLIC_CPP_BINDINGS_SYNC_CALL_RESTRICTIONS_H_ diff --git a/mojo/public/cpp/bindings/sync_event_watcher.h b/mojo/public/cpp/bindings/sync_event_watcher.h index 6e254844e9..9bc9ada594 100644 --- a/mojo/public/cpp/bindings/sync_event_watcher.h +++ b/mojo/public/cpp/bindings/sync_event_watcher.h @@ -10,8 +10,8 @@ #include "base/callback.h" #include "base/macros.h" #include "base/memory/ref_counted.h" +#include "base/sequence_checker.h" #include "base/synchronization/waitable_event.h" -#include "base/threading/thread_checker.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/sync_handle_registry.h" @@ -19,7 +19,7 @@ namespace mojo { // SyncEventWatcher supports waiting on a base::WaitableEvent to signal while // also allowing other SyncEventWatchers and SyncHandleWatchers on the same -// thread to wake up as needed. +// sequence to wake up as needed. // // This class is not thread safe. class MOJO_CPP_BINDINGS_EXPORT SyncEventWatcher { @@ -29,17 +29,22 @@ class MOJO_CPP_BINDINGS_EXPORT SyncEventWatcher { ~SyncEventWatcher(); // Registers |event_| with SyncHandleRegistry, so that when others perform - // sync watching on the same thread, |event_| will be watched along with them. + // sync watching on the same sequence, |event_| will be watched along with + // them. void AllowWokenUpBySyncWatchOnSameThread(); // Waits on |event_| plus all other events and handles registered with this - // thread's SyncHandleRegistry, running callbacks synchronously for any ready - // events and handles. + // sequence's SyncHandleRegistry, running callbacks synchronously for any + // ready events and handles. + // + // |stop_flags| is treated as an array of |const bool*| with |num_stop_flags| + // entries. + // // This method: - // - returns true when |should_stop| is set to true; + // - returns true when any flag in |stop_flags| is set to |true|. // - return false when any error occurs, including this object being // destroyed during a callback. - bool SyncWatch(const bool* should_stop); + bool SyncWatch(const bool** stop_flags, size_t num_stop_flags); private: void IncrementRegisterCount(); @@ -58,7 +63,7 @@ class MOJO_CPP_BINDINGS_EXPORT SyncEventWatcher { scoped_refptr<base::RefCountedData<bool>> destroyed_; - base::ThreadChecker thread_checker_; + SEQUENCE_CHECKER(sequence_checker_); DISALLOW_COPY_AND_ASSIGN(SyncEventWatcher); }; diff --git a/mojo/public/cpp/bindings/sync_handle_registry.h b/mojo/public/cpp/bindings/sync_handle_registry.h index afb3b56bf4..cd535aa69d 100644 --- a/mojo/public/cpp/bindings/sync_handle_registry.h +++ b/mojo/public/cpp/bindings/sync_handle_registry.h @@ -6,30 +6,34 @@ #define MOJO_PUBLIC_CPP_BINDINGS_SYNC_HANDLE_REGISTRY_H_ #include <map> -#include <unordered_map> #include "base/callback.h" +#include "base/containers/stack_container.h" #include "base/macros.h" #include "base/memory/ref_counted.h" +#include "base/sequence_checker.h" #include "base/synchronization/waitable_event.h" -#include "base/threading/thread_checker.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/system/core.h" #include "mojo/public/cpp/system/wait_set.h" namespace mojo { -// SyncHandleRegistry is a thread-local storage to register handles that want to -// be watched together. +// SyncHandleRegistry is a sequence-local storage to register handles that want +// to be watched together. // -// This class is not thread safe. +// This class is thread unsafe. class MOJO_CPP_BINDINGS_EXPORT SyncHandleRegistry : public base::RefCounted<SyncHandleRegistry> { public: - // Returns a thread-local object. + // Returns a sequence-local object. static scoped_refptr<SyncHandleRegistry> current(); using HandleCallback = base::Callback<void(MojoResult)>; + + // Registers a |Handle| to be watched for |handle_signals|. If any such + // signals are satisfied during a Wait(), the Wait() is woken up and + // |callback| is run. bool RegisterHandle(const Handle& handle, MojoHandleSignals handle_signals, const HandleCallback& callback); @@ -38,11 +42,13 @@ class MOJO_CPP_BINDINGS_EXPORT SyncHandleRegistry // Registers a |base::WaitableEvent| which can be used to wake up // Wait() before any handle signals. |event| is not owned, and if it signals - // during Wait(), |callback| is invoked. Returns |true| if registered - // successfully or |false| if |event| was already registered. - bool RegisterEvent(base::WaitableEvent* event, const base::Closure& callback); + // during Wait(), |callback| is invoked. Note that |event| may be registered + // multiple times with different callbacks. + void RegisterEvent(base::WaitableEvent* event, const base::Closure& callback); - void UnregisterEvent(base::WaitableEvent* event); + // Unregisters a specific |event|+|callback| pair. + void UnregisterEvent(base::WaitableEvent* event, + const base::Closure& callback); // Waits on all the registered handles and events and runs callbacks // synchronously for any that become ready. @@ -54,14 +60,28 @@ class MOJO_CPP_BINDINGS_EXPORT SyncHandleRegistry private: friend class base::RefCounted<SyncHandleRegistry>; + using EventCallbackList = base::StackVector<base::Closure, 1>; + using EventMap = std::map<base::WaitableEvent*, EventCallbackList>; + SyncHandleRegistry(); ~SyncHandleRegistry(); + void RemoveInvalidEventCallbacks(); + WaitSet wait_set_; std::map<Handle, HandleCallback> handles_; - std::map<base::WaitableEvent*, base::Closure> events_; + EventMap events_; + + // |true| iff this registry is currently dispatching event callbacks in + // Wait(). Used to allow for safe event registration/unregistration from event + // callbacks. + bool is_dispatching_event_callbacks_ = false; + + // Indicates if one or more event callbacks was unregistered during the most + // recent event callback dispatch. + bool remove_invalid_event_callbacks_after_dispatch_ = false; - base::ThreadChecker thread_checker_; + SEQUENCE_CHECKER(sequence_checker_); DISALLOW_COPY_AND_ASSIGN(SyncHandleRegistry); }; diff --git a/mojo/public/cpp/bindings/sync_handle_watcher.h b/mojo/public/cpp/bindings/sync_handle_watcher.h index eff73dd66e..e680d74aa6 100644 --- a/mojo/public/cpp/bindings/sync_handle_watcher.h +++ b/mojo/public/cpp/bindings/sync_handle_watcher.h @@ -7,7 +7,7 @@ #include "base/macros.h" #include "base/memory/ref_counted.h" -#include "base/threading/thread_checker.h" +#include "base/sequence_checker.h" #include "mojo/public/cpp/bindings/bindings_export.h" #include "mojo/public/cpp/bindings/sync_handle_registry.h" #include "mojo/public/cpp/system/core.h" @@ -15,15 +15,15 @@ namespace mojo { // SyncHandleWatcher supports watching a handle synchronously. It also supports -// registering the handle with a thread-local storage (SyncHandleRegistry), so -// that when other SyncHandleWatcher instances on the same thread perform sync +// registering the handle with a sequence-local storage (SyncHandleRegistry), so +// that when other SyncHandleWatcher instances on the same sequence perform sync // handle watching, this handle will be watched together. // // SyncHandleWatcher is used for sync methods. While a sync call is waiting for -// response, we would like to block the thread. On the other hand, we need -// incoming sync method requests on the same thread to be able to reenter. We +// response, we would like to block the sequence. On the other hand, we need +// incoming sync method requests on the same sequence to be able to reenter. We // also need master interface endpoints to continue dispatching messages for -// associated endpoints on different threads. +// associated endpoints on different sequence. // // This class is not thread safe. class MOJO_CPP_BINDINGS_EXPORT SyncHandleWatcher { @@ -36,7 +36,7 @@ class MOJO_CPP_BINDINGS_EXPORT SyncHandleWatcher { ~SyncHandleWatcher(); // Registers |handle_| with SyncHandleRegistry, so that when others perform - // sync handle watching on the same thread, |handle_| will be watched + // sync handle watching on the same sequence, |handle_| will be watched // together. void AllowWokenUpBySyncWatchOnSameThread(); @@ -65,7 +65,7 @@ class MOJO_CPP_BINDINGS_EXPORT SyncHandleWatcher { scoped_refptr<base::RefCountedData<bool>> destroyed_; - base::ThreadChecker thread_checker_; + SEQUENCE_CHECKER(sequence_checker_); DISALLOW_COPY_AND_ASSIGN(SyncHandleWatcher); }; diff --git a/mojo/public/cpp/bindings/tests/BUILD.gn b/mojo/public/cpp/bindings/tests/BUILD.gn index 668ca6da90..82038a186d 100644 --- a/mojo/public/cpp/bindings/tests/BUILD.gn +++ b/mojo/public/cpp/bindings/tests/BUILD.gn @@ -11,7 +11,10 @@ source_set("tests") { "binding_callback_unittest.cc", "binding_set_unittest.cc", "binding_unittest.cc", + "bindings_test_base.cc", + "bindings_test_base.h", "buffer_unittest.cc", + "callback_helpers_unittest.cc", "connector_unittest.cc", "constant_unittest.cc", "container_test_util.cc", @@ -21,10 +24,12 @@ source_set("tests") { "handle_passing_unittest.cc", "hash_unittest.cc", "interface_ptr_unittest.cc", + "lazy_serialization_unittest.cc", "map_unittest.cc", "message_queue.cc", "message_queue.h", "multiplex_router_unittest.cc", + "native_struct_unittest.cc", "report_bad_message_unittest.cc", "request_response_unittest.cc", "router_test_util.cc", @@ -32,7 +37,9 @@ source_set("tests") { "sample_service_unittest.cc", "serialization_warning_unittest.cc", "struct_unittest.cc", + "sync_handle_registry_unittest.cc", "sync_method_unittest.cc", + "test_helpers_unittest.cc", "type_conversion_unittest.cc", "union_unittest.cc", "validation_context_unittest.cc", @@ -43,7 +50,7 @@ source_set("tests") { deps = [ ":mojo_public_bindings_test_utils", "//base/test:test_support", - "//mojo/edk/system", + "//mojo/core/embedder", "//mojo/public/cpp/bindings", "//mojo/public/cpp/system", "//mojo/public/cpp/test_support:test_utils", @@ -68,7 +75,10 @@ source_set("tests") { "struct_traits_unittest.cc", ] - deps += [ "//mojo/public/interfaces/bindings/tests:test_interfaces_blink" ] + deps += [ + "//mojo/public/cpp/bindings/tests:struct_with_traits_impl", + "//mojo/public/interfaces/bindings/tests:test_interfaces_blink", + ] } } @@ -124,8 +134,8 @@ source_set("perftests") { deps = [ "//base/test:test_support", - "//mojo/edk/system", - "//mojo/edk/test:test_support", + "//mojo/core/embedder", + "//mojo/core/test:test_support", "//mojo/public/cpp/bindings", "//mojo/public/cpp/system", "//mojo/public/cpp/test_support:test_utils", diff --git a/mojo/public/cpp/bindings/tests/associated_interface_unittest.cc b/mojo/public/cpp/bindings/tests/associated_interface_unittest.cc index be225e4761..85cec97ae9 100644 --- a/mojo/public/cpp/bindings/tests/associated_interface_unittest.cc +++ b/mojo/public/cpp/bindings/tests/associated_interface_unittest.cc @@ -11,10 +11,11 @@ #include "base/callback.h" #include "base/callback_helpers.h" #include "base/memory/ptr_util.h" -#include "base/message_loop/message_loop.h" #include "base/run_loop.h" #include "base/single_thread_task_runner.h" #include "base/synchronization/waitable_event.h" +#include "base/task_scheduler/post_task.h" +#include "base/test/scoped_task_environment.h" #include "base/threading/sequenced_task_runner_handle.h" #include "base/threading/thread.h" #include "base/threading/thread_task_runner_handle.h" @@ -95,7 +96,8 @@ class IntegerSenderConnectionImpl : public IntegerSenderConnection { class AssociatedInterfaceTest : public testing::Test { public: - AssociatedInterfaceTest() {} + AssociatedInterfaceTest() + : main_runner_(base::ThreadTaskRunnerHandle::Get()) {} ~AssociatedInterfaceTest() override { base::RunLoop().RunUntilIdle(); } void PumpMessages() { base::RunLoop().RunUntilIdle(); } @@ -117,10 +119,10 @@ class AssociatedInterfaceTest : public testing::Test { MessagePipe pipe; *router0 = new MultiplexRouter(std::move(pipe.handle0), MultiplexRouter::MULTI_INTERFACE, true, - base::ThreadTaskRunnerHandle::Get()); + main_runner_); *router1 = new MultiplexRouter(std::move(pipe.handle1), MultiplexRouter::MULTI_INTERFACE, false, - base::ThreadTaskRunnerHandle::Get()); + main_runner_); } void CreateIntegerSenderWithExistingRouters( @@ -143,10 +145,10 @@ class AssociatedInterfaceTest : public testing::Test { // Okay to call from any thread. void QuitRunLoop(base::RunLoop* run_loop) { - if (loop_.task_runner()->BelongsToCurrentThread()) { + if (main_runner_->RunsTasksInCurrentSequence()) { run_loop->Quit(); } else { - loop_.task_runner()->PostTask( + main_runner_->PostTask( FROM_HERE, base::Bind(&AssociatedInterfaceTest::QuitRunLoop, base::Unretained(this), base::Unretained(run_loop))); @@ -154,7 +156,8 @@ class AssociatedInterfaceTest : public testing::Test { } private: - base::MessageLoop loop_; + base::test::ScopedTaskEnvironment task_environment; + scoped_refptr<base::SequencedTaskRunner> main_runner_; }; void DoSetFlagAndRunClosure(bool* flag, const base::Closure& closure) { @@ -244,17 +247,15 @@ TEST_F(AssociatedInterfaceTest, InterfacesAtBothEnds) { class TestSender { public: TestSender() - : sender_thread_("TestSender"), + : task_runner_(base::CreateSequencedTaskRunnerWithTraits({})), next_sender_(nullptr), - max_value_to_send_(-1) { - sender_thread_.Start(); - } + max_value_to_send_(-1) {} // The following three methods are called on the corresponding sender thread. void SetUp(IntegerSenderAssociatedPtrInfo ptr_info, TestSender* next_sender, int32_t max_value_to_send) { - CHECK(sender_thread_.task_runner()->BelongsToCurrentThread()); + CHECK(task_runner()->RunsTasksInCurrentSequence()); ptr_.Bind(std::move(ptr_info)); next_sender_ = next_sender ? next_sender : this; @@ -262,28 +263,28 @@ class TestSender { } void Send(int32_t value) { - CHECK(sender_thread_.task_runner()->BelongsToCurrentThread()); + CHECK(task_runner()->RunsTasksInCurrentSequence()); if (value > max_value_to_send_) return; ptr_->Send(value); - next_sender_->sender_thread()->task_runner()->PostTask( + next_sender_->task_runner()->PostTask( FROM_HERE, base::Bind(&TestSender::Send, base::Unretained(next_sender_), ++value)); } void TearDown() { - CHECK(sender_thread_.task_runner()->BelongsToCurrentThread()); + CHECK(task_runner()->RunsTasksInCurrentSequence()); ptr_.reset(); } - base::Thread* sender_thread() { return &sender_thread_; } + base::SequencedTaskRunner* task_runner() { return task_runner_.get(); } private: - base::Thread sender_thread_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; TestSender* next_sender_; int32_t max_value_to_send_; @@ -292,15 +293,15 @@ class TestSender { class TestReceiver { public: - TestReceiver() : receiver_thread_("TestReceiver"), expected_calls_(0) { - receiver_thread_.Start(); - } + TestReceiver() + : task_runner_(base::CreateSequencedTaskRunnerWithTraits({})), + expected_calls_(0) {} void SetUp(AssociatedInterfaceRequest<IntegerSender> request0, AssociatedInterfaceRequest<IntegerSender> request1, size_t expected_calls, const base::Closure& notify_finish) { - CHECK(receiver_thread_.task_runner()->BelongsToCurrentThread()); + CHECK(task_runner()->RunsTasksInCurrentSequence()); impl0_.reset(new IntegerSenderImpl(std::move(request0))); impl0_->set_notify_send_method_called( @@ -314,13 +315,13 @@ class TestReceiver { } void TearDown() { - CHECK(receiver_thread_.task_runner()->BelongsToCurrentThread()); + CHECK(task_runner()->RunsTasksInCurrentSequence()); impl0_.reset(); impl1_.reset(); } - base::Thread* receiver_thread() { return &receiver_thread_; } + base::SequencedTaskRunner* task_runner() { return task_runner_.get(); } const std::vector<int32_t>& values() const { return values_; } private: @@ -331,7 +332,7 @@ class TestReceiver { notify_finish_.Run(); } - base::Thread receiver_thread_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; size_t expected_calls_; std::unique_ptr<IntegerSenderImpl> impl0_; @@ -392,7 +393,7 @@ TEST_F(AssociatedInterfaceTest, MultiThreadAccess) { TestSender senders[4]; for (size_t i = 0; i < 4; ++i) { - senders[i].sender_thread()->task_runner()->PostTask( + senders[i].task_runner()->PostTask( FROM_HERE, base::Bind(&TestSender::SetUp, base::Unretained(&senders[i]), base::Passed(&ptr_infos[i]), nullptr, kMaxValue * (i + 1) / 4)); @@ -404,7 +405,7 @@ TEST_F(AssociatedInterfaceTest, MultiThreadAccess) { 2, base::Bind(&AssociatedInterfaceTest::QuitRunLoop, base::Unretained(this), base::Unretained(&run_loop))); for (size_t i = 0; i < 2; ++i) { - receivers[i].receiver_thread()->task_runner()->PostTask( + receivers[i].task_runner()->PostTask( FROM_HERE, base::Bind(&TestReceiver::SetUp, base::Unretained(&receivers[i]), base::Passed(&requests[2 * i]), @@ -415,7 +416,7 @@ TEST_F(AssociatedInterfaceTest, MultiThreadAccess) { } for (size_t i = 0; i < 4; ++i) { - senders[i].sender_thread()->task_runner()->PostTask( + senders[i].task_runner()->PostTask( FROM_HERE, base::Bind(&TestSender::Send, base::Unretained(&senders[i]), kMaxValue * i / 4 + 1)); } @@ -424,7 +425,7 @@ TEST_F(AssociatedInterfaceTest, MultiThreadAccess) { for (size_t i = 0; i < 4; ++i) { base::RunLoop run_loop; - senders[i].sender_thread()->task_runner()->PostTaskAndReply( + senders[i].task_runner()->PostTaskAndReply( FROM_HERE, base::Bind(&TestSender::TearDown, base::Unretained(&senders[i])), base::Bind(&AssociatedInterfaceTest::QuitRunLoop, @@ -434,7 +435,7 @@ TEST_F(AssociatedInterfaceTest, MultiThreadAccess) { for (size_t i = 0; i < 2; ++i) { base::RunLoop run_loop; - receivers[i].receiver_thread()->task_runner()->PostTaskAndReply( + receivers[i].task_runner()->PostTaskAndReply( FROM_HERE, base::Bind(&TestReceiver::TearDown, base::Unretained(&receivers[i])), base::Bind(&AssociatedInterfaceTest::QuitRunLoop, @@ -477,7 +478,7 @@ TEST_F(AssociatedInterfaceTest, FIFO) { TestSender senders[4]; for (size_t i = 0; i < 4; ++i) { - senders[i].sender_thread()->task_runner()->PostTask( + senders[i].task_runner()->PostTask( FROM_HERE, base::Bind(&TestSender::SetUp, base::Unretained(&senders[i]), base::Passed(&ptr_infos[i]), @@ -490,7 +491,7 @@ TEST_F(AssociatedInterfaceTest, FIFO) { 2, base::Bind(&AssociatedInterfaceTest::QuitRunLoop, base::Unretained(this), base::Unretained(&run_loop))); for (size_t i = 0; i < 2; ++i) { - receivers[i].receiver_thread()->task_runner()->PostTask( + receivers[i].task_runner()->PostTask( FROM_HERE, base::Bind(&TestReceiver::SetUp, base::Unretained(&receivers[i]), base::Passed(&requests[2 * i]), @@ -500,7 +501,7 @@ TEST_F(AssociatedInterfaceTest, FIFO) { base::Unretained(&counter)))); } - senders[0].sender_thread()->task_runner()->PostTask( + senders[0].task_runner()->PostTask( FROM_HERE, base::Bind(&TestSender::Send, base::Unretained(&senders[0]), 1)); @@ -508,7 +509,7 @@ TEST_F(AssociatedInterfaceTest, FIFO) { for (size_t i = 0; i < 4; ++i) { base::RunLoop run_loop; - senders[i].sender_thread()->task_runner()->PostTaskAndReply( + senders[i].task_runner()->PostTaskAndReply( FROM_HERE, base::Bind(&TestSender::TearDown, base::Unretained(&senders[i])), base::Bind(&AssociatedInterfaceTest::QuitRunLoop, @@ -518,7 +519,7 @@ TEST_F(AssociatedInterfaceTest, FIFO) { for (size_t i = 0; i < 2; ++i) { base::RunLoop run_loop; - receivers[i].receiver_thread()->task_runner()->PostTaskAndReply( + receivers[i].task_runner()->PostTaskAndReply( FROM_HERE, base::Bind(&TestReceiver::TearDown, base::Unretained(&receivers[i])), base::Bind(&AssociatedInterfaceTest::QuitRunLoop, @@ -669,7 +670,7 @@ class CallbackFilter : public MessageReceiver { ~CallbackFilter() override {} static std::unique_ptr<CallbackFilter> Wrap(const base::Closure& callback) { - return base::MakeUnique<CallbackFilter>(callback); + return std::make_unique<CallbackFilter>(callback); } // MessageReceiver: @@ -760,8 +761,8 @@ TEST_F(AssociatedInterfaceTest, AssociatedPtrFlushForTesting) { ptr0.set_connection_error_handler(base::Bind(&Fail)); bool ptr0_callback_run = false; - ptr0->Echo(123, ExpectValueSetFlagAndRunClosure( - 123, &ptr0_callback_run, base::Bind(&base::DoNothing))); + ptr0->Echo(123, ExpectValueSetFlagAndRunClosure(123, &ptr0_callback_run, + base::DoNothing())); ptr0.FlushForTesting(); EXPECT_TRUE(ptr0_callback_run); } @@ -802,8 +803,8 @@ TEST_F(AssociatedInterfaceTest, AssociatedBindingFlushForTesting) { ptr0.Bind(std::move(ptr_info)); bool ptr0_callback_run = false; - ptr0->Echo(123, ExpectValueSetFlagAndRunClosure( - 123, &ptr0_callback_run, base::Bind(&base::DoNothing))); + ptr0->Echo(123, ExpectValueSetFlagAndRunClosure(123, &ptr0_callback_run, + base::DoNothing())); // Because the flush is sent from the binding, it only guarantees that the // request has been received, not the response. The second flush waits for the // response to be received. @@ -864,7 +865,7 @@ TEST_F(AssociatedInterfaceTest, BindingFlushForTestingWithClosedPeer) { TEST_F(AssociatedInterfaceTest, StrongBindingFlushForTesting) { IntegerSenderConnectionPtr ptr; auto binding = - MakeStrongBinding(base::MakeUnique<IntegerSenderConnectionImpl>( + MakeStrongBinding(std::make_unique<IntegerSenderConnectionImpl>( IntegerSenderConnectionRequest{}), MakeRequest(&ptr)); bool called = false; @@ -886,7 +887,7 @@ TEST_F(AssociatedInterfaceTest, StrongBindingFlushForTestingWithClosedPeer) { IntegerSenderConnectionPtr ptr; bool called = false; auto binding = - MakeStrongBinding(base::MakeUnique<IntegerSenderConnectionImpl>( + MakeStrongBinding(std::make_unique<IntegerSenderConnectionImpl>( IntegerSenderConnectionRequest{}), MakeRequest(&ptr)); binding->set_connection_error_handler(base::Bind(&SetBool, &called)); @@ -1060,8 +1061,6 @@ TEST_F(AssociatedInterfaceTest, ThreadSafeAssociatedInterfacePtr) { // Test the thread safe pointer can be used from another thread. base::RunLoop run_loop; - base::Thread other_thread("service test thread"); - other_thread.Start(); auto run_method = base::Bind( [](const scoped_refptr<base::TaskRunner>& main_task_runner, @@ -1071,21 +1070,24 @@ TEST_F(AssociatedInterfaceTest, ThreadSafeAssociatedInterfacePtr) { auto done_callback = base::Bind( [](const scoped_refptr<base::TaskRunner>& main_task_runner, const base::Closure& quit_closure, - base::PlatformThreadId thread_id, int32_t result) { + scoped_refptr<base::SequencedTaskRunner> sender_sequence_runner, + int32_t result) { EXPECT_EQ(123, result); - // Validate the callback is invoked on the calling thread. - EXPECT_EQ(thread_id, base::PlatformThread::CurrentId()); + // Validate the callback is invoked on the calling sequence. + EXPECT_TRUE(sender_sequence_runner->RunsTasksInCurrentSequence()); // Notify the run_loop to quit. main_task_runner->PostTask(FROM_HERE, quit_closure); }); + scoped_refptr<base::SequencedTaskRunner> current_sequence_runner = + base::SequencedTaskRunnerHandle::Get(); (*thread_safe_sender) - ->Echo(123, - base::Bind(done_callback, main_task_runner, quit_closure, - base::PlatformThread::CurrentId())); + ->Echo(123, base::Bind(done_callback, main_task_runner, + quit_closure, current_sequence_runner)); }, base::SequencedTaskRunnerHandle::Get(), run_loop.QuitClosure(), thread_safe_sender); - other_thread.message_loop()->task_runner()->PostTask(FROM_HERE, run_method); + base::CreateSequencedTaskRunnerWithTraits({})->PostTask(FROM_HERE, + run_method); // Block until the method callback is called on the background thread. run_loop.Run(); @@ -1099,11 +1101,8 @@ struct ForwarderTestContext { TEST_F(AssociatedInterfaceTest, ThreadSafeAssociatedInterfacePtrWithTaskRunner) { - // Start the thread from where we'll bind the interface pointer. - base::Thread other_thread("service test thread"); - other_thread.Start(); - const scoped_refptr<base::SingleThreadTaskRunner>& other_thread_task_runner = - other_thread.message_loop()->task_runner(); + const scoped_refptr<base::SequencedTaskRunner> other_thread_task_runner = + base::CreateSequencedTaskRunnerWithTraits({}); ForwarderTestContext* context = new ForwarderTestContext(); IntegerSenderAssociatedPtrInfo sender_info; @@ -1113,7 +1112,7 @@ TEST_F(AssociatedInterfaceTest, auto setup = [](base::WaitableEvent* sender_info_bound_event, IntegerSenderAssociatedPtrInfo* sender_info, ForwarderTestContext* context) { - context->interface_impl = base::MakeUnique<IntegerSenderConnectionImpl>( + context->interface_impl = std::make_unique<IntegerSenderConnectionImpl>( MakeRequest(&context->connection_ptr)); auto sender_request = MakeRequest(sender_info); @@ -1168,7 +1167,8 @@ TEST_F(AssociatedInterfaceTest, CloseWithoutBindingAssociatedRequest) { DiscardingAssociatedPingProviderProvider ping_provider_provider; mojo::Binding<AssociatedPingProviderProvider> binding( &ping_provider_provider); - auto provider_provider = binding.CreateInterfacePtrAndBind(); + AssociatedPingProviderProviderPtr provider_provider; + binding.Bind(mojo::MakeRequest(&provider_provider)); AssociatedPingProviderAssociatedPtr provider; provider_provider->GetPingProvider(mojo::MakeRequest(&provider)); PingServiceAssociatedPtr ping; @@ -1178,9 +1178,9 @@ TEST_F(AssociatedInterfaceTest, CloseWithoutBindingAssociatedRequest) { run_loop.Run(); } -TEST_F(AssociatedInterfaceTest, GetIsolatedInterface) { +TEST_F(AssociatedInterfaceTest, AssociateWithDisconnectedPipe) { IntegerSenderAssociatedPtr sender; - GetIsolatedInterface(MakeRequest(&sender).PassHandle()); + AssociateWithDisconnectedPipe(MakeRequest(&sender).PassHandle()); sender->Send(42); } diff --git a/mojo/public/cpp/bindings/tests/bind_task_runner_unittest.cc b/mojo/public/cpp/bindings/tests/bind_task_runner_unittest.cc index 569eb518c6..17f4685af2 100644 --- a/mojo/public/cpp/bindings/tests/bind_task_runner_unittest.cc +++ b/mojo/public/cpp/bindings/tests/bind_task_runner_unittest.cc @@ -6,6 +6,7 @@ #include "base/bind.h" #include "base/callback.h" +#include "base/containers/queue.h" #include "base/message_loop/message_loop.h" #include "base/single_thread_task_runner.h" #include "base/synchronization/lock.h" @@ -31,14 +32,14 @@ class TestTaskRunner : public base::SingleThreadTaskRunner { task_ready_(base::WaitableEvent::ResetPolicy::AUTOMATIC, base::WaitableEvent::InitialState::NOT_SIGNALED) {} - bool PostNonNestableDelayedTask(const tracked_objects::Location& from_here, + bool PostNonNestableDelayedTask(const base::Location& from_here, base::OnceClosure task, base::TimeDelta delay) override { NOTREACHED(); return false; } - bool PostDelayedTask(const tracked_objects::Location& from_here, + bool PostDelayedTask(const base::Location& from_here, base::OnceClosure task, base::TimeDelta delay) override { { @@ -48,13 +49,13 @@ class TestTaskRunner : public base::SingleThreadTaskRunner { task_ready_.Signal(); return true; } - bool RunsTasksOnCurrentThread() const override { + bool RunsTasksInCurrentSequence() const override { return base::PlatformThread::CurrentRef() == thread_id_; } // Only quits when Quit() is called. void Run() { - DCHECK(RunsTasksOnCurrentThread()); + DCHECK(RunsTasksInCurrentSequence()); quit_called_ = false; while (true) { @@ -77,13 +78,13 @@ class TestTaskRunner : public base::SingleThreadTaskRunner { } void Quit() { - DCHECK(RunsTasksOnCurrentThread()); + DCHECK(RunsTasksInCurrentSequence()); quit_called_ = true; } // Waits until one task is ready and runs it. void RunOneTask() { - DCHECK(RunsTasksOnCurrentThread()); + DCHECK(RunsTasksInCurrentSequence()); while (true) { { @@ -112,7 +113,7 @@ class TestTaskRunner : public base::SingleThreadTaskRunner { // Protect |tasks_|. base::Lock lock_; - std::queue<base::OnceClosure> tasks_; + base::queue<base::OnceClosure> tasks_; DISALLOW_COPY_AND_ASSIGN(TestTaskRunner); }; diff --git a/mojo/public/cpp/bindings/tests/binding_set_unittest.cc b/mojo/public/cpp/bindings/tests/binding_set_unittest.cc index 07acfbebe0..67b6fb1819 100644 --- a/mojo/public/cpp/bindings/tests/binding_set_unittest.cc +++ b/mojo/public/cpp/bindings/tests/binding_set_unittest.cc @@ -5,11 +5,12 @@ #include <memory> #include <utility> -#include "base/message_loop/message_loop.h" #include "base/run_loop.h" +#include "mojo/core/embedder/embedder.h" #include "mojo/public/cpp/bindings/associated_binding_set.h" #include "mojo/public/cpp/bindings/binding_set.h" #include "mojo/public/cpp/bindings/strong_binding_set.h" +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" #include "mojo/public/interfaces/bindings/tests/ping_service.mojom.h" #include "mojo/public/interfaces/bindings/tests/test_associated_interfaces.mojom.h" #include "testing/gtest/include/gtest/gtest.h" @@ -18,18 +19,7 @@ namespace mojo { namespace test { namespace { -class BindingSetTest : public testing::Test { - public: - BindingSetTest() {} - ~BindingSetTest() override {} - - base::MessageLoop& loop() { return loop_; } - - private: - base::MessageLoop loop_; - - DISALLOW_COPY_AND_ASSIGN(BindingSetTest); -}; +using BindingSetTest = BindingsTestBase; template <typename BindingSetType, typename ContextType> void ExpectContextHelper(BindingSetType* binding_set, @@ -45,6 +35,45 @@ base::Closure ExpectContext(BindingSetType* binding_set, expected_context); } +template <typename BindingSetType> +void ExpectBindingIdHelper(BindingSetType* binding_set, + BindingId expected_binding_id) { + EXPECT_EQ(expected_binding_id, binding_set->dispatch_binding()); +} + +template <typename BindingSetType> +base::Closure ExpectBindingId(BindingSetType* binding_set, + BindingId expected_binding_id) { + return base::Bind(&ExpectBindingIdHelper<BindingSetType>, binding_set, + expected_binding_id); +} + +template <typename BindingSetType> +void ReportBadMessageHelper(BindingSetType* binding_set, + const std::string& error) { + binding_set->ReportBadMessage(error); +} + +template <typename BindingSetType> +base::Closure ReportBadMessage(BindingSetType* binding_set, + const std::string& error) { + return base::Bind(&ReportBadMessageHelper<BindingSetType>, binding_set, + error); +} + +template <typename BindingSetType> +void SaveBadMessageCallbackHelper(BindingSetType* binding_set, + ReportBadMessageCallback* callback) { + *callback = binding_set->GetBadMessageCallback(); +} + +template <typename BindingSetType> +base::Closure SaveBadMessageCallback(BindingSetType* binding_set, + ReportBadMessageCallback* callback) { + return base::Bind(&SaveBadMessageCallbackHelper<BindingSetType>, binding_set, + callback); +} + base::Closure Sequence(const base::Closure& first, const base::Closure& second) { return base::Bind( @@ -74,7 +103,7 @@ class PingImpl : public PingService { base::Closure ping_handler_; }; -TEST_F(BindingSetTest, BindingSetContext) { +TEST_P(BindingSetTest, BindingSetContext) { PingImpl impl; BindingSet<PingService, int> bindings; @@ -115,7 +144,48 @@ TEST_F(BindingSetTest, BindingSetContext) { EXPECT_TRUE(bindings.empty()); } -TEST_F(BindingSetTest, BindingSetConnectionErrorWithReason) { +TEST_P(BindingSetTest, BindingSetDispatchBinding) { + PingImpl impl; + + BindingSet<PingService, int> bindings; + PingServicePtr ping_a, ping_b; + BindingId id_a = bindings.AddBinding(&impl, MakeRequest(&ping_a), 1); + BindingId id_b = bindings.AddBinding(&impl, MakeRequest(&ping_b), 2); + + { + impl.set_ping_handler(ExpectBindingId(&bindings, id_a)); + base::RunLoop loop; + ping_a->Ping(loop.QuitClosure()); + loop.Run(); + } + + { + impl.set_ping_handler(ExpectBindingId(&bindings, id_b)); + base::RunLoop loop; + ping_b->Ping(loop.QuitClosure()); + loop.Run(); + } + + { + base::RunLoop loop; + bindings.set_connection_error_handler( + Sequence(ExpectBindingId(&bindings, id_a), loop.QuitClosure())); + ping_a.reset(); + loop.Run(); + } + + { + base::RunLoop loop; + bindings.set_connection_error_handler( + Sequence(ExpectBindingId(&bindings, id_b), loop.QuitClosure())); + ping_b.reset(); + loop.Run(); + } + + EXPECT_TRUE(bindings.empty()); +} + +TEST_P(BindingSetTest, BindingSetConnectionErrorWithReason) { PingImpl impl; PingServicePtr ptr; BindingSet<PingService> bindings; @@ -134,6 +204,121 @@ TEST_F(BindingSetTest, BindingSetConnectionErrorWithReason) { ptr.ResetWithReason(1024u, "bye"); } +TEST_P(BindingSetTest, BindingSetReportBadMessage) { + PingImpl impl; + + std::string last_received_error; + core::SetDefaultProcessErrorCallback( + base::Bind([](std::string* out_error, + const std::string& error) { *out_error = error; }, + &last_received_error)); + + BindingSet<PingService, int> bindings; + PingServicePtr ping_a, ping_b; + bindings.AddBinding(&impl, MakeRequest(&ping_a), 1); + bindings.AddBinding(&impl, MakeRequest(&ping_b), 2); + + { + impl.set_ping_handler(ReportBadMessage(&bindings, "message 1")); + base::RunLoop loop; + ping_a.set_connection_error_handler(loop.QuitClosure()); + ping_a->Ping(base::Bind([] {})); + loop.Run(); + EXPECT_EQ("message 1", last_received_error); + } + + { + impl.set_ping_handler(ReportBadMessage(&bindings, "message 2")); + base::RunLoop loop; + ping_b.set_connection_error_handler(loop.QuitClosure()); + ping_b->Ping(base::Bind([] {})); + loop.Run(); + EXPECT_EQ("message 2", last_received_error); + } + + EXPECT_TRUE(bindings.empty()); + + core::SetDefaultProcessErrorCallback(mojo::core::ProcessErrorCallback()); +} + +TEST_P(BindingSetTest, BindingSetGetBadMessageCallback) { + PingImpl impl; + + std::string last_received_error; + core::SetDefaultProcessErrorCallback( + base::Bind([](std::string* out_error, + const std::string& error) { *out_error = error; }, + &last_received_error)); + + BindingSet<PingService, int> bindings; + PingServicePtr ping_a, ping_b; + bindings.AddBinding(&impl, MakeRequest(&ping_a), 1); + bindings.AddBinding(&impl, MakeRequest(&ping_b), 2); + + ReportBadMessageCallback bad_message_callback_a; + ReportBadMessageCallback bad_message_callback_b; + + { + impl.set_ping_handler( + SaveBadMessageCallback(&bindings, &bad_message_callback_a)); + base::RunLoop loop; + ping_a->Ping(loop.QuitClosure()); + loop.Run(); + ping_a.reset(); + } + + { + impl.set_ping_handler( + SaveBadMessageCallback(&bindings, &bad_message_callback_b)); + base::RunLoop loop; + ping_b->Ping(loop.QuitClosure()); + loop.Run(); + } + + std::move(bad_message_callback_a).Run("message 1"); + EXPECT_EQ("message 1", last_received_error); + + { + base::RunLoop loop; + ping_b.set_connection_error_handler(loop.QuitClosure()); + std::move(bad_message_callback_b).Run("message 2"); + EXPECT_EQ("message 2", last_received_error); + loop.Run(); + } + + EXPECT_TRUE(bindings.empty()); + + core::SetDefaultProcessErrorCallback(mojo::core::ProcessErrorCallback()); +} + +TEST_P(BindingSetTest, BindingSetGetBadMessageCallbackOutlivesBindingSet) { + PingImpl impl; + + std::string last_received_error; + core::SetDefaultProcessErrorCallback( + base::Bind([](std::string* out_error, + const std::string& error) { *out_error = error; }, + &last_received_error)); + + ReportBadMessageCallback bad_message_callback; + { + BindingSet<PingService, int> bindings; + PingServicePtr ping_a; + bindings.AddBinding(&impl, MakeRequest(&ping_a), 1); + + impl.set_ping_handler( + SaveBadMessageCallback(&bindings, &bad_message_callback)); + base::RunLoop loop; + ping_a->Ping(loop.QuitClosure()); + loop.Run(); + } + + std::move(bad_message_callback).Run("message 1"); + EXPECT_EQ("message 1", last_received_error); + + core::SetDefaultProcessErrorCallback(mojo::core::ProcessErrorCallback()); +} + class PingProviderImpl : public AssociatedPingProvider, public PingService { public: PingProviderImpl() {} @@ -174,7 +359,7 @@ class PingProviderImpl : public AssociatedPingProvider, public PingService { base::Closure new_ping_handler_; }; -TEST_F(BindingSetTest, AssociatedBindingSetContext) { +TEST_P(BindingSetTest, AssociatedBindingSetContext) { AssociatedPingProviderPtr provider; PingProviderImpl impl; Binding<AssociatedPingProvider> binding(&impl, MakeRequest(&provider)); @@ -230,7 +415,7 @@ TEST_F(BindingSetTest, AssociatedBindingSetContext) { EXPECT_TRUE(impl.ping_bindings().empty()); } -TEST_F(BindingSetTest, MasterInterfaceBindingSetContext) { +TEST_P(BindingSetTest, MasterInterfaceBindingSetContext) { AssociatedPingProviderPtr provider_a, provider_b; PingProviderImpl impl; BindingSet<AssociatedPingProvider, int> bindings; @@ -275,7 +460,52 @@ TEST_F(BindingSetTest, MasterInterfaceBindingSetContext) { EXPECT_TRUE(bindings.empty()); } -TEST_F(BindingSetTest, PreDispatchHandler) { +TEST_P(BindingSetTest, MasterInterfaceBindingSetDispatchBinding) { + AssociatedPingProviderPtr provider_a, provider_b; + PingProviderImpl impl; + BindingSet<AssociatedPingProvider, int> bindings; + + BindingId id_a = bindings.AddBinding(&impl, MakeRequest(&provider_a), 1); + BindingId id_b = bindings.AddBinding(&impl, MakeRequest(&provider_b), 2); + + { + PingServiceAssociatedPtr ping; + base::RunLoop loop; + impl.set_new_ping_handler( + Sequence(ExpectBindingId(&bindings, id_a), loop.QuitClosure())); + provider_a->GetPing(MakeRequest(&ping)); + loop.Run(); + } + + { + PingServiceAssociatedPtr ping; + base::RunLoop loop; + impl.set_new_ping_handler( + Sequence(ExpectBindingId(&bindings, id_b), loop.QuitClosure())); + provider_b->GetPing(MakeRequest(&ping)); + loop.Run(); + } + + { + base::RunLoop loop; + bindings.set_connection_error_handler( + Sequence(ExpectBindingId(&bindings, id_a), loop.QuitClosure())); + provider_a.reset(); + loop.Run(); + } + + { + base::RunLoop loop; + bindings.set_connection_error_handler( + Sequence(ExpectBindingId(&bindings, id_b), loop.QuitClosure())); + provider_b.reset(); + loop.Run(); + } + + EXPECT_TRUE(bindings.empty()); +} + +TEST_P(BindingSetTest, PreDispatchHandler) { PingImpl impl; BindingSet<PingService, int> bindings; @@ -326,10 +556,11 @@ TEST_F(BindingSetTest, PreDispatchHandler) { EXPECT_TRUE(bindings.empty()); } -TEST_F(BindingSetTest, AssociatedBindingSetConnectionErrorWithReason) { +TEST_P(BindingSetTest, AssociatedBindingSetConnectionErrorWithReason) { AssociatedPingProviderPtr master_ptr; PingProviderImpl master_impl; - Binding<AssociatedPingProvider> master_binding(&master_impl, &master_ptr); + Binding<AssociatedPingProvider> master_binding(&master_impl, + MakeRequest(&master_ptr)); base::RunLoop run_loop; master_impl.ping_bindings().set_connection_error_with_reason_handler( @@ -361,15 +592,15 @@ class PingInstanceCounter : public PingService { }; int PingInstanceCounter::instance_count = 0; -TEST_F(BindingSetTest, StrongBinding_Destructor) { +TEST_P(BindingSetTest, StrongBinding_Destructor) { PingServicePtr ping_a, ping_b; - auto bindings = base::MakeUnique<StrongBindingSet<PingService>>(); + auto bindings = std::make_unique<StrongBindingSet<PingService>>(); - bindings->AddBinding(base::MakeUnique<PingInstanceCounter>(), + bindings->AddBinding(std::make_unique<PingInstanceCounter>(), mojo::MakeRequest(&ping_a)); EXPECT_EQ(1, PingInstanceCounter::instance_count); - bindings->AddBinding(base::MakeUnique<PingInstanceCounter>(), + bindings->AddBinding(std::make_unique<PingInstanceCounter>(), mojo::MakeRequest(&ping_b)); EXPECT_EQ(2, PingInstanceCounter::instance_count); @@ -377,12 +608,12 @@ TEST_F(BindingSetTest, StrongBinding_Destructor) { EXPECT_EQ(0, PingInstanceCounter::instance_count); } -TEST_F(BindingSetTest, StrongBinding_ConnectionError) { +TEST_P(BindingSetTest, StrongBinding_ConnectionError) { PingServicePtr ping_a, ping_b; StrongBindingSet<PingService> bindings; - bindings.AddBinding(base::MakeUnique<PingInstanceCounter>(), + bindings.AddBinding(std::make_unique<PingInstanceCounter>(), mojo::MakeRequest(&ping_a)); - bindings.AddBinding(base::MakeUnique<PingInstanceCounter>(), + bindings.AddBinding(std::make_unique<PingInstanceCounter>(), mojo::MakeRequest(&ping_b)); EXPECT_EQ(2, PingInstanceCounter::instance_count); @@ -395,13 +626,13 @@ TEST_F(BindingSetTest, StrongBinding_ConnectionError) { EXPECT_EQ(0, PingInstanceCounter::instance_count); } -TEST_F(BindingSetTest, StrongBinding_RemoveBinding) { +TEST_P(BindingSetTest, StrongBinding_RemoveBinding) { PingServicePtr ping_a, ping_b; StrongBindingSet<PingService> bindings; BindingId binding_id_a = bindings.AddBinding( - base::MakeUnique<PingInstanceCounter>(), mojo::MakeRequest(&ping_a)); + std::make_unique<PingInstanceCounter>(), mojo::MakeRequest(&ping_a)); BindingId binding_id_b = bindings.AddBinding( - base::MakeUnique<PingInstanceCounter>(), mojo::MakeRequest(&ping_b)); + std::make_unique<PingInstanceCounter>(), mojo::MakeRequest(&ping_b)); EXPECT_EQ(2, PingInstanceCounter::instance_count); EXPECT_TRUE(bindings.RemoveBinding(binding_id_a)); @@ -411,6 +642,8 @@ TEST_F(BindingSetTest, StrongBinding_RemoveBinding) { EXPECT_EQ(0, PingInstanceCounter::instance_count); } +INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(BindingSetTest); + } // namespace } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/binding_unittest.cc b/mojo/public/cpp/bindings/tests/binding_unittest.cc index e76993bb68..33b6a5aa81 100644 --- a/mojo/public/cpp/bindings/tests/binding_unittest.cc +++ b/mojo/public/cpp/bindings/tests/binding_unittest.cc @@ -13,9 +13,10 @@ #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/memory/weak_ptr.h" -#include "base/message_loop/message_loop.h" #include "base/run_loop.h" +#include "mojo/core/embedder/embedder.h" #include "mojo/public/cpp/bindings/strong_binding.h" +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" #include "mojo/public/interfaces/bindings/tests/ping_service.mojom.h" #include "mojo/public/interfaces/bindings/tests/sample_interfaces.mojom.h" #include "mojo/public/interfaces/bindings/tests/sample_service.mojom.h" @@ -24,19 +25,6 @@ namespace mojo { namespace { -class BindingTestBase : public testing::Test { - public: - BindingTestBase() {} - ~BindingTestBase() override {} - - base::MessageLoop& loop() { return loop_; } - - private: - base::MessageLoop loop_; - - DISALLOW_COPY_AND_ASSIGN(BindingTestBase); -}; - class ServiceImpl : public sample::Service { public: explicit ServiceImpl(bool* was_deleted = nullptr) @@ -79,9 +67,9 @@ base::Callback<void(Args...)> SetFlagAndRunClosure( // BindingTest ----------------------------------------------------------------- -using BindingTest = BindingTestBase; +using BindingTest = BindingsTestBase; -TEST_F(BindingTest, Close) { +TEST_P(BindingTest, Close) { bool called = false; sample::ServicePtr ptr; auto request = MakeRequest(&ptr); @@ -98,7 +86,7 @@ TEST_F(BindingTest, Close) { } // Tests that destroying a mojo::Binding closes the bound message pipe handle. -TEST_F(BindingTest, DestroyClosesMessagePipe) { +TEST_P(BindingTest, DestroyClosesMessagePipe) { bool encountered_error = false; ServiceImpl impl; sample::ServicePtr ptr; @@ -133,7 +121,7 @@ TEST_F(BindingTest, DestroyClosesMessagePipe) { // Tests that the binding's connection error handler gets called when the other // end is closed. -TEST_F(BindingTest, ConnectionError) { +TEST_P(BindingTest, ConnectionError) { bool called = false; { ServiceImpl impl; @@ -154,7 +142,7 @@ TEST_F(BindingTest, ConnectionError) { // Tests that calling Close doesn't result in the connection error handler being // called. -TEST_F(BindingTest, CloseDoesntCallConnectionErrorHandler) { +TEST_P(BindingTest, CloseDoesntCallConnectionErrorHandler) { ServiceImpl impl; sample::ServicePtr ptr; Binding<sample::Service> binding(&impl, MakeRequest(&ptr)); @@ -198,7 +186,7 @@ class ServiceImplWithBinding : public ServiceImpl { }; // Tests that the binding may be deleted in the connection error handler. -TEST_F(BindingTest, SelfDeleteOnConnectionError) { +TEST_P(BindingTest, SelfDeleteOnConnectionError) { bool was_deleted = false; sample::ServicePtr ptr; // This should delete itself on connection error. @@ -212,7 +200,7 @@ TEST_F(BindingTest, SelfDeleteOnConnectionError) { } // Tests that explicitly calling Unbind followed by rebinding works. -TEST_F(BindingTest, Unbind) { +TEST_P(BindingTest, Unbind) { ServiceImpl impl; sample::ServicePtr ptr; Binding<sample::Service> binding(&impl, MakeRequest(&ptr)); @@ -262,14 +250,7 @@ class IntegerAccessorImpl : public sample::IntegerAccessor { DISALLOW_COPY_AND_ASSIGN(IntegerAccessorImpl); }; -TEST_F(BindingTest, SetInterfacePtrVersion) { - IntegerAccessorImpl impl; - sample::IntegerAccessorPtr ptr; - Binding<sample::IntegerAccessor> binding(&impl, &ptr); - EXPECT_EQ(3u, ptr.version()); -} - -TEST_F(BindingTest, PauseResume) { +TEST_P(BindingTest, PauseResume) { bool called = false; base::RunLoop run_loop; sample::ServicePtr ptr; @@ -292,7 +273,7 @@ TEST_F(BindingTest, PauseResume) { } // Verifies the connection error handler is not run while a binding is paused. -TEST_F(BindingTest, ErrorHandleNotRunWhilePaused) { +TEST_P(BindingTest, ErrorHandleNotRunWhilePaused) { bool called = false; base::RunLoop run_loop; sample::ServicePtr ptr; @@ -343,7 +324,7 @@ class CallbackFilter : public MessageReceiver { ~CallbackFilter() override {} static std::unique_ptr<CallbackFilter> Wrap(const base::Closure& callback) { - return base::MakeUnique<CallbackFilter>(callback); + return std::make_unique<CallbackFilter>(callback); } // MessageReceiver: @@ -358,7 +339,7 @@ class CallbackFilter : public MessageReceiver { // Verifies that message filters are notified in the order they were added and // are always notified before a message is dispatched. -TEST_F(BindingTest, MessageFilter) { +TEST_P(BindingTest, MessageFilter) { test::PingServicePtr ptr; PingServiceImpl impl; mojo::Binding<test::PingService> binding(&impl, MakeRequest(&ptr)); @@ -389,7 +370,7 @@ void Fail() { FAIL() << "Unexpected connection error"; } -TEST_F(BindingTest, FlushForTesting) { +TEST_P(BindingTest, FlushForTesting) { bool called = false; sample::ServicePtr ptr; auto request = MakeRequest(&ptr); @@ -408,7 +389,7 @@ TEST_F(BindingTest, FlushForTesting) { EXPECT_TRUE(called); } -TEST_F(BindingTest, FlushForTestingWithClosedPeer) { +TEST_P(BindingTest, FlushForTestingWithClosedPeer) { bool called = false; sample::ServicePtr ptr; auto request = MakeRequest(&ptr); @@ -423,7 +404,7 @@ TEST_F(BindingTest, FlushForTestingWithClosedPeer) { binding.FlushForTesting(); } -TEST_F(BindingTest, ConnectionErrorWithReason) { +TEST_P(BindingTest, ConnectionErrorWithReason) { sample::ServicePtr ptr; auto request = MakeRequest(&ptr); ServiceImpl impl; @@ -455,7 +436,7 @@ struct WeakPtrImplRefTraits { template <typename T> using WeakBinding = Binding<T, WeakPtrImplRefTraits<T>>; -TEST_F(BindingTest, CustomImplPointerType) { +TEST_P(BindingTest, CustomImplPointerType) { PingServiceImpl impl; base::WeakPtrFactory<test::PingService> weak_factory(&impl); @@ -485,13 +466,76 @@ TEST_F(BindingTest, CustomImplPointerType) { } } +TEST_P(BindingTest, ReportBadMessage) { + bool called = false; + test::PingServicePtr ptr; + auto request = MakeRequest(&ptr); + base::RunLoop run_loop; + ptr.set_connection_error_handler( + SetFlagAndRunClosure(&called, run_loop.QuitClosure())); + PingServiceImpl impl; + Binding<test::PingService> binding(&impl, std::move(request)); + impl.set_ping_handler(base::Bind( + [](Binding<test::PingService>* binding) { + binding->ReportBadMessage("received bad message"); + }, + &binding)); + + std::string received_error; + core::SetDefaultProcessErrorCallback( + base::Bind([](std::string* out_error, + const std::string& error) { *out_error = error; }, + &received_error)); + + ptr->Ping(base::Bind([] {})); + EXPECT_FALSE(called); + run_loop.Run(); + EXPECT_TRUE(called); + EXPECT_EQ("received bad message", received_error); + + core::SetDefaultProcessErrorCallback(mojo::core::ProcessErrorCallback()); +} + +TEST_P(BindingTest, GetBadMessageCallback) { + test::PingServicePtr ptr; + auto request = MakeRequest(&ptr); + base::RunLoop run_loop; + PingServiceImpl impl; + ReportBadMessageCallback bad_message_callback; + + std::string received_error; + core::SetDefaultProcessErrorCallback( + base::Bind([](std::string* out_error, + const std::string& error) { *out_error = error; }, + &received_error)); + + { + Binding<test::PingService> binding(&impl, std::move(request)); + impl.set_ping_handler(base::Bind( + [](Binding<test::PingService>* binding, + ReportBadMessageCallback* out_callback) { + *out_callback = binding->GetBadMessageCallback(); + }, + &binding, &bad_message_callback)); + ptr->Ping(run_loop.QuitClosure()); + run_loop.Run(); + EXPECT_TRUE(received_error.empty()); + EXPECT_TRUE(bad_message_callback); + } + + std::move(bad_message_callback).Run("delayed bad message"); + EXPECT_EQ("delayed bad message", received_error); + + core::SetDefaultProcessErrorCallback(mojo::core::ProcessErrorCallback()); +} + // StrongBindingTest ----------------------------------------------------------- -using StrongBindingTest = BindingTestBase; +using StrongBindingTest = BindingsTestBase; // Tests that destroying a mojo::StrongBinding closes the bound message pipe // handle but does *not* destroy the implementation object. -TEST_F(StrongBindingTest, DestroyClosesMessagePipe) { +TEST_P(StrongBindingTest, DestroyClosesMessagePipe) { base::RunLoop run_loop; bool encountered_error = false; bool was_deleted = false; @@ -502,7 +546,7 @@ TEST_F(StrongBindingTest, DestroyClosesMessagePipe) { bool called = false; base::RunLoop run_loop2; - auto binding = MakeStrongBinding(base::MakeUnique<ServiceImpl>(&was_deleted), + auto binding = MakeStrongBinding(std::make_unique<ServiceImpl>(&was_deleted), std::move(request)); ptr->Frobinate( nullptr, sample::Service::BazOptions::REGULAR, nullptr, @@ -523,7 +567,7 @@ TEST_F(StrongBindingTest, DestroyClosesMessagePipe) { // Tests the typical case, where the implementation object owns the // StrongBinding (and should be destroyed on connection error). -TEST_F(StrongBindingTest, ConnectionErrorDestroysImpl) { +TEST_P(StrongBindingTest, ConnectionErrorDestroysImpl) { sample::ServicePtr ptr; bool was_deleted = false; // Will delete itself. @@ -540,12 +584,12 @@ TEST_F(StrongBindingTest, ConnectionErrorDestroysImpl) { EXPECT_TRUE(was_deleted); } -TEST_F(StrongBindingTest, FlushForTesting) { +TEST_P(StrongBindingTest, FlushForTesting) { bool called = false; bool was_deleted = false; sample::ServicePtr ptr; auto request = MakeRequest(&ptr); - auto binding = MakeStrongBinding(base::MakeUnique<ServiceImpl>(&was_deleted), + auto binding = MakeStrongBinding(std::make_unique<ServiceImpl>(&was_deleted), std::move(request)); binding->set_connection_error_handler(base::Bind(&Fail)); @@ -568,12 +612,12 @@ TEST_F(StrongBindingTest, FlushForTesting) { EXPECT_TRUE(was_deleted); } -TEST_F(StrongBindingTest, FlushForTestingWithClosedPeer) { +TEST_P(StrongBindingTest, FlushForTestingWithClosedPeer) { bool called = false; bool was_deleted = false; sample::ServicePtr ptr; auto request = MakeRequest(&ptr); - auto binding = MakeStrongBinding(base::MakeUnique<ServiceImpl>(&was_deleted), + auto binding = MakeStrongBinding(std::make_unique<ServiceImpl>(&was_deleted), std::move(request)); binding->set_connection_error_handler(SetFlagAndRunClosure(&called)); ptr.reset(); @@ -587,11 +631,11 @@ TEST_F(StrongBindingTest, FlushForTestingWithClosedPeer) { ASSERT_FALSE(binding); } -TEST_F(StrongBindingTest, ConnectionErrorWithReason) { +TEST_P(StrongBindingTest, ConnectionErrorWithReason) { sample::ServicePtr ptr; auto request = MakeRequest(&ptr); auto binding = - MakeStrongBinding(base::MakeUnique<ServiceImpl>(), std::move(request)); + MakeStrongBinding(std::make_unique<ServiceImpl>(), std::move(request)); base::RunLoop run_loop; binding->set_connection_error_with_reason_handler(base::Bind( [](const base::Closure& quit_closure, uint32_t custom_reason, @@ -607,5 +651,8 @@ TEST_F(StrongBindingTest, ConnectionErrorWithReason) { run_loop.Run(); } +INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(BindingTest); +INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(StrongBindingTest); + } // namespace } // mojo diff --git a/mojo/public/cpp/bindings/tests/bindings_perftest.cc b/mojo/public/cpp/bindings/tests/bindings_perftest.cc index 65b3c8c1d4..c54d7ca6a9 100644 --- a/mojo/public/cpp/bindings/tests/bindings_perftest.cc +++ b/mojo/public/cpp/bindings/tests/bindings_perftest.cc @@ -12,7 +12,6 @@ #include "base/time/time.h" #include "mojo/public/cpp/bindings/binding.h" #include "mojo/public/cpp/bindings/interface_endpoint_client.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" #include "mojo/public/cpp/bindings/lib/multiplex_router.h" #include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/test_support/test_support.h" @@ -154,8 +153,8 @@ class PingPongPaddle : public MessageReceiverWithResponderStatus { } } - internal::MessageBuilder builder(count, 0, 8, 0); - bool result = sender_->Accept(builder.message()); + Message reply(count, 0, 0, 0, nullptr); + bool result = sender_->Accept(&reply); DCHECK(result); return true; } @@ -174,8 +173,8 @@ class PingPongPaddle : public MessageReceiverWithResponderStatus { quit_closure_ = run_loop.QuitClosure(); start_time_ = base::TimeTicks::Now(); - internal::MessageBuilder builder(0, 0, 8, 0); - bool result = sender_->Accept(builder.message()); + Message message(0, 0, 0, 0, nullptr); + bool result = sender_->Accept(&message); DCHECK(result); run_loop.Run(); @@ -264,9 +263,8 @@ TEST_F(MojoBindingsPerftest, MultiplexRouterDispatchCost) { receiver.Reset(); base::TimeTicks start_time = base::TimeTicks::Now(); for (size_t j = 0; j < kIterations[i]; ++j) { - internal::MessageBuilder builder(0, 0, 8, 0); - bool result = - router->SimulateReceivingMessageForTesting(builder.message()); + Message message(0, 0, 8, 0, nullptr); + bool result = router->SimulateReceivingMessageForTesting(&message); DCHECK(result); } diff --git a/mojo/public/cpp/bindings/tests/bindings_test_base.cc b/mojo/public/cpp/bindings/tests/bindings_test_base.cc new file mode 100644 index 0000000000..02430475f6 --- /dev/null +++ b/mojo/public/cpp/bindings/tests/bindings_test_base.cc @@ -0,0 +1,40 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" + +#include "mojo/public/cpp/bindings/connector.h" + +namespace mojo { + +BindingsTestBase::BindingsTestBase() { + SetupSerializationBehavior(GetParam()); +} + +BindingsTestBase::~BindingsTestBase() = default; + +// static +void BindingsTestBase::SetupSerializationBehavior( + BindingsTestSerializationMode mode) { + switch (mode) { + case BindingsTestSerializationMode::kSerializeBeforeSend: + Connector::OverrideDefaultSerializationBehaviorForTesting( + Connector::OutgoingSerializationMode::kEager, + Connector::IncomingSerializationMode::kDispatchAsIs); + break; + case BindingsTestSerializationMode::kSerializeBeforeDispatch: + Connector::OverrideDefaultSerializationBehaviorForTesting( + Connector::OutgoingSerializationMode::kLazy, + Connector::IncomingSerializationMode :: + kSerializeBeforeDispatchForTesting); + break; + case BindingsTestSerializationMode::kNeverSerialize: + Connector::OverrideDefaultSerializationBehaviorForTesting( + Connector::OutgoingSerializationMode::kLazy, + Connector::IncomingSerializationMode::kDispatchAsIs); + break; + } +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/bindings_test_base.h b/mojo/public/cpp/bindings/tests/bindings_test_base.h new file mode 100644 index 0000000000..5f5c779579 --- /dev/null +++ b/mojo/public/cpp/bindings/tests/bindings_test_base.h @@ -0,0 +1,51 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_TESTS_BINDINGS_TEST_BASE_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_TESTS_BINDINGS_TEST_BASE_H_ + +#include "base/test/scoped_task_environment.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace mojo { + +// Used to parameterize tests which inherit from BindingsTestBase to exercise +// various message serialization-related code paths for intra-process bindings +// usage. +enum class BindingsTestSerializationMode { + // Messages should be serialized immediately before sending. + kSerializeBeforeSend, + + // Messages should be serialized immediately before dispatching. + kSerializeBeforeDispatch, + + // Messages should never be serialized. + kNeverSerialize, +}; + +class BindingsTestBase + : public testing::Test, + public testing::WithParamInterface<BindingsTestSerializationMode> { + public: + BindingsTestBase(); + ~BindingsTestBase(); + + // Helper which other test fixtures can use. + static void SetupSerializationBehavior(BindingsTestSerializationMode mode); + + private: + base::test::ScopedTaskEnvironment task_environment_; +}; + +} // namespace mojo + +#define INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(fixture) \ + INSTANTIATE_TEST_CASE_P( \ + , fixture, \ + testing::Values( \ + mojo::BindingsTestSerializationMode::kSerializeBeforeSend, \ + mojo::BindingsTestSerializationMode::kSerializeBeforeDispatch, \ + mojo::BindingsTestSerializationMode::kNeverSerialize)) + +#endif // MOJO_PUBLIC_CPP_BINDINGS_TESTS_BINDINGS_TEST_BASE_H_ diff --git a/mojo/public/cpp/bindings/tests/buffer_unittest.cc b/mojo/public/cpp/bindings/tests/buffer_unittest.cc index d75bdd0785..2370454bff 100644 --- a/mojo/public/cpp/bindings/tests/buffer_unittest.cc +++ b/mojo/public/cpp/bindings/tests/buffer_unittest.cc @@ -14,80 +14,20 @@ namespace mojo { namespace test { namespace { -bool IsZero(void* p_buf, size_t size) { - char* buf = reinterpret_cast<char*>(p_buf); - for (size_t i = 0; i < size; ++i) { - if (buf[i] != 0) - return false; - } - return true; -} - // Tests that FixedBuffer allocates memory aligned to 8 byte boundaries. TEST(FixedBufferTest, Alignment) { internal::FixedBufferForTesting buf(internal::Align(10) * 2); ASSERT_EQ(buf.size(), 16u * 2); - void* a = buf.Allocate(10); - ASSERT_TRUE(a); - EXPECT_TRUE(IsZero(a, 10)); - EXPECT_EQ(0, reinterpret_cast<ptrdiff_t>(a) % 8); + size_t a = buf.Allocate(10); + EXPECT_EQ(0u, a); - void* b = buf.Allocate(10); - ASSERT_TRUE(b); - EXPECT_TRUE(IsZero(b, 10)); - EXPECT_EQ(0, reinterpret_cast<ptrdiff_t>(b) % 8); + size_t b = buf.Allocate(10); + ASSERT_EQ(16u, b); // Any more allocations would result in an assert, but we can't test that. } -// Tests that FixedBufferForTesting::Leak passes ownership to the caller. -TEST(FixedBufferTest, Leak) { - void* ptr = nullptr; - void* buf_ptr = nullptr; - { - internal::FixedBufferForTesting buf(8); - ASSERT_EQ(8u, buf.size()); - - ptr = buf.Allocate(8); - ASSERT_TRUE(ptr); - buf_ptr = buf.Leak(); - - // The buffer should point to the first element allocated. - // TODO(mpcomplete): Is this a reasonable expectation? - EXPECT_EQ(ptr, buf_ptr); - - // The FixedBufferForTesting should be empty now. - EXPECT_EQ(0u, buf.size()); - EXPECT_FALSE(buf.Leak()); - } - - // Since we called Leak, ptr is still writable after FixedBufferForTesting - // went out of scope. - memset(ptr, 1, 8); - free(buf_ptr); -} - -#if defined(NDEBUG) && !defined(DCHECK_ALWAYS_ON) -TEST(FixedBufferTest, TooBig) { - internal::FixedBufferForTesting buf(24); - - // A little bit too large. - EXPECT_EQ(reinterpret_cast<void*>(0), buf.Allocate(32)); - - // Move the cursor forward. - EXPECT_NE(reinterpret_cast<void*>(0), buf.Allocate(16)); - - // A lot too large. - EXPECT_EQ(reinterpret_cast<void*>(0), - buf.Allocate(std::numeric_limits<size_t>::max() - 1024u)); - - // A lot too large, leading to possible integer overflow. - EXPECT_EQ(reinterpret_cast<void*>(0), - buf.Allocate(std::numeric_limits<size_t>::max() - 8u)); -} -#endif - } // namespace } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/callback_helpers_unittest.cc b/mojo/public/cpp/bindings/tests/callback_helpers_unittest.cc new file mode 100644 index 0000000000..fbb0dc7e1b --- /dev/null +++ b/mojo/public/cpp/bindings/tests/callback_helpers_unittest.cc @@ -0,0 +1,202 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/callback_helpers.h" + +#include <memory> +#include <string> +#include <utility> + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback.h" +#include "base/memory/ptr_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace mojo { + +namespace { + +void SetBool(bool* var, bool val) { + *var = val; +} + +void SetBoolFromRawPtr(bool* var, bool* val) { + *var = *val; +} + +void SetIntegers(int* a_var, int* b_var, int a_val, int b_val) { + *a_var = a_val; + *b_var = b_val; +} + +void SetIntegerFromUniquePtr(int* var, std::unique_ptr<int> val) { + *var = *val; +} + +void SetString(std::string* var, const std::string val) { + *var = val; +} + +void CallClosure(base::OnceClosure cl) { + std::move(cl).Run(); +} + +} // namespace + +TEST(CallbackWithDeleteTest, SetIntegers_Run) { + int a = 0; + int b = 0; + auto cb = + WrapCallbackWithDropHandler(base::BindOnce(&SetIntegers, &a, &b), + base::BindOnce(&SetIntegers, &a, &b, 3, 4)); + std::move(cb).Run(1, 2); + EXPECT_EQ(a, 1); + EXPECT_EQ(b, 2); +} + +TEST(CallbackWithDeleteTest, SetIntegers_Destruction) { + int a = 0; + int b = 0; + { + auto cb = + WrapCallbackWithDropHandler(base::BindOnce(&SetIntegers, &a, &b), + base::BindOnce(&SetIntegers, &a, &b, 3, 4)); + } + EXPECT_EQ(a, 3); + EXPECT_EQ(b, 4); +} + +TEST(CallbackWithDefaultTest, CallClosure_Run) { + int a = 0; + int b = 0; + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&CallClosure), base::BindOnce(&SetIntegers, &a, &b, 3, 4)); + std::move(cb).Run(base::BindOnce(&SetIntegers, &a, &b, 1, 2)); + EXPECT_EQ(a, 1); + EXPECT_EQ(b, 2); +} + +TEST(CallbackWithDefaultTest, CallClosure_Destruction) { + int a = 0; + int b = 0; + { + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&CallClosure), + base::BindOnce(&SetIntegers, &a, &b, 3, 4)); + } + EXPECT_EQ(a, 3); + EXPECT_EQ(b, 4); +} + +TEST(CallbackWithDefaultTest, Closure_Run) { + bool a = false; + auto cb = + WrapCallbackWithDefaultInvokeIfNotRun(base::BindOnce(&SetBool, &a, true)); + std::move(cb).Run(); + EXPECT_TRUE(a); +} + +TEST(CallbackWithDefaultTest, Closure_Destruction) { + bool a = false; + { + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetBool, &a, true)); + } + EXPECT_TRUE(a); +} + +TEST(CallbackWithDefaultTest, SetBool_Run) { + bool a = false; + auto cb = + WrapCallbackWithDefaultInvokeIfNotRun(base::BindOnce(&SetBool, &a), true); + std::move(cb).Run(true); + EXPECT_TRUE(a); +} + +TEST(CallbackWithDefaultTest, SetBoolFromRawPtr_Run) { + bool a = false; + bool* b = new bool(false); + bool c = true; + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetBoolFromRawPtr, &a), base::Owned(b)); + std::move(cb).Run(&c); + EXPECT_TRUE(a); +} + +TEST(CallbackWithDefaultTest, SetBoolFromRawPtr_Destruction) { + bool a = false; + bool* b = new bool(true); + { + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetBoolFromRawPtr, &a), base::Owned(b)); + } + EXPECT_TRUE(a); +} + +TEST(CallbackWithDefaultTest, SetBool_Destruction) { + bool a = false; + { + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetBool, &a), true); + } + EXPECT_TRUE(a); +} + +TEST(CallbackWithDefaultTest, SetIntegers_Run) { + int a = 0; + int b = 0; + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetIntegers, &a, &b), 3, 4); + std::move(cb).Run(1, 2); + EXPECT_EQ(a, 1); + EXPECT_EQ(b, 2); +} + +TEST(CallbackWithDefaultTest, SetIntegers_Destruction) { + int a = 0; + int b = 0; + { + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetIntegers, &a, &b), 3, 4); + } + EXPECT_EQ(a, 3); + EXPECT_EQ(b, 4); +} + +TEST(CallbackWithDefaultTest, SetIntegerFromUniquePtr_Run) { + int a = 0; + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetIntegerFromUniquePtr, &a), std::make_unique<int>(1)); + std::move(cb).Run(std::make_unique<int>(2)); + EXPECT_EQ(a, 2); +} + +TEST(CallbackWithDefaultTest, SetIntegerFromUniquePtr_Destruction) { + int a = 0; + { + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetIntegerFromUniquePtr, &a), std::make_unique<int>(1)); + } + EXPECT_EQ(a, 1); +} + +TEST(CallbackWithDefaultTest, SetString_Run) { + std::string a; + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetString, &a), "hello"); + std::move(cb).Run("world"); + EXPECT_EQ(a, "world"); +} + +TEST(CallbackWithDefaultTest, SetString_Destruction) { + std::string a; + { + auto cb = WrapCallbackWithDefaultInvokeIfNotRun( + base::BindOnce(&SetString, &a), "hello"); + } + EXPECT_EQ(a, "hello"); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/connector_unittest.cc b/mojo/public/cpp/bindings/tests/connector_unittest.cc index 74ecb7a9ee..68f51fc5bd 100644 --- a/mojo/public/cpp/bindings/tests/connector_unittest.cc +++ b/mojo/public/cpp/bindings/tests/connector_unittest.cc @@ -16,7 +16,7 @@ #include "base/run_loop.h" #include "base/threading/thread.h" #include "base/threading/thread_task_runner_handle.h" -#include "mojo/public/cpp/bindings/lib/message_builder.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/bindings/tests/message_queue.h" #include "testing/gtest/include/gtest/gtest.h" @@ -91,18 +91,17 @@ class ConnectorTest : public testing::Test { public: ConnectorTest() {} - void SetUp() override { - CreateMessagePipe(nullptr, &handle0_, &handle1_); - } + void SetUp() override { CreateMessagePipe(nullptr, &handle0_, &handle1_); } void TearDown() override {} - void AllocMessage(const char* text, Message* message) { - size_t payload_size = strlen(text) + 1; // Plus null terminator. - internal::MessageBuilder builder(1, 0, payload_size, 0); - memcpy(builder.buffer()->Allocate(payload_size), text, payload_size); - - *message = std::move(*builder.message()); + Message CreateMessage( + const char* text, + std::vector<ScopedHandle> handles = std::vector<ScopedHandle>()) { + const size_t size = strlen(text) + 1; // Plus null terminator. + Message message(1, 0, size, 0, &handles); + memcpy(message.payload_buffer()->AllocateAndGet(size), text, size); + return message; } protected: @@ -120,10 +119,7 @@ TEST_F(ConnectorTest, Basic) { base::ThreadTaskRunnerHandle::Get()); const char kText[] = "hello world"; - - Message message; - AllocMessage(kText, &message); - + Message message = CreateMessage(kText); connector0.Accept(&message); base::RunLoop run_loop; @@ -149,10 +145,7 @@ TEST_F(ConnectorTest, Basic_Synchronous) { base::ThreadTaskRunnerHandle::Get()); const char kText[] = "hello world"; - - Message message; - AllocMessage(kText, &message); - + Message message = CreateMessage(kText); connector0.Accept(&message); MessageAccumulator accumulator; @@ -181,10 +174,7 @@ TEST_F(ConnectorTest, Basic_EarlyIncomingReceiver) { connector1.set_incoming_receiver(&accumulator); const char kText[] = "hello world"; - - Message message; - AllocMessage(kText, &message); - + Message message = CreateMessage(kText); connector0.Accept(&message); run_loop.Run(); @@ -206,11 +196,8 @@ TEST_F(ConnectorTest, Basic_TwoMessages) { base::ThreadTaskRunnerHandle::Get()); const char* kText[] = {"hello", "world"}; - for (size_t i = 0; i < arraysize(kText); ++i) { - Message message; - AllocMessage(kText[i], &message); - + Message message = CreateMessage(kText[i]); connector0.Accept(&message); } @@ -241,11 +228,8 @@ TEST_F(ConnectorTest, Basic_TwoMessages_Synchronous) { base::ThreadTaskRunnerHandle::Get()); const char* kText[] = {"hello", "world"}; - for (size_t i = 0; i < arraysize(kText); ++i) { - Message message; - AllocMessage(kText[i], &message); - + Message message = CreateMessage(kText[i]); connector0.Accept(&message); } @@ -271,9 +255,7 @@ TEST_F(ConnectorTest, WriteToClosedPipe) { base::ThreadTaskRunnerHandle::Get()); const char kText[] = "hello world"; - - Message message; - AllocMessage(kText, &message); + Message message = CreateMessage(kText); // Close the other end of the pipe. handle1_.reset(); @@ -304,17 +286,13 @@ TEST_F(ConnectorTest, MessageWithHandles) { const char kText[] = "hello world"; - Message message1; - AllocMessage(kText, &message1); - MessagePipe pipe; - message1.mutable_handles()->push_back(pipe.handle0.release()); + std::vector<ScopedHandle> handles; + handles.emplace_back(ScopedHandle::From(std::move(pipe.handle0))); + Message message1 = CreateMessage(kText, std::move(handles)); connector0.Accept(&message1); - // The message should have been transferred, releasing the handles. - EXPECT_TRUE(message1.handles()->empty()); - base::RunLoop run_loop; MessageAccumulator accumulator(run_loop.QuitClosure()); connector1.set_incoming_receiver(&accumulator); @@ -333,21 +311,16 @@ TEST_F(ConnectorTest, MessageWithHandles) { // Now send a message to the transferred handle and confirm it's sent through // to the orginal pipe. - // TODO(vtl): Do we need a better way of "downcasting" the handle types? - ScopedMessagePipeHandle smph; - smph.reset(MessagePipeHandle(message_received.handles()->front().value())); - message_received.mutable_handles()->front() = Handle(); - // |smph| now owns this handle. - - Connector connector_received(std::move(smph), Connector::SINGLE_THREADED_SEND, + auto pipe_handle = ScopedMessagePipeHandle::From( + std::move(message_received.mutable_handles()->front())); + Connector connector_received(std::move(pipe_handle), + Connector::SINGLE_THREADED_SEND, base::ThreadTaskRunnerHandle::Get()); Connector connector_original(std::move(pipe.handle1), Connector::SINGLE_THREADED_SEND, base::ThreadTaskRunnerHandle::Get()); - Message message2; - AllocMessage(kText, &message2); - + Message message2 = CreateMessage(kText); connector_received.Accept(&message2); base::RunLoop run_loop2; MessageAccumulator accumulator2(run_loop2.QuitClosure()); @@ -379,10 +352,7 @@ TEST_F(ConnectorTest, WaitForIncomingMessageWithDeletion) { base::ThreadTaskRunnerHandle::Get()); const char kText[] = "hello world"; - - Message message; - AllocMessage(kText, &message); - + Message message = CreateMessage(kText); connector0.Accept(&message); ConnectorDeletingMessageAccumulator accumulator(&connector1); @@ -408,11 +378,8 @@ TEST_F(ConnectorTest, WaitForIncomingMessageWithReentrancy) { base::ThreadTaskRunnerHandle::Get()); const char* kText[] = {"hello", "world"}; - for (size_t i = 0; i < arraysize(kText); ++i) { - Message message; - AllocMessage(kText[i], &message); - + Message message = CreateMessage(kText[i]); connector0.Accept(&message); } @@ -448,22 +415,17 @@ TEST_F(ConnectorTest, RaiseError) { Connector connector0(std::move(handle0_), Connector::SINGLE_THREADED_SEND, base::ThreadTaskRunnerHandle::Get()); bool error_handler_called0 = false; - connector0.set_connection_error_handler( - base::Bind(&ForwardErrorHandler, &error_handler_called0, - run_loop.QuitClosure())); + connector0.set_connection_error_handler(base::Bind( + &ForwardErrorHandler, &error_handler_called0, run_loop.QuitClosure())); Connector connector1(std::move(handle1_), Connector::SINGLE_THREADED_SEND, base::ThreadTaskRunnerHandle::Get()); bool error_handler_called1 = false; - connector1.set_connection_error_handler( - base::Bind(&ForwardErrorHandler, &error_handler_called1, - run_loop2.QuitClosure())); + connector1.set_connection_error_handler(base::Bind( + &ForwardErrorHandler, &error_handler_called1, run_loop2.QuitClosure())); const char kText[] = "hello world"; - - Message message; - AllocMessage(kText, &message); - + Message message = CreateMessage(kText); connector0.Accept(&message); connector0.RaiseError(); @@ -514,18 +476,16 @@ TEST_F(ConnectorTest, PauseWithQueuedMessages) { const char kText[] = "hello world"; // Queue up two messages. - Message message; - AllocMessage(kText, &message); + Message message = CreateMessage(kText); connector0.Accept(&message); - AllocMessage(kText, &message); + message = CreateMessage(kText); connector0.Accept(&message); base::RunLoop run_loop; // Configure the accumulator such that it pauses after the first message is // received. - MessageAccumulator accumulator( - base::Bind(&PauseConnectorAndRunClosure, &connector1, - run_loop.QuitClosure())); + MessageAccumulator accumulator(base::Bind( + &PauseConnectorAndRunClosure, &connector1, run_loop.QuitClosure())); connector1.set_incoming_receiver(&accumulator); run_loop.Run(); @@ -537,9 +497,7 @@ TEST_F(ConnectorTest, PauseWithQueuedMessages) { void AccumulateWithNestedLoop(MessageAccumulator* accumulator, const base::Closure& closure) { - base::RunLoop nested_run_loop; - base::MessageLoop::ScopedNestableTaskAllower allow( - base::MessageLoop::current()); + base::RunLoop nested_run_loop(base::RunLoop::Type::kNestableTasksAllowed); accumulator->set_closure(nested_run_loop.QuitClosure()); nested_run_loop.Run(); closure.Run(); @@ -554,10 +512,9 @@ TEST_F(ConnectorTest, ProcessWhenNested) { const char kText[] = "hello world"; // Queue up two messages. - Message message; - AllocMessage(kText, &message); + Message message = CreateMessage(kText); connector0.Accept(&message); - AllocMessage(kText, &message); + message = CreateMessage(kText); connector0.Accept(&message); base::RunLoop run_loop; diff --git a/mojo/public/cpp/bindings/tests/constant_unittest.cc b/mojo/public/cpp/bindings/tests/constant_unittest.cc index caa6464cf4..12a8615e41 100644 --- a/mojo/public/cpp/bindings/tests/constant_unittest.cc +++ b/mojo/public/cpp/bindings/tests/constant_unittest.cc @@ -53,7 +53,7 @@ TEST(ConstantTest, InterfaceConstants) { EXPECT_EQ(base::StringPiece(InterfaceWithConstants::kStringValue), "interface test string contents"); EXPECT_EQ(base::StringPiece(InterfaceWithConstants::Name_), - "mojo::test::InterfaceWithConstants"); + "mojo.test.InterfaceWithConstants"); } } // namespace test diff --git a/mojo/public/cpp/bindings/tests/data_view_unittest.cc b/mojo/public/cpp/bindings/tests/data_view_unittest.cc index 0ebfda5d12..552bc6c86f 100644 --- a/mojo/public/cpp/bindings/tests/data_view_unittest.cc +++ b/mojo/public/cpp/bindings/tests/data_view_unittest.cc @@ -26,22 +26,18 @@ class DataViewTest : public testing::Test { struct DataViewHolder { std::unique_ptr<TestStructDataView> data_view; - std::unique_ptr<mojo::internal::FixedBufferForTesting> buf; + mojo::Message message; mojo::internal::SerializationContext context; }; std::unique_ptr<DataViewHolder> SerializeTestStruct(TestStructPtr input) { - std::unique_ptr<DataViewHolder> result(new DataViewHolder); - - size_t size = mojo::internal::PrepareToSerialize<TestStructDataView>( - input, &result->context); - - result->buf.reset(new mojo::internal::FixedBufferForTesting(size)); - internal::TestStruct_Data* data = nullptr; - mojo::internal::Serialize<TestStructDataView>(input, result->buf.get(), &data, - &result->context); - - result->data_view.reset(new TestStructDataView(data, &result->context)); + auto result = std::make_unique<DataViewHolder>(); + result->message = Message(0, 0, 0, 0, nullptr); + internal::TestStruct_Data::BufferWriter writer; + mojo::internal::Serialize<TestStructDataView>( + input, result->message.payload_buffer(), &writer, &result->context); + result->data_view = + std::make_unique<TestStructDataView>(writer.data(), &result->context); return result; } @@ -94,21 +90,24 @@ TEST_F(DataViewTest, NestedStruct) { TEST_F(DataViewTest, NativeStruct) { TestStructPtr obj(TestStruct::New()); - obj->f_native_struct = NativeStruct::New(); + obj->f_native_struct = native::NativeStruct::New(); obj->f_native_struct->data = std::vector<uint8_t>({3, 2, 1}); auto data_view_holder = SerializeTestStruct(std::move(obj)); auto& data_view = *data_view_holder->data_view; - NativeStructDataView struct_data_view; + native::NativeStructDataView struct_data_view; data_view.GetFNativeStructDataView(&struct_data_view); - ASSERT_FALSE(struct_data_view.is_null()); - ASSERT_EQ(3u, struct_data_view.size()); - EXPECT_EQ(3, struct_data_view[0]); - EXPECT_EQ(2, struct_data_view[1]); - EXPECT_EQ(1, struct_data_view[2]); - EXPECT_EQ(3, *struct_data_view.data()); + ArrayDataView<uint8_t> data_data_view; + struct_data_view.GetDataDataView(&data_data_view); + + ASSERT_FALSE(data_data_view.is_null()); + ASSERT_EQ(3u, data_data_view.size()); + EXPECT_EQ(3, data_data_view[0]); + EXPECT_EQ(2, data_data_view[1]); + EXPECT_EQ(1, data_data_view[2]); + EXPECT_EQ(3, *data_data_view.data()); } TEST_F(DataViewTest, BoolArray) { @@ -166,11 +165,11 @@ TEST_F(DataViewTest, EnumArray) { } TEST_F(DataViewTest, InterfaceArray) { - TestInterfacePtr ptr; - TestInterfaceImpl impl(MakeRequest(&ptr)); + TestInterfacePtrInfo ptr_info; + TestInterfaceImpl impl(MakeRequest(&ptr_info)); TestStructPtr obj(TestStruct::New()); - obj->f_interface_array.push_back(std::move(ptr)); + obj->f_interface_array.push_back(std::move(ptr_info)); auto data_view_holder = SerializeTestStruct(std::move(obj)); auto& data_view = *data_view_holder->data_view; diff --git a/mojo/public/cpp/bindings/tests/e2e_perftest.cc b/mojo/public/cpp/bindings/tests/e2e_perftest.cc index bc69e0f727..037b3d5910 100644 --- a/mojo/public/cpp/bindings/tests/e2e_perftest.cc +++ b/mojo/public/cpp/bindings/tests/e2e_perftest.cc @@ -9,12 +9,13 @@ #include "base/callback.h" #include "base/memory/ptr_util.h" #include "base/message_loop/message_loop.h" +#include "base/message_loop/message_loop_current.h" #include "base/run_loop.h" #include "base/strings/stringprintf.h" #include "base/test/perf_time_logger.h" #include "base/threading/thread_task_runner_handle.h" -#include "mojo/edk/embedder/embedder.h" -#include "mojo/edk/test/mojo_test_base.h" +#include "mojo/core/embedder/embedder.h" +#include "mojo/core/test/mojo_test_base.h" #include "mojo/public/cpp/bindings/strong_binding.h" #include "mojo/public/interfaces/bindings/tests/ping_service.mojom.h" #include "testing/gtest/include/gtest/gtest.h" @@ -82,7 +83,7 @@ void PingPongTest::RunTest(int iterations, int batch_size, int message_size) { current_iterations_ = 0; calls_outstanding_ = 0; - base::MessageLoop::current()->SetNestableTasksAllowed(true); + base::MessageLoopCurrent::Get()->SetNestableTasksAllowed(true); base::RunLoop run_loop; quit_closure_ = run_loop.QuitClosure(); base::ThreadTaskRunnerHandle::Get()->PostTask( @@ -112,7 +113,7 @@ void PingPongTest::OnPingDone(const std::string& reply) { DoPing(); } -class MojoE2EPerftest : public edk::test::MojoTestBase { +class MojoE2EPerftest : public core::test::MojoTestBase { public: void RunTestOnTaskRunner(base::TaskRunner* runner, MojoHandle client_mp, @@ -122,8 +123,9 @@ class MojoE2EPerftest : public edk::test::MojoTestBase { } else { base::RunLoop run_loop; runner->PostTaskAndReply( - FROM_HERE, base::Bind(&MojoE2EPerftest::RunTests, - base::Unretained(this), client_mp, test_name), + FROM_HERE, + base::Bind(&MojoE2EPerftest::RunTests, base::Unretained(this), + client_mp, test_name), run_loop.QuitClosure()); run_loop.Run(); } @@ -161,17 +163,17 @@ class MojoE2EPerftest : public edk::test::MojoTestBase { void CreateAndRunService(InterfaceRequest<test::EchoService> request, const base::Closure& cb) { - MakeStrongBinding(base::MakeUnique<EchoServiceImpl>(cb), std::move(request)); + MakeStrongBinding(std::make_unique<EchoServiceImpl>(cb), std::move(request)); } DEFINE_TEST_CLIENT_TEST_WITH_PIPE(PingService, MojoE2EPerftest, mp) { MojoHandle service_mp; EXPECT_EQ("hello", ReadMessageWithHandles(mp, &service_mp, 1)); - InterfaceRequest<test::EchoService> request; - request.Bind(ScopedMessagePipeHandle(MessagePipeHandle(service_mp))); + auto request = InterfaceRequest<test::EchoService>( + ScopedMessagePipeHandle(MessagePipeHandle(service_mp))); base::RunLoop run_loop; - edk::GetIOTaskRunner()->PostTask( + core::GetIOTaskRunner()->PostTask( FROM_HERE, base::Bind(&CreateAndRunService, base::Passed(&request), base::Bind(base::IgnoreResult(&base::TaskRunner::PostTask), @@ -181,23 +183,23 @@ DEFINE_TEST_CLIENT_TEST_WITH_PIPE(PingService, MojoE2EPerftest, mp) { } TEST_F(MojoE2EPerftest, MultiProcessEchoMainThread) { - RUN_CHILD_ON_PIPE(PingService, mp) + RunTestClient("PingService", [&](MojoHandle mp) { MojoHandle client_mp, service_mp; CreateMessagePipe(&client_mp, &service_mp); WriteMessageWithHandles(mp, "hello", &service_mp, 1); RunTestOnTaskRunner(message_loop_.task_runner().get(), client_mp, "MultiProcessEchoMainThread"); - END_CHILD() + }); } TEST_F(MojoE2EPerftest, MultiProcessEchoIoThread) { - RUN_CHILD_ON_PIPE(PingService, mp) + RunTestClient("PingService", [&](MojoHandle mp) { MojoHandle client_mp, service_mp; CreateMessagePipe(&client_mp, &service_mp); WriteMessageWithHandles(mp, "hello", &service_mp, 1); - RunTestOnTaskRunner(edk::GetIOTaskRunner().get(), client_mp, + RunTestOnTaskRunner(core::GetIOTaskRunner().get(), client_mp, "MultiProcessEchoIoThread"); - END_CHILD() + }); } } // namespace diff --git a/mojo/public/cpp/bindings/tests/handle_passing_unittest.cc b/mojo/public/cpp/bindings/tests/handle_passing_unittest.cc index ef977af935..509b0feb71 100644 --- a/mojo/public/cpp/bindings/tests/handle_passing_unittest.cc +++ b/mojo/public/cpp/bindings/tests/handle_passing_unittest.cc @@ -6,10 +6,10 @@ #include <utility> #include "base/memory/ptr_util.h" -#include "base/message_loop/message_loop.h" #include "base/run_loop.h" #include "mojo/public/cpp/bindings/binding.h" #include "mojo/public/cpp/bindings/strong_binding.h" +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" #include "mojo/public/cpp/system/wait.h" #include "mojo/public/cpp/test_support/test_utils.h" #include "mojo/public/interfaces/bindings/tests/sample_factory.mojom.h" @@ -99,8 +99,10 @@ class SampleFactoryImpl : public sample::Factory { sample::ResponsePtr response(sample::Response::New(2, std::move(pipe0))); callback.Run(std::move(response), text1); - if (request->obj) - request->obj->DoSomething(); + if (request->obj) { + imported::ImportedInterfacePtr proxy(std::move(request->obj)); + proxy->DoSomething(); + } } void DoStuff2(ScopedDataPipeConsumerHandle pipe, @@ -115,15 +117,12 @@ class SampleFactoryImpl : public sample::Factory { mojo::Wait(pipe.get(), MOJO_HANDLE_SIGNAL_READABLE, &state)); ASSERT_TRUE(state.satisfied_signals & MOJO_HANDLE_SIGNAL_READABLE); ASSERT_EQ(MOJO_RESULT_OK, - ReadDataRaw( - pipe.get(), nullptr, &data_size, MOJO_READ_DATA_FLAG_QUERY)); + pipe->ReadData(nullptr, &data_size, MOJO_READ_DATA_FLAG_QUERY)); ASSERT_NE(0, static_cast<int>(data_size)); char data[64]; ASSERT_LT(static_cast<int>(data_size), 64); - ASSERT_EQ( - MOJO_RESULT_OK, - ReadDataRaw( - pipe.get(), data, &data_size, MOJO_READ_DATA_FLAG_ALL_OR_NONE)); + ASSERT_EQ(MOJO_RESULT_OK, pipe->ReadData(data, &data_size, + MOJO_READ_DATA_FLAG_ALL_OR_NONE)); callback.Run(data); } @@ -131,7 +130,7 @@ class SampleFactoryImpl : public sample::Factory { void CreateNamedObject( InterfaceRequest<sample::NamedObject> object_request) override { EXPECT_TRUE(object_request.is_pending()); - MakeStrongBinding(base::MakeUnique<SampleNamedObjectImpl>(), + MakeStrongBinding(std::make_unique<SampleNamedObjectImpl>(), std::move(object_request)); } @@ -150,7 +149,7 @@ class SampleFactoryImpl : public sample::Factory { Binding<sample::Factory> binding_; }; -class HandlePassingTest : public testing::Test { +class HandlePassingTest : public BindingsTestBase { public: HandlePassingTest() {} @@ -158,8 +157,6 @@ class HandlePassingTest : public testing::Test { void PumpMessages() { base::RunLoop().RunUntilIdle(); } - private: - base::MessageLoop loop_; }; void DoStuff(bool* got_response, @@ -197,7 +194,7 @@ void DoStuff2(bool* got_response, closure.Run(); } -TEST_F(HandlePassingTest, Basic) { +TEST_P(HandlePassingTest, Basic) { sample::FactoryPtr factory; SampleFactoryImpl factory_impl(MakeRequest(&factory)); @@ -207,7 +204,7 @@ TEST_F(HandlePassingTest, Basic) { MessagePipe pipe1; EXPECT_TRUE(WriteTextMessage(pipe1.handle1.get(), kText2)); - imported::ImportedInterfacePtr imported; + imported::ImportedInterfacePtrInfo imported; base::RunLoop run_loop; ImportedInterfaceImpl imported_impl(MakeRequest(&imported), run_loop.QuitClosure()); @@ -232,13 +229,13 @@ TEST_F(HandlePassingTest, Basic) { EXPECT_EQ(1, ImportedInterfaceImpl::do_something_count() - count_before); } -TEST_F(HandlePassingTest, PassInvalid) { +TEST_P(HandlePassingTest, PassInvalid) { sample::FactoryPtr factory; SampleFactoryImpl factory_impl(MakeRequest(&factory)); - sample::RequestPtr request( - sample::Request::New(1, ScopedMessagePipeHandle(), base::nullopt, - imported::ImportedInterfacePtr())); + sample::RequestPtr request(sample::Request::New(1, ScopedMessagePipeHandle(), + base::nullopt, nullptr)); + bool got_response = false; std::string got_text_reply; base::RunLoop run_loop; @@ -254,7 +251,7 @@ TEST_F(HandlePassingTest, PassInvalid) { } // Verifies DataPipeConsumer can be passed and read from. -TEST_F(HandlePassingTest, DataPipe) { +TEST_P(HandlePassingTest, DataPipe) { sample::FactoryPtr factory; SampleFactoryImpl factory_impl(MakeRequest(&factory)); @@ -263,8 +260,7 @@ TEST_F(HandlePassingTest, DataPipe) { ScopedDataPipeProducerHandle producer_handle; ScopedDataPipeConsumerHandle consumer_handle; MojoCreateDataPipeOptions options = {sizeof(MojoCreateDataPipeOptions), - MOJO_CREATE_DATA_PIPE_OPTIONS_FLAG_NONE, - 1, + MOJO_CREATE_DATA_PIPE_FLAG_NONE, 1, 1024}; ASSERT_EQ(MOJO_RESULT_OK, CreateDataPipe(&options, &producer_handle, &consumer_handle)); @@ -272,10 +268,8 @@ TEST_F(HandlePassingTest, DataPipe) { // +1 for \0. uint32_t data_size = static_cast<uint32_t>(expected_text_reply.size() + 1); ASSERT_EQ(MOJO_RESULT_OK, - WriteDataRaw(producer_handle.get(), - expected_text_reply.c_str(), - &data_size, - MOJO_WRITE_DATA_FLAG_ALL_OR_NONE)); + producer_handle->WriteData(expected_text_reply.c_str(), &data_size, + MOJO_WRITE_DATA_FLAG_ALL_OR_NONE)); bool got_response = false; std::string got_text_reply; @@ -292,40 +286,14 @@ TEST_F(HandlePassingTest, DataPipe) { EXPECT_EQ(expected_text_reply, got_text_reply); } -TEST_F(HandlePassingTest, PipesAreClosed) { - sample::FactoryPtr factory; - SampleFactoryImpl factory_impl(MakeRequest(&factory)); - - MessagePipe extra_pipe; - - MojoHandle handle0_value = extra_pipe.handle0.get().value(); - MojoHandle handle1_value = extra_pipe.handle1.get().value(); - - { - std::vector<ScopedMessagePipeHandle> pipes(2); - pipes[0] = std::move(extra_pipe.handle0); - pipes[1] = std::move(extra_pipe.handle1); - - sample::RequestPtr request(sample::Request::New()); - request->more_pipes = std::move(pipes); - - factory->DoStuff(std::move(request), ScopedMessagePipeHandle(), - sample::Factory::DoStuffCallback()); - } - - // We expect the pipes to have been closed. - EXPECT_EQ(MOJO_RESULT_INVALID_ARGUMENT, MojoClose(handle0_value)); - EXPECT_EQ(MOJO_RESULT_INVALID_ARGUMENT, MojoClose(handle1_value)); -} - -TEST_F(HandlePassingTest, CreateNamedObject) { +TEST_P(HandlePassingTest, CreateNamedObject) { sample::FactoryPtr factory; SampleFactoryImpl factory_impl(MakeRequest(&factory)); sample::NamedObjectPtr object1; EXPECT_FALSE(object1); - InterfaceRequest<sample::NamedObject> object1_request(&object1); + auto object1_request = mojo::MakeRequest(&object1); EXPECT_TRUE(object1_request.is_pending()); factory->CreateNamedObject(std::move(object1_request)); EXPECT_FALSE(object1_request.is_pending()); // We've passed the request. @@ -351,6 +319,8 @@ TEST_F(HandlePassingTest, CreateNamedObject) { EXPECT_EQ(std::string("object2"), name2); } +INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(HandlePassingTest); + } // namespace } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/hash_unittest.cc b/mojo/public/cpp/bindings/tests/hash_unittest.cc index 9ce1f5bc7b..3a5bf4791c 100644 --- a/mojo/public/cpp/bindings/tests/hash_unittest.cc +++ b/mojo/public/cpp/bindings/tests/hash_unittest.cc @@ -22,14 +22,6 @@ TEST_F(HashTest, NestedStruct) { SimpleNestedStruct::New(ContainsOther::New(1)))); } -TEST_F(HashTest, UnmappedNativeStruct) { - // Just check that this template instantiation compiles. - ASSERT_EQ(::mojo::internal::Hash(::mojo::internal::kHashSeed, - UnmappedNativeStruct::New()), - ::mojo::internal::Hash(::mojo::internal::kHashSeed, - UnmappedNativeStruct::New())); -} - } // namespace } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/interface_ptr_unittest.cc b/mojo/public/cpp/bindings/tests/interface_ptr_unittest.cc index 431a844250..550a5aff56 100644 --- a/mojo/public/cpp/bindings/tests/interface_ptr_unittest.cc +++ b/mojo/public/cpp/bindings/tests/interface_ptr_unittest.cc @@ -1,3 +1,4 @@ + // Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -6,14 +7,21 @@ #include <utility> #include "base/bind.h" +#include "base/bind_helpers.h" #include "base/callback.h" +#include "base/callback_helpers.h" #include "base/memory/ptr_util.h" #include "base/message_loop/message_loop.h" #include "base/run_loop.h" +#include "base/sequenced_task_runner.h" +#include "base/task_scheduler/post_task.h" +#include "base/test/scoped_task_environment.h" #include "base/threading/sequenced_task_runner_handle.h" #include "base/threading/thread.h" +#include "base/threading/thread_task_runner_handle.h" #include "mojo/public/cpp/bindings/binding.h" #include "mojo/public/cpp/bindings/strong_binding.h" +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" #include "mojo/public/cpp/bindings/thread_safe_interface_ptr.h" #include "mojo/public/interfaces/bindings/tests/math_calculator.mojom.h" #include "mojo/public/interfaces/bindings/tests/sample_interfaces.mojom.h" @@ -192,15 +200,12 @@ class IntegerAccessorImpl : public sample::IntegerAccessor { base::Closure closure_; }; -class InterfacePtrTest : public testing::Test { +class InterfacePtrTest : public BindingsTestBase { public: InterfacePtrTest() {} ~InterfacePtrTest() override { base::RunLoop().RunUntilIdle(); } void PumpMessages() { base::RunLoop().RunUntilIdle(); } - - private: - base::MessageLoop loop_; }; void SetFlagAndRunClosure(bool* flag, const base::Closure& closure) { @@ -219,54 +224,62 @@ void ExpectValueAndRunClosure(uint32_t expected_value, closure.Run(); } -TEST_F(InterfacePtrTest, IsBound) { +TEST_P(InterfacePtrTest, IsBound) { math::CalculatorPtr calc; EXPECT_FALSE(calc.is_bound()); MathCalculatorImpl calc_impl(MakeRequest(&calc)); EXPECT_TRUE(calc.is_bound()); } -TEST_F(InterfacePtrTest, EndToEnd) { - math::CalculatorPtr calc; - MathCalculatorImpl calc_impl(MakeRequest(&calc)); - - // Suppose this is instantiated in a process that has pipe1_. - MathCalculatorUI calculator_ui(std::move(calc)); - - base::RunLoop run_loop, run_loop2; - calculator_ui.Add(2.0, run_loop.QuitClosure()); - calculator_ui.Multiply(5.0, run_loop2.QuitClosure()); - run_loop.Run(); - run_loop2.Run(); +class EndToEndInterfacePtrTest : public InterfacePtrTest { + public: + void RunTest(const scoped_refptr<base::SequencedTaskRunner> runner) { + base::RunLoop run_loop; + done_closure_ = run_loop.QuitClosure(); + done_runner_ = base::ThreadTaskRunnerHandle::Get(); + runner->PostTask(FROM_HERE, + base::Bind(&EndToEndInterfacePtrTest::RunTestImpl, + base::Unretained(this))); + run_loop.Run(); + } - EXPECT_EQ(10.0, calculator_ui.GetOutput()); -} + private: + void RunTestImpl() { + math::CalculatorPtr calc; + calc_impl_ = std::make_unique<MathCalculatorImpl>(MakeRequest(&calc)); + calculator_ui_ = std::make_unique<MathCalculatorUI>(std::move(calc)); + calculator_ui_->Add(2.0, base::Bind(&EndToEndInterfacePtrTest::AddDone, + base::Unretained(this))); + calculator_ui_->Multiply(5.0, + base::Bind(&EndToEndInterfacePtrTest::MultiplyDone, + base::Unretained(this))); + EXPECT_EQ(0.0, calculator_ui_->GetOutput()); + } -TEST_F(InterfacePtrTest, EndToEnd_Synchronous) { - math::CalculatorPtr calc; - MathCalculatorImpl calc_impl(MakeRequest(&calc)); + void AddDone() { EXPECT_EQ(2.0, calculator_ui_->GetOutput()); } - // Suppose this is instantiated in a process that has pipe1_. - MathCalculatorUI calculator_ui(std::move(calc)); + void MultiplyDone() { + EXPECT_EQ(10.0, calculator_ui_->GetOutput()); + calculator_ui_.reset(); + calc_impl_.reset(); + done_runner_->PostTask(FROM_HERE, base::ResetAndReturn(&done_closure_)); + } - EXPECT_EQ(0.0, calculator_ui.GetOutput()); + base::Closure done_closure_; + scoped_refptr<base::SingleThreadTaskRunner> done_runner_; + std::unique_ptr<MathCalculatorUI> calculator_ui_; + std::unique_ptr<MathCalculatorImpl> calc_impl_; +}; - base::RunLoop run_loop; - calculator_ui.Add(2.0, run_loop.QuitClosure()); - EXPECT_EQ(0.0, calculator_ui.GetOutput()); - calc_impl.binding()->WaitForIncomingMethodCall(); - run_loop.Run(); - EXPECT_EQ(2.0, calculator_ui.GetOutput()); +TEST_P(EndToEndInterfacePtrTest, EndToEnd) { + RunTest(base::ThreadTaskRunnerHandle::Get()); +} - base::RunLoop run_loop2; - calculator_ui.Multiply(5.0, run_loop2.QuitClosure()); - EXPECT_EQ(2.0, calculator_ui.GetOutput()); - calc_impl.binding()->WaitForIncomingMethodCall(); - run_loop2.Run(); - EXPECT_EQ(10.0, calculator_ui.GetOutput()); +TEST_P(EndToEndInterfacePtrTest, EndToEndOnSequence) { + RunTest(base::CreateSequencedTaskRunnerWithTraits({})); } -TEST_F(InterfacePtrTest, Movable) { +TEST_P(InterfacePtrTest, Movable) { math::CalculatorPtr a; math::CalculatorPtr b; MathCalculatorImpl calc_impl(MakeRequest(&b)); @@ -280,7 +293,7 @@ TEST_F(InterfacePtrTest, Movable) { EXPECT_TRUE(!b); } -TEST_F(InterfacePtrTest, Resettable) { +TEST_P(InterfacePtrTest, Resettable) { math::CalculatorPtr a; EXPECT_TRUE(!a); @@ -304,7 +317,7 @@ TEST_F(InterfacePtrTest, Resettable) { EXPECT_EQ(MOJO_RESULT_INVALID_ARGUMENT, CloseRaw(handle)); } -TEST_F(InterfacePtrTest, BindInvalidHandle) { +TEST_P(InterfacePtrTest, BindInvalidHandle) { math::CalculatorPtr ptr; EXPECT_FALSE(ptr.get()); EXPECT_FALSE(ptr); @@ -314,7 +327,7 @@ TEST_F(InterfacePtrTest, BindInvalidHandle) { EXPECT_FALSE(ptr); } -TEST_F(InterfacePtrTest, EncounteredError) { +TEST_P(InterfacePtrTest, EncounteredError) { math::CalculatorPtr proxy; MathCalculatorImpl calc_impl(MakeRequest(&proxy)); @@ -343,7 +356,7 @@ TEST_F(InterfacePtrTest, EncounteredError) { EXPECT_TRUE(calculator_ui.encountered_error()); } -TEST_F(InterfacePtrTest, EncounteredErrorCallback) { +TEST_P(InterfacePtrTest, EncounteredErrorCallback) { math::CalculatorPtr proxy; MathCalculatorImpl calc_impl(MakeRequest(&proxy)); @@ -380,7 +393,7 @@ TEST_F(InterfacePtrTest, EncounteredErrorCallback) { EXPECT_TRUE(encountered_error); } -TEST_F(InterfacePtrTest, DestroyInterfacePtrOnMethodResponse) { +TEST_P(InterfacePtrTest, DestroyInterfacePtrOnMethodResponse) { math::CalculatorPtr proxy; MathCalculatorImpl calc_impl(MakeRequest(&proxy)); @@ -395,7 +408,7 @@ TEST_F(InterfacePtrTest, DestroyInterfacePtrOnMethodResponse) { EXPECT_EQ(0, SelfDestructingMathCalculatorUI::num_instances()); } -TEST_F(InterfacePtrTest, NestedDestroyInterfacePtrOnMethodResponse) { +TEST_P(InterfacePtrTest, NestedDestroyInterfacePtrOnMethodResponse) { math::CalculatorPtr proxy; MathCalculatorImpl calc_impl(MakeRequest(&proxy)); @@ -410,7 +423,7 @@ TEST_F(InterfacePtrTest, NestedDestroyInterfacePtrOnMethodResponse) { EXPECT_EQ(0, SelfDestructingMathCalculatorUI::num_instances()); } -TEST_F(InterfacePtrTest, ReentrantWaitForIncomingMethodCall) { +TEST_P(InterfacePtrTest, ReentrantWaitForIncomingMethodCall) { sample::ServicePtr proxy; ReentrantServiceImpl impl(MakeRequest(&proxy)); @@ -428,7 +441,7 @@ TEST_F(InterfacePtrTest, ReentrantWaitForIncomingMethodCall) { EXPECT_EQ(2, impl.max_call_depth()); } -TEST_F(InterfacePtrTest, QueryVersion) { +TEST_P(InterfacePtrTest, QueryVersion) { IntegerAccessorImpl impl; sample::IntegerAccessorPtr ptr; Binding<sample::IntegerAccessor> binding(&impl, MakeRequest(&ptr)); @@ -443,7 +456,7 @@ TEST_F(InterfacePtrTest, QueryVersion) { EXPECT_EQ(3u, ptr.version()); } -TEST_F(InterfacePtrTest, RequireVersion) { +TEST_P(InterfacePtrTest, RequireVersion) { IntegerAccessorImpl impl; sample::IntegerAccessorPtr ptr; Binding<sample::IntegerAccessor> binding(&impl, MakeRequest(&ptr)); @@ -513,7 +526,7 @@ TEST(StrongConnectorTest, Math) { base::RunLoop run_loop; auto binding = - MakeStrongBinding(base::MakeUnique<StrongMathCalculatorImpl>(&destroyed), + MakeStrongBinding(std::make_unique<StrongMathCalculatorImpl>(&destroyed), MakeRequest(&calc)); binding->set_connection_error_handler(base::Bind( &SetFlagAndRunClosure, &error_received, run_loop.QuitClosure())); @@ -544,14 +557,14 @@ TEST(StrongConnectorTest, Math) { class WeakMathCalculatorImpl : public math::Calculator { public: - WeakMathCalculatorImpl(ScopedMessagePipeHandle handle, + WeakMathCalculatorImpl(math::CalculatorRequest request, bool* error_received, bool* destroyed, const base::Closure& closure) : error_received_(error_received), destroyed_(destroyed), closure_(closure), - binding_(this, std::move(handle)) { + binding_(this, std::move(request)) { binding_.set_connection_error_handler( base::Bind(&SetFlagAndRunClosure, error_received_, closure_)); } @@ -585,8 +598,9 @@ TEST(WeakConnectorTest, Math) { bool destroyed = false; MessagePipe pipe; base::RunLoop run_loop; - WeakMathCalculatorImpl impl(std::move(pipe.handle0), &error_received, - &destroyed, run_loop.QuitClosure()); + WeakMathCalculatorImpl impl(math::CalculatorRequest(std::move(pipe.handle0)), + &error_received, &destroyed, + run_loop.QuitClosure()); math::CalculatorPtr calc; calc.Bind(InterfacePtrInfo<math::Calculator>(std::move(pipe.handle1), 0u)); @@ -639,7 +653,7 @@ class BImpl : public B { private: void GetC(InterfaceRequest<C> c) override { - MakeStrongBinding(base::MakeUnique<CImpl>(d_called_, closure_), + MakeStrongBinding(std::make_unique<CImpl>(d_called_, closure_), std::move(c)); } @@ -658,7 +672,7 @@ class AImpl : public A { private: void GetB(InterfaceRequest<B> b) override { - MakeStrongBinding(base::MakeUnique<BImpl>(&d_called_, closure_), + MakeStrongBinding(std::make_unique<BImpl>(&d_called_, closure_), std::move(b)); } @@ -667,7 +681,7 @@ class AImpl : public A { base::Closure closure_; }; -TEST_F(InterfacePtrTest, Scoping) { +TEST_P(InterfacePtrTest, Scoping) { APtr a; base::RunLoop run_loop; AImpl a_impl(MakeRequest(&a), run_loop.QuitClosure()); @@ -703,16 +717,13 @@ class PingTestImpl : public sample::PingTest { }; // Tests that FuseProxy does what it's supposed to do. -TEST_F(InterfacePtrTest, Fusion) { - sample::PingTestPtr proxy; - PingTestImpl impl(MakeRequest(&proxy)); +TEST_P(InterfacePtrTest, Fusion) { + sample::PingTestPtrInfo proxy_info; + PingTestImpl impl(MakeRequest(&proxy_info)); - // Create another PingTest pipe. + // Create another PingTest pipe and fuse it to the one hanging off |impl|. sample::PingTestPtr ptr; - sample::PingTestRequest request(&ptr); - - // Fuse the new pipe to the one hanging off |impl|. - EXPECT_TRUE(FuseInterface(std::move(request), proxy.PassInterface())); + EXPECT_TRUE(FuseInterface(mojo::MakeRequest(&ptr), std::move(proxy_info))); // Ping! bool called = false; @@ -726,18 +737,18 @@ void Fail() { FAIL() << "Unexpected connection error"; } -TEST_F(InterfacePtrTest, FlushForTesting) { +TEST_P(InterfacePtrTest, FlushForTesting) { math::CalculatorPtr calc; MathCalculatorImpl calc_impl(MakeRequest(&calc)); calc.set_connection_error_handler(base::Bind(&Fail)); MathCalculatorUI calculator_ui(std::move(calc)); - calculator_ui.Add(2.0, base::Bind(&base::DoNothing)); + calculator_ui.Add(2.0, base::DoNothing()); calculator_ui.GetInterfacePtr().FlushForTesting(); EXPECT_EQ(2.0, calculator_ui.GetOutput()); - calculator_ui.Multiply(5.0, base::Bind(&base::DoNothing)); + calculator_ui.Multiply(5.0, base::DoNothing()); calculator_ui.GetInterfacePtr().FlushForTesting(); EXPECT_EQ(10.0, calculator_ui.GetOutput()); @@ -747,7 +758,7 @@ void SetBool(bool* value) { *value = true; } -TEST_F(InterfacePtrTest, FlushForTestingWithClosedPeer) { +TEST_P(InterfacePtrTest, FlushForTestingWithClosedPeer) { math::CalculatorPtr calc; MakeRequest(&calc); bool called = false; @@ -757,7 +768,7 @@ TEST_F(InterfacePtrTest, FlushForTestingWithClosedPeer) { calc.FlushForTesting(); } -TEST_F(InterfacePtrTest, ConnectionErrorWithReason) { +TEST_P(InterfacePtrTest, ConnectionErrorWithReason) { math::CalculatorPtr calc; MathCalculatorImpl calc_impl(MakeRequest(&calc)); @@ -776,7 +787,7 @@ TEST_F(InterfacePtrTest, ConnectionErrorWithReason) { run_loop.Run(); } -TEST_F(InterfacePtrTest, InterfaceRequestResetWithReason) { +TEST_P(InterfacePtrTest, InterfaceRequestResetWithReason) { math::CalculatorPtr calc; auto request = MakeRequest(&calc); @@ -795,17 +806,17 @@ TEST_F(InterfacePtrTest, InterfaceRequestResetWithReason) { run_loop.Run(); } -TEST_F(InterfacePtrTest, CallbackIsPassedInterfacePtr) { +TEST_P(InterfacePtrTest, CallbackIsPassedInterfacePtr) { sample::PingTestPtr ptr; - sample::PingTestRequest request(&ptr); + auto request = mojo::MakeRequest(&ptr); base::RunLoop run_loop; // Make a call with the proxy's lifetime bound to the response callback. sample::PingTest* raw_proxy = ptr.get(); ptr.set_connection_error_handler(run_loop.QuitClosure()); - raw_proxy->Ping( - base::Bind([](sample::PingTestPtr ptr) {}, base::Passed(&ptr))); + raw_proxy->Ping(base::Bind(base::DoNothing::Repeatedly<sample::PingTestPtr>(), + base::Passed(&ptr))); // Trigger an error on |ptr|. This will ultimately lead to the proxy's // response callbacks being destroyed, which will in turn lead to the proxy @@ -814,9 +825,9 @@ TEST_F(InterfacePtrTest, CallbackIsPassedInterfacePtr) { run_loop.Run(); } -TEST_F(InterfacePtrTest, ConnectionErrorHandlerOwnsInterfacePtr) { +TEST_P(InterfacePtrTest, ConnectionErrorHandlerOwnsInterfacePtr) { sample::PingTestPtr* ptr = new sample::PingTestPtr; - sample::PingTestRequest request(ptr); + auto request = mojo::MakeRequest(ptr); base::RunLoop run_loop; @@ -836,7 +847,7 @@ TEST_F(InterfacePtrTest, ConnectionErrorHandlerOwnsInterfacePtr) { run_loop.Run(); } -TEST_F(InterfacePtrTest, ThreadSafeInterfacePointer) { +TEST_P(InterfacePtrTest, ThreadSafeInterfacePointer) { math::CalculatorPtr ptr; MathCalculatorImpl calc_impl(MakeRequest(&ptr)); scoped_refptr<math::ThreadSafeCalculatorPtr> thread_safe_ptr = @@ -844,10 +855,6 @@ TEST_F(InterfacePtrTest, ThreadSafeInterfacePointer) { base::RunLoop run_loop; - // Create and start the thread from where we'll call the interface pointer. - base::Thread other_thread("service test thread"); - other_thread.Start(); - auto run_method = base::Bind( [](const scoped_refptr<base::TaskRunner>& main_task_runner, const base::Closure& quit_closure, @@ -855,35 +862,35 @@ TEST_F(InterfacePtrTest, ThreadSafeInterfacePointer) { auto calc_callback = base::Bind( [](const scoped_refptr<base::TaskRunner>& main_task_runner, const base::Closure& quit_closure, - base::PlatformThreadId thread_id, + scoped_refptr<base::SequencedTaskRunner> sender_sequence_runner, double result) { EXPECT_EQ(123, result); - // Validate the callback is invoked on the calling thread. - EXPECT_EQ(thread_id, base::PlatformThread::CurrentId()); + // Validate the callback is invoked on the calling sequence. + EXPECT_TRUE(sender_sequence_runner->RunsTasksInCurrentSequence()); // Notify the run_loop to quit. main_task_runner->PostTask(FROM_HERE, quit_closure); }); - (*thread_safe_ptr)->Add( - 123, base::Bind(calc_callback, main_task_runner, quit_closure, - base::PlatformThread::CurrentId())); + scoped_refptr<base::SequencedTaskRunner> current_sequence_runner = + base::SequencedTaskRunnerHandle::Get(); + (*thread_safe_ptr) + ->Add(123, base::Bind(calc_callback, main_task_runner, quit_closure, + current_sequence_runner)); }, base::SequencedTaskRunnerHandle::Get(), run_loop.QuitClosure(), thread_safe_ptr); - other_thread.message_loop()->task_runner()->PostTask(FROM_HERE, run_method); + base::CreateSequencedTaskRunnerWithTraits({})->PostTask(FROM_HERE, + run_method); // Block until the method callback is called on the background thread. run_loop.Run(); } -TEST_F(InterfacePtrTest, ThreadSafeInterfacePointerWithTaskRunner) { - // Create and start the thread from where we'll bind the interface pointer. - base::Thread other_thread("service test thread"); - other_thread.Start(); - const scoped_refptr<base::SingleThreadTaskRunner>& other_thread_task_runner = - other_thread.message_loop()->task_runner(); +TEST_P(InterfacePtrTest, ThreadSafeInterfacePointerWithTaskRunner) { + const scoped_refptr<base::SequencedTaskRunner> other_thread_task_runner = + base::CreateSequencedTaskRunnerWithTraits({}); math::CalculatorPtr ptr; - math::CalculatorRequest request(&ptr); + auto request = mojo::MakeRequest(&ptr); // Create a ThreadSafeInterfacePtr that we'll bind from a different thread. scoped_refptr<math::ThreadSafeCalculatorPtr> thread_safe_ptr = @@ -907,7 +914,7 @@ TEST_F(InterfacePtrTest, ThreadSafeInterfacePointerWithTaskRunner) { }, base::SequencedTaskRunnerHandle::Get(), run_loop.QuitClosure(), thread_safe_ptr, base::Passed(&request), &math_calc_impl); - other_thread.message_loop()->task_runner()->PostTask(FROM_HERE, run_method); + other_thread_task_runner->PostTask(FROM_HERE, run_method); run_loop.Run(); } @@ -932,6 +939,8 @@ TEST_F(InterfacePtrTest, ThreadSafeInterfacePointerWithTaskRunner) { thread_safe_ptr = nullptr; } +INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(InterfacePtrTest); + } // namespace } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/lazy_serialization_unittest.cc b/mojo/public/cpp/bindings/tests/lazy_serialization_unittest.cc new file mode 100644 index 0000000000..81f8419fab --- /dev/null +++ b/mojo/public/cpp/bindings/tests/lazy_serialization_unittest.cc @@ -0,0 +1,166 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/macros.h" +#include "base/run_loop.h" +#include "base/test/bind_test_util.h" +#include "base/test/scoped_task_environment.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" +#include "mojo/public/interfaces/bindings/tests/struct_with_traits.mojom.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace mojo { +namespace { + +class LazySerializationTest : public testing::Test { + public: + LazySerializationTest() {} + ~LazySerializationTest() override {} + + private: + base::test::ScopedTaskEnvironment task_environment_; + + DISALLOW_COPY_AND_ASSIGN(LazySerializationTest); +}; + +class TestUnserializedStructImpl : public test::TestUnserializedStruct { + public: + explicit TestUnserializedStructImpl( + test::TestUnserializedStructRequest request) + : binding_(this, std::move(request)) {} + ~TestUnserializedStructImpl() override {} + + // test::TestUnserializedStruct: + void PassUnserializedStruct( + const test::StructWithUnreachableTraitsImpl& s, + const PassUnserializedStructCallback& callback) override { + callback.Run(s); + } + + private: + mojo::Binding<test::TestUnserializedStruct> binding_; + + DISALLOW_COPY_AND_ASSIGN(TestUnserializedStructImpl); +}; + +class ForceSerializeTesterImpl : public test::ForceSerializeTester { + public: + ForceSerializeTesterImpl(test::ForceSerializeTesterRequest request) + : binding_(this, std::move(request)) {} + ~ForceSerializeTesterImpl() override = default; + + // test::ForceSerializeTester: + void SendForceSerializedStruct( + const test::StructForceSerializeImpl& s, + const SendForceSerializedStructCallback& callback) override { + callback.Run(s); + } + + void SendNestedForceSerializedStruct( + const test::StructNestedForceSerializeImpl& s, + const SendNestedForceSerializedStructCallback& callback) override { + callback.Run(s); + } + + private: + Binding<test::ForceSerializeTester> binding_; + + DISALLOW_COPY_AND_ASSIGN(ForceSerializeTesterImpl); +}; + +TEST_F(LazySerializationTest, NeverSerialize) { + // Basic sanity check to ensure that no messages are serialized by default in + // environments where lazy serialization is supported, on an interface which + // supports lazy serialization, and where both ends of the interface are in + // the same process. + + test::TestUnserializedStructPtr ptr; + TestUnserializedStructImpl impl(MakeRequest(&ptr)); + + const int32_t kTestMagicNumber = 42; + + test::StructWithUnreachableTraitsImpl data; + EXPECT_EQ(0, data.magic_number); + data.magic_number = kTestMagicNumber; + + // Send our data over the pipe and wait for it to come back. The value should + // be preserved. We know the data was never serialized because the + // StructTraits for this type will DCHECK if executed in any capacity. + int received_number = 0; + base::RunLoop loop; + ptr->PassUnserializedStruct( + data, base::Bind( + [](base::RunLoop* loop, int* received_number, + const test::StructWithUnreachableTraitsImpl& passed) { + *received_number = passed.magic_number; + loop->Quit(); + }, + &loop, &received_number)); + loop.Run(); + EXPECT_EQ(kTestMagicNumber, received_number); +} + +TEST_F(LazySerializationTest, ForceSerialize) { + // Verifies that the [force_serialize] attribute works as intended: i.e., even + // with lazy serialization enabled, messages which carry a force-serialized + // type will always serialize at call time. + + test::ForceSerializeTesterPtr tester; + ForceSerializeTesterImpl impl(mojo::MakeRequest(&tester)); + + constexpr int32_t kTestValue = 42; + + base::RunLoop loop; + test::StructForceSerializeImpl in; + in.set_value(kTestValue); + EXPECT_FALSE(in.was_serialized()); + EXPECT_FALSE(in.was_deserialized()); + tester->SendForceSerializedStruct( + in, base::BindLambdaForTesting( + [&](const test::StructForceSerializeImpl& passed) { + EXPECT_EQ(kTestValue, passed.value()); + EXPECT_TRUE(passed.was_deserialized()); + EXPECT_FALSE(passed.was_serialized()); + loop.Quit(); + })); + EXPECT_TRUE(in.was_serialized()); + EXPECT_FALSE(in.was_deserialized()); + loop.Run(); + EXPECT_TRUE(in.was_serialized()); + EXPECT_FALSE(in.was_deserialized()); +} + +TEST_F(LazySerializationTest, ForceSerializeNested) { + // Verifies that the [force_serialize] attribute works as intended in a nested + // context, i.e. when a force-serialized type is contained within a + // non-force-serialized type, + + test::ForceSerializeTesterPtr tester; + ForceSerializeTesterImpl impl(mojo::MakeRequest(&tester)); + + constexpr int32_t kTestValue = 42; + + base::RunLoop loop; + test::StructNestedForceSerializeImpl in; + in.force().set_value(kTestValue); + EXPECT_FALSE(in.was_serialized()); + EXPECT_FALSE(in.was_deserialized()); + tester->SendNestedForceSerializedStruct( + in, base::BindLambdaForTesting( + [&](const test::StructNestedForceSerializeImpl& passed) { + EXPECT_EQ(kTestValue, passed.force().value()); + EXPECT_TRUE(passed.was_deserialized()); + EXPECT_FALSE(passed.was_serialized()); + loop.Quit(); + })); + EXPECT_TRUE(in.was_serialized()); + EXPECT_FALSE(in.was_deserialized()); + loop.Run(); + EXPECT_TRUE(in.was_serialized()); + EXPECT_FALSE(in.was_deserialized()); +} + +} // namespace +} // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/map_unittest.cc b/mojo/public/cpp/bindings/tests/map_unittest.cc index 8d630a5862..104f3fd37c 100644 --- a/mojo/public/cpp/bindings/tests/map_unittest.cc +++ b/mojo/public/cpp/bindings/tests/map_unittest.cc @@ -4,9 +4,9 @@ #include <stddef.h> #include <stdint.h> -#include <unordered_map> #include <utility> +#include "base/containers/flat_map.h" #include "mojo/public/cpp/bindings/tests/rect_chromium.h" #include "mojo/public/interfaces/bindings/tests/rect.mojom.h" #include "mojo/public/interfaces/bindings/tests/test_structs.mojom.h" @@ -17,7 +17,7 @@ namespace test { namespace { TEST(MapTest, StructKey) { - std::unordered_map<RectPtr, int32_t> map; + base::flat_map<RectPtr, int32_t> map; map.insert(std::make_pair(Rect::New(1, 2, 3, 4), 123)); RectPtr key = Rect::New(1, 2, 3, 4); @@ -29,7 +29,7 @@ TEST(MapTest, StructKey) { } TEST(MapTest, TypemappedStructKey) { - std::unordered_map<ContainsHashablePtr, int32_t> map; + base::flat_map<ContainsHashablePtr, int32_t> map; map.insert( std::make_pair(ContainsHashable::New(RectChromium(1, 2, 3, 4)), 123)); diff --git a/mojo/public/cpp/bindings/tests/message_queue.h b/mojo/public/cpp/bindings/tests/message_queue.h index 8f13f7ab6d..836c60d3a4 100644 --- a/mojo/public/cpp/bindings/tests/message_queue.h +++ b/mojo/public/cpp/bindings/tests/message_queue.h @@ -5,8 +5,7 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_TESTS_MESSAGE_QUEUE_H_ #define MOJO_PUBLIC_CPP_BINDINGS_TESTS_MESSAGE_QUEUE_H_ -#include <queue> - +#include "base/containers/queue.h" #include "base/macros.h" #include "mojo/public/cpp/bindings/message.h" @@ -33,7 +32,7 @@ class MessageQueue { private: void Pop(); - std::queue<Message> queue_; + base::queue<Message> queue_; DISALLOW_COPY_AND_ASSIGN(MessageQueue); }; diff --git a/mojo/public/cpp/bindings/tests/multiplex_router_unittest.cc b/mojo/public/cpp/bindings/tests/multiplex_router_unittest.cc index 89509283c4..ff1c02db35 100644 --- a/mojo/public/cpp/bindings/tests/multiplex_router_unittest.cc +++ b/mojo/public/cpp/bindings/tests/multiplex_router_unittest.cc @@ -43,7 +43,12 @@ class MultiplexRouterTest : public testing::Test { endpoint1_ = router1_->CreateLocalEndpointHandle(id); } - void TearDown() override {} + void TearDown() override { + endpoint1_.reset(); + endpoint0_.reset(); + router1_ = nullptr; + router0_ = nullptr; + } void PumpMessages() { base::RunLoop().RunUntilIdle(); } @@ -59,11 +64,11 @@ class MultiplexRouterTest : public testing::Test { TEST_F(MultiplexRouterTest, BasicRequestResponse) { InterfaceEndpointClient client0(std::move(endpoint0_), nullptr, - base::MakeUnique<PassThroughFilter>(), false, + std::make_unique<PassThroughFilter>(), false, base::ThreadTaskRunnerHandle::Get(), 0u); ResponseGenerator generator; InterfaceEndpointClient client1(std::move(endpoint1_), &generator, - base::MakeUnique<PassThroughFilter>(), false, + std::make_unique<PassThroughFilter>(), false, base::ThreadTaskRunnerHandle::Get(), 0u); Message request; @@ -72,7 +77,7 @@ TEST_F(MultiplexRouterTest, BasicRequestResponse) { MessageQueue message_queue; base::RunLoop run_loop; client0.AcceptWithResponder( - &request, base::MakeUnique<MessageAccumulator>(&message_queue, + &request, std::make_unique<MessageAccumulator>(&message_queue, run_loop.QuitClosure())); run_loop.Run(); @@ -91,7 +96,7 @@ TEST_F(MultiplexRouterTest, BasicRequestResponse) { base::RunLoop run_loop2; client0.AcceptWithResponder( - &request2, base::MakeUnique<MessageAccumulator>(&message_queue, + &request2, std::make_unique<MessageAccumulator>(&message_queue, run_loop2.QuitClosure())); run_loop2.Run(); @@ -106,11 +111,11 @@ TEST_F(MultiplexRouterTest, BasicRequestResponse) { TEST_F(MultiplexRouterTest, BasicRequestResponse_Synchronous) { InterfaceEndpointClient client0(std::move(endpoint0_), nullptr, - base::MakeUnique<PassThroughFilter>(), false, + std::make_unique<PassThroughFilter>(), false, base::ThreadTaskRunnerHandle::Get(), 0u); ResponseGenerator generator; InterfaceEndpointClient client1(std::move(endpoint1_), &generator, - base::MakeUnique<PassThroughFilter>(), false, + std::make_unique<PassThroughFilter>(), false, base::ThreadTaskRunnerHandle::Get(), 0u); Message request; @@ -118,7 +123,7 @@ TEST_F(MultiplexRouterTest, BasicRequestResponse_Synchronous) { MessageQueue message_queue; client0.AcceptWithResponder( - &request, base::MakeUnique<MessageAccumulator>(&message_queue)); + &request, std::make_unique<MessageAccumulator>(&message_queue)); router1_->WaitForIncomingMessage(MOJO_DEADLINE_INDEFINITE); router0_->WaitForIncomingMessage(MOJO_DEADLINE_INDEFINITE); @@ -136,7 +141,7 @@ TEST_F(MultiplexRouterTest, BasicRequestResponse_Synchronous) { AllocRequestMessage(1, "hello again", &request2); client0.AcceptWithResponder( - &request2, base::MakeUnique<MessageAccumulator>(&message_queue)); + &request2, std::make_unique<MessageAccumulator>(&message_queue)); router1_->WaitForIncomingMessage(MOJO_DEADLINE_INDEFINITE); router0_->WaitForIncomingMessage(MOJO_DEADLINE_INDEFINITE); @@ -168,7 +173,7 @@ TEST_F(MultiplexRouterTest, LazyResponses) { MessageQueue message_queue; base::RunLoop run_loop2; client0.AcceptWithResponder( - &request, base::MakeUnique<MessageAccumulator>(&message_queue, + &request, std::make_unique<MessageAccumulator>(&message_queue, run_loop2.QuitClosure())); run_loop.Run(); @@ -195,7 +200,7 @@ TEST_F(MultiplexRouterTest, LazyResponses) { base::RunLoop run_loop4; client0.AcceptWithResponder( - &request2, base::MakeUnique<MessageAccumulator>(&message_queue, + &request2, std::make_unique<MessageAccumulator>(&message_queue, run_loop4.QuitClosure())); run_loop3.Run(); @@ -248,7 +253,7 @@ TEST_F(MultiplexRouterTest, MissingResponses) { MessageQueue message_queue; client0.AcceptWithResponder( - &request, base::MakeUnique<MessageAccumulator>(&message_queue)); + &request, std::make_unique<MessageAccumulator>(&message_queue)); run_loop3.Run(); // The request has been received but no response has been sent. @@ -284,10 +289,10 @@ TEST_F(MultiplexRouterTest, LateResponse) { LazyResponseGenerator generator(run_loop.QuitClosure()); { InterfaceEndpointClient client0( - std::move(endpoint0_), nullptr, base::MakeUnique<PassThroughFilter>(), + std::move(endpoint0_), nullptr, std::make_unique<PassThroughFilter>(), false, base::ThreadTaskRunnerHandle::Get(), 0u); InterfaceEndpointClient client1(std::move(endpoint1_), &generator, - base::MakeUnique<PassThroughFilter>(), + std::make_unique<PassThroughFilter>(), false, base::ThreadTaskRunnerHandle::Get(), 0u); @@ -296,7 +301,7 @@ TEST_F(MultiplexRouterTest, LateResponse) { MessageQueue message_queue; client0.AcceptWithResponder( - &request, base::MakeUnique<MessageAccumulator>(&message_queue)); + &request, std::make_unique<MessageAccumulator>(&message_queue)); run_loop.Run(); diff --git a/mojo/public/cpp/bindings/tests/native_struct_unittest.cc b/mojo/public/cpp/bindings/tests/native_struct_unittest.cc new file mode 100644 index 0000000000..6e3cbcbe36 --- /dev/null +++ b/mojo/public/cpp/bindings/tests/native_struct_unittest.cc @@ -0,0 +1,98 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <stdint.h> + +#include <vector> + +#include "base/bind.h" +#include "base/macros.h" +#include "base/run_loop.h" +#include "base/test/scoped_task_environment.h" +#include "ipc/ipc_param_traits.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" +#include "mojo/public/cpp/system/message_pipe.h" +#include "mojo/public/cpp/system/wait.h" +#include "mojo/public/interfaces/bindings/tests/test_native_types.mojom.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace mojo { + +class NativeStructTest : public BindingsTestBase, + public test::NativeTypeTester { + public: + NativeStructTest() : binding_(this, mojo::MakeRequest(&proxy_)) {} + ~NativeStructTest() override = default; + + test::NativeTypeTester* proxy() { return proxy_.get(); } + + private: + // test::NativeTypeTester: + void PassNativeStruct(const test::TestNativeStruct& s, + const PassNativeStructCallback& callback) override { + callback.Run(s); + } + + void PassNativeStructWithAttachments( + test::TestNativeStructWithAttachments s, + const PassNativeStructWithAttachmentsCallback& callback) override { + callback.Run(std::move(s)); + } + + test::NativeTypeTesterPtr proxy_; + Binding<test::NativeTypeTester> binding_; + + DISALLOW_COPY_AND_ASSIGN(NativeStructTest); +}; + +TEST_P(NativeStructTest, NativeStruct) { + test::TestNativeStruct s("hello world", 5, 42); + base::RunLoop loop; + proxy()->PassNativeStruct( + s, base::Bind( + [](test::TestNativeStruct* expected_struct, base::RunLoop* loop, + const test::TestNativeStruct& passed) { + EXPECT_EQ(expected_struct->message(), passed.message()); + EXPECT_EQ(expected_struct->x(), passed.x()); + EXPECT_EQ(expected_struct->y(), passed.y()); + loop->Quit(); + }, + &s, &loop)); + loop.Run(); +} + +TEST_P(NativeStructTest, NativeStructWithAttachments) { + mojo::MessagePipe pipe; + const std::string kTestMessage = "hey hi"; + test::TestNativeStructWithAttachments s(kTestMessage, + std::move(pipe.handle0)); + base::RunLoop loop; + proxy()->PassNativeStructWithAttachments( + std::move(s), + base::Bind( + [](const std::string& expected_message, + mojo::ScopedMessagePipeHandle peer_pipe, base::RunLoop* loop, + test::TestNativeStructWithAttachments passed) { + // To ensure that the received pipe handle is functioning, we write + // to its peer and wait for the message to be received. + WriteMessageRaw(peer_pipe.get(), "ping", 4, nullptr, 0, + MOJO_WRITE_MESSAGE_FLAG_NONE); + auto pipe = passed.PassPipe(); + EXPECT_EQ(MOJO_RESULT_OK, + Wait(pipe.get(), MOJO_HANDLE_SIGNAL_READABLE)); + std::vector<uint8_t> bytes; + EXPECT_EQ(MOJO_RESULT_OK, + ReadMessageRaw(pipe.get(), &bytes, nullptr, + MOJO_READ_MESSAGE_FLAG_NONE)); + EXPECT_EQ("ping", std::string(bytes.begin(), bytes.end())); + loop->Quit(); + }, + kTestMessage, base::Passed(&pipe.handle1), &loop)); + loop.Run(); +} + +INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(NativeStructTest); + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/pickle_unittest.cc b/mojo/public/cpp/bindings/tests/pickle_unittest.cc index a5947ce9ed..fd97c574e3 100644 --- a/mojo/public/cpp/bindings/tests/pickle_unittest.cc +++ b/mojo/public/cpp/bindings/tests/pickle_unittest.cc @@ -157,23 +157,40 @@ class PickleTest : public testing::Test { template <typename ProxyType = PicklePasser> InterfacePtr<ProxyType> ConnectToChromiumService() { InterfacePtr<ProxyType> proxy; - InterfaceRequest<ProxyType> request(&proxy); chromium_bindings_.AddBinding( &chromium_service_, - ConvertInterfaceRequest<PicklePasser>(std::move(request))); + ConvertInterfaceRequest<PicklePasser>(mojo::MakeRequest(&proxy))); return proxy; } template <typename ProxyType = blink::PicklePasser> InterfacePtr<ProxyType> ConnectToBlinkService() { InterfacePtr<ProxyType> proxy; - InterfaceRequest<ProxyType> request(&proxy); - blink_bindings_.AddBinding( - &blink_service_, - ConvertInterfaceRequest<blink::PicklePasser>(std::move(request))); + blink_bindings_.AddBinding(&blink_service_, + ConvertInterfaceRequest<blink::PicklePasser>( + mojo::MakeRequest(&proxy))); return proxy; } + protected: + static void ForceMessageSerialization(bool forced) { + // Force messages to be serialized in this test since it intentionally + // exercises StructTraits logic. + Connector::OverrideDefaultSerializationBehaviorForTesting( + forced ? Connector::OutgoingSerializationMode::kEager + : Connector::OutgoingSerializationMode::kLazy, + Connector::IncomingSerializationMode::kDispatchAsIs); + } + + class ScopedForceMessageSerialization { + public: + ScopedForceMessageSerialization() { ForceMessageSerialization(true); } + ~ScopedForceMessageSerialization() { ForceMessageSerialization(false); } + + private: + DISALLOW_COPY_AND_ASSIGN(ScopedForceMessageSerialization); + }; + private: base::MessageLoop loop_; ChromiumPicklePasserImpl chromium_service_; @@ -297,6 +314,7 @@ TEST_F(PickleTest, BlinkProxyToChromiumService) { } TEST_F(PickleTest, PickleArray) { + ScopedForceMessageSerialization force_serialization; auto proxy = ConnectToChromiumService(); auto pickles = std::vector<PickledStructChromium>(2); pickles[0].set_foo(1); @@ -328,6 +346,7 @@ TEST_F(PickleTest, PickleArray) { } TEST_F(PickleTest, PickleArrayArray) { + ScopedForceMessageSerialization force_serialization; auto proxy = ConnectToChromiumService(); auto pickle_arrays = std::vector<std::vector<PickledStructChromium>>(2); for (size_t i = 0; i < 2; ++i) @@ -374,6 +393,7 @@ TEST_F(PickleTest, PickleArrayArray) { } TEST_F(PickleTest, PickleContainer) { + ScopedForceMessageSerialization force_serialization; auto proxy = ConnectToChromiumService(); PickleContainerPtr pickle_container = PickleContainer::New(); pickle_container->f_struct.set_foo(42); diff --git a/mojo/public/cpp/bindings/tests/pickled_types_blink.cc b/mojo/public/cpp/bindings/tests/pickled_types_blink.cc index 7e556507bb..9e56bea90c 100644 --- a/mojo/public/cpp/bindings/tests/pickled_types_blink.cc +++ b/mojo/public/cpp/bindings/tests/pickled_types_blink.cc @@ -25,13 +25,6 @@ PickledStructBlink::~PickledStructBlink() {} namespace IPC { -void ParamTraits<mojo::test::PickledStructBlink>::GetSize( - base::PickleSizer* sizer, - const param_type& p) { - sizer->AddInt(); - sizer->AddInt(); -} - void ParamTraits<mojo::test::PickledStructBlink>::Write(base::Pickle* m, const param_type& p) { m->WriteInt(p.foo()); @@ -51,9 +44,6 @@ bool ParamTraits<mojo::test::PickledStructBlink>::Read( return true; } -#include "ipc/param_traits_size_macros.h" -IPC_ENUM_TRAITS_MAX_VALUE(mojo::test::PickledEnumBlink, - mojo::test::PickledEnumBlink::VALUE_1) #include "ipc/param_traits_write_macros.h" IPC_ENUM_TRAITS_MAX_VALUE(mojo::test::PickledEnumBlink, mojo::test::PickledEnumBlink::VALUE_1) diff --git a/mojo/public/cpp/bindings/tests/pickled_types_blink.h b/mojo/public/cpp/bindings/tests/pickled_types_blink.h index 37e9e70578..fc6bd4e677 100644 --- a/mojo/public/cpp/bindings/tests/pickled_types_blink.h +++ b/mojo/public/cpp/bindings/tests/pickled_types_blink.h @@ -17,7 +17,6 @@ namespace base { class Pickle; class PickleIterator; -class PickleSizer; } namespace mojo { @@ -72,7 +71,6 @@ template <> struct ParamTraits<mojo::test::PickledStructBlink> { using param_type = mojo::test::PickledStructBlink; - static void GetSize(base::PickleSizer* sizer, const param_type& p); static void Write(base::Pickle* m, const param_type& p); static bool Read(const base::Pickle* m, base::PickleIterator* iter, diff --git a/mojo/public/cpp/bindings/tests/pickled_types_chromium.cc b/mojo/public/cpp/bindings/tests/pickled_types_chromium.cc index 9957c9a4d0..aeb0be4555 100644 --- a/mojo/public/cpp/bindings/tests/pickled_types_chromium.cc +++ b/mojo/public/cpp/bindings/tests/pickled_types_chromium.cc @@ -26,13 +26,6 @@ bool operator==(const PickledStructChromium& a, namespace IPC { -void ParamTraits<mojo::test::PickledStructChromium>::GetSize( - base::PickleSizer* sizer, - const param_type& p) { - sizer->AddInt(); - sizer->AddInt(); -} - void ParamTraits<mojo::test::PickledStructChromium>::Write( base::Pickle* m, const param_type& p) { @@ -53,9 +46,6 @@ bool ParamTraits<mojo::test::PickledStructChromium>::Read( return true; } -#include "ipc/param_traits_size_macros.h" -IPC_ENUM_TRAITS_MAX_VALUE(mojo::test::PickledEnumChromium, - mojo::test::PickledEnumChromium::VALUE_2) #include "ipc/param_traits_write_macros.h" IPC_ENUM_TRAITS_MAX_VALUE(mojo::test::PickledEnumChromium, mojo::test::PickledEnumChromium::VALUE_2) diff --git a/mojo/public/cpp/bindings/tests/pickled_types_chromium.h b/mojo/public/cpp/bindings/tests/pickled_types_chromium.h index d9287b62e7..51649c278e 100644 --- a/mojo/public/cpp/bindings/tests/pickled_types_chromium.h +++ b/mojo/public/cpp/bindings/tests/pickled_types_chromium.h @@ -16,7 +16,6 @@ namespace base { class Pickle; class PickleIterator; -class PickleSizer; } namespace mojo { @@ -65,7 +64,6 @@ template <> struct ParamTraits<mojo::test::PickledStructChromium> { using param_type = mojo::test::PickledStructChromium; - static void GetSize(base::PickleSizer* sizer, const param_type& p); static void Write(base::Pickle* m, const param_type& p); static bool Read(const base::Pickle* m, base::PickleIterator* iter, diff --git a/mojo/public/cpp/bindings/tests/report_bad_message_unittest.cc b/mojo/public/cpp/bindings/tests/report_bad_message_unittest.cc index 1bf3f7a4b7..cdd9799d0e 100644 --- a/mojo/public/cpp/bindings/tests/report_bad_message_unittest.cc +++ b/mojo/public/cpp/bindings/tests/report_bad_message_unittest.cc @@ -5,11 +5,11 @@ #include "base/bind.h" #include "base/callback.h" #include "base/macros.h" -#include "base/message_loop/message_loop.h" #include "base/run_loop.h" -#include "mojo/edk/embedder/embedder.h" +#include "mojo/core/embedder/embedder.h" #include "mojo/public/cpp/bindings/binding.h" #include "mojo/public/cpp/bindings/message.h" +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" #include "mojo/public/interfaces/bindings/tests/test_bad_messages.mojom.h" #include "testing/gtest/include/gtest/gtest.h" @@ -26,7 +26,7 @@ class TestBadMessagesImpl : public TestBadMessages { binding_.Bind(std::move(request)); } - const ReportBadMessageCallback& bad_message_callback() { + ReportBadMessageCallback& bad_message_callback() { return bad_message_callback_; } @@ -57,21 +57,20 @@ class TestBadMessagesImpl : public TestBadMessages { DISALLOW_COPY_AND_ASSIGN(TestBadMessagesImpl); }; -class ReportBadMessageTest : public testing::Test { +class ReportBadMessageTest : public BindingsTestBase { public: ReportBadMessageTest() {} void SetUp() override { - mojo::edk::SetDefaultProcessErrorCallback( - base::Bind(&ReportBadMessageTest::OnProcessError, - base::Unretained(this))); + mojo::core::SetDefaultProcessErrorCallback(base::Bind( + &ReportBadMessageTest::OnProcessError, base::Unretained(this))); impl_.BindImpl(MakeRequest(&proxy_)); } void TearDown() override { - mojo::edk::SetDefaultProcessErrorCallback( - mojo::edk::ProcessErrorCallback()); + mojo::core::SetDefaultProcessErrorCallback( + mojo::core::ProcessErrorCallback()); } TestBadMessages* proxy() { return proxy_.get(); } @@ -91,10 +90,9 @@ class ReportBadMessageTest : public testing::Test { TestBadMessagesPtr proxy_; TestBadMessagesImpl impl_; base::Closure error_handler_; - base::MessageLoop message_loop; }; -TEST_F(ReportBadMessageTest, Request) { +TEST_P(ReportBadMessageTest, Request) { // Verify that basic immediate error reporting works. bool error = false; SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error)); @@ -102,7 +100,7 @@ TEST_F(ReportBadMessageTest, Request) { EXPECT_TRUE(error); } -TEST_F(ReportBadMessageTest, RequestAsync) { +TEST_P(ReportBadMessageTest, RequestAsync) { bool error = false; SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error)); @@ -115,11 +113,11 @@ TEST_F(ReportBadMessageTest, RequestAsync) { // Now we can run the callback and it should trigger a bad message report. DCHECK(!impl()->bad_message_callback().is_null()); - impl()->bad_message_callback().Run("bad!"); + std::move(impl()->bad_message_callback()).Run("bad!"); EXPECT_TRUE(error); } -TEST_F(ReportBadMessageTest, Response) { +TEST_P(ReportBadMessageTest, Response) { bool error = false; SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error)); @@ -137,7 +135,7 @@ TEST_F(ReportBadMessageTest, Response) { EXPECT_TRUE(error); } -TEST_F(ReportBadMessageTest, ResponseAsync) { +TEST_P(ReportBadMessageTest, ResponseAsync) { bool error = false; SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error)); @@ -157,11 +155,12 @@ TEST_F(ReportBadMessageTest, ResponseAsync) { // Invoking this callback should report a bad message and trigger the error // handler immediately. - bad_message_callback.Run("this message is bad and should feel bad"); + std::move(bad_message_callback) + .Run("this message is bad and should feel bad"); EXPECT_TRUE(error); } -TEST_F(ReportBadMessageTest, ResponseSync) { +TEST_P(ReportBadMessageTest, ResponseSync) { bool error = false; SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error)); @@ -173,7 +172,7 @@ TEST_F(ReportBadMessageTest, ResponseSync) { EXPECT_TRUE(error); } -TEST_F(ReportBadMessageTest, ResponseSyncDeferred) { +TEST_P(ReportBadMessageTest, ResponseSyncDeferred) { bool error = false; SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error)); @@ -185,10 +184,12 @@ TEST_F(ReportBadMessageTest, ResponseSyncDeferred) { } EXPECT_FALSE(error); - bad_message_callback.Run("nope nope nope"); + std::move(bad_message_callback).Run("nope nope nope"); EXPECT_TRUE(error); } +INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(ReportBadMessageTest); + } // namespace } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/request_response_unittest.cc b/mojo/public/cpp/bindings/tests/request_response_unittest.cc index 43b8f0dc90..7293a5c4da 100644 --- a/mojo/public/cpp/bindings/tests/request_response_unittest.cc +++ b/mojo/public/cpp/bindings/tests/request_response_unittest.cc @@ -5,9 +5,9 @@ #include <stdint.h> #include <utility> -#include "base/message_loop/message_loop.h" #include "base/run_loop.h" #include "mojo/public/cpp/bindings/binding.h" +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" #include "mojo/public/cpp/test_support/test_utils.h" #include "mojo/public/interfaces/bindings/tests/sample_import.mojom.h" #include "mojo/public/interfaces/bindings/tests/sample_interfaces.mojom.h" @@ -82,18 +82,15 @@ void RecordEnum(sample::Enum* storage, closure.Run(); } -class RequestResponseTest : public testing::Test { +class RequestResponseTest : public BindingsTestBase { public: RequestResponseTest() {} ~RequestResponseTest() override { base::RunLoop().RunUntilIdle(); } void PumpMessages() { base::RunLoop().RunUntilIdle(); } - - private: - base::MessageLoop loop_; }; -TEST_F(RequestResponseTest, EchoString) { +TEST_P(RequestResponseTest, EchoString) { sample::ProviderPtr provider; ProviderImpl provider_impl(MakeRequest(&provider)); @@ -107,7 +104,7 @@ TEST_F(RequestResponseTest, EchoString) { EXPECT_EQ(std::string("hello"), buf); } -TEST_F(RequestResponseTest, EchoStrings) { +TEST_P(RequestResponseTest, EchoStrings) { sample::ProviderPtr provider; ProviderImpl provider_impl(MakeRequest(&provider)); @@ -121,7 +118,7 @@ TEST_F(RequestResponseTest, EchoStrings) { EXPECT_EQ(std::string("hello world"), buf); } -TEST_F(RequestResponseTest, EchoMessagePipeHandle) { +TEST_P(RequestResponseTest, EchoMessagePipeHandle) { sample::ProviderPtr provider; ProviderImpl provider_impl(MakeRequest(&provider)); @@ -139,7 +136,7 @@ TEST_F(RequestResponseTest, EchoMessagePipeHandle) { EXPECT_EQ(std::string("hello"), value); } -TEST_F(RequestResponseTest, EchoEnum) { +TEST_P(RequestResponseTest, EchoEnum) { sample::ProviderPtr provider; ProviderImpl provider_impl(MakeRequest(&provider)); @@ -152,6 +149,8 @@ TEST_F(RequestResponseTest, EchoEnum) { EXPECT_EQ(sample::Enum::VALUE, value); } +INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(RequestResponseTest); + } // namespace } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/router_test_util.cc b/mojo/public/cpp/bindings/tests/router_test_util.cc index 9bab1cb360..b36de55c41 100644 --- a/mojo/public/cpp/bindings/tests/router_test_util.cc +++ b/mojo/public/cpp/bindings/tests/router_test_util.cc @@ -8,7 +8,7 @@ #include <stdint.h> #include <string.h> -#include "mojo/public/cpp/bindings/lib/message_builder.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/bindings/tests/message_queue.h" #include "testing/gtest/include/gtest/gtest.h" @@ -17,10 +17,10 @@ namespace test { void AllocRequestMessage(uint32_t name, const char* text, Message* message) { size_t payload_size = strlen(text) + 1; // Plus null terminator. - internal::MessageBuilder builder(name, Message::kFlagExpectsResponse, - payload_size, 0); - memcpy(builder.buffer()->Allocate(payload_size), text, payload_size); - *message = std::move(*builder.message()); + *message = + Message(name, Message::kFlagExpectsResponse, payload_size, 0, nullptr); + memcpy(message->payload_buffer()->AllocateAndGet(payload_size), text, + payload_size); } void AllocResponseMessage(uint32_t name, @@ -28,11 +28,10 @@ void AllocResponseMessage(uint32_t name, uint64_t request_id, Message* message) { size_t payload_size = strlen(text) + 1; // Plus null terminator. - internal::MessageBuilder builder(name, Message::kFlagIsResponse, payload_size, - 0); - builder.message()->set_request_id(request_id); - memcpy(builder.buffer()->Allocate(payload_size), text, payload_size); - *message = std::move(*builder.message()); + *message = Message(name, Message::kFlagIsResponse, payload_size, 0, nullptr); + message->set_request_id(request_id); + memcpy(message->payload_buffer()->AllocateAndGet(payload_size), text, + payload_size); } MessageAccumulator::MessageAccumulator(MessageQueue* queue, @@ -64,7 +63,7 @@ bool ResponseGenerator::AcceptWithResponder( bool result = SendResponse(message->name(), message->request_id(), reinterpret_cast<const char*>(message->payload()), responder.get()); - EXPECT_TRUE(responder->IsValid()); + EXPECT_TRUE(responder->IsConnected()); return result; } diff --git a/mojo/public/cpp/bindings/tests/router_test_util.h b/mojo/public/cpp/bindings/tests/router_test_util.h index dd6aff63da..6d18a032b1 100644 --- a/mojo/public/cpp/bindings/tests/router_test_util.h +++ b/mojo/public/cpp/bindings/tests/router_test_util.h @@ -64,7 +64,7 @@ class LazyResponseGenerator : public ResponseGenerator { bool has_responder() const { return !!responder_; } - bool responder_is_valid() const { return responder_->IsValid(); } + bool responder_is_valid() const { return responder_->IsConnected(); } void set_closure(const base::Closure& closure) { closure_ = closure; } diff --git a/mojo/public/cpp/bindings/tests/sample_service_unittest.cc b/mojo/public/cpp/bindings/tests/sample_service_unittest.cc index 1f95a27a5e..9762a8bf4b 100644 --- a/mojo/public/cpp/bindings/tests/sample_service_unittest.cc +++ b/mojo/public/cpp/bindings/tests/sample_service_unittest.cc @@ -9,6 +9,7 @@ #include <string> #include <utility> +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" #include "mojo/public/interfaces/bindings/tests/sample_service.mojom.h" #include "testing/gtest/include/gtest/gtest.h" @@ -56,7 +57,7 @@ FooPtr MakeFoo() { for (size_t i = 0; i < input_streams.size(); ++i) { MojoCreateDataPipeOptions options; options.struct_size = sizeof(MojoCreateDataPipeOptions); - options.flags = MOJO_CREATE_DATA_PIPE_OPTIONS_FLAG_NONE; + options.flags = MOJO_CREATE_DATA_PIPE_FLAG_NONE; options.element_num_bytes = 1; options.capacity_num_bytes = 1024; mojo::ScopedDataPipeProducerHandle producer; @@ -225,8 +226,8 @@ void Print(int depth, const char* name, const FooPtr& foo) { } } -void DumpHex(const uint8_t* bytes, uint32_t num_bytes) { - for (uint32_t i = 0; i < num_bytes; ++i) { +void DumpHex(const uint8_t* bytes, size_t num_bytes) { + for (size_t i = 0; i < num_bytes; ++i) { std::cout << std::setw(2) << std::setfill('0') << std::hex << uint32_t(bytes[i]); @@ -278,6 +279,8 @@ class ServiceProxyImpl : public ServiceProxy { class SimpleMessageReceiver : public mojo::MessageReceiverWithResponder { public: + bool PrefersSerializedMessages() override { return true; } + bool Accept(mojo::Message* message) override { // Imagine some IPC happened here. @@ -302,9 +305,9 @@ class SimpleMessageReceiver : public mojo::MessageReceiverWithResponder { } }; -using BindingsSampleTest = testing::Test; +using BindingsSampleTest = mojo::BindingsTestBase; -TEST_F(BindingsSampleTest, Basic) { +TEST_P(BindingsSampleTest, Basic) { SimpleMessageReceiver receiver; // User has a proxy to a Service somehow. @@ -326,7 +329,7 @@ TEST_F(BindingsSampleTest, Basic) { delete service; } -TEST_F(BindingsSampleTest, DefaultValues) { +TEST_P(BindingsSampleTest, DefaultValues) { DefaultsTestPtr defaults(DefaultsTest::New()); EXPECT_EQ(-12, defaults->a0); EXPECT_EQ(kTwelve, defaults->a1); @@ -358,5 +361,7 @@ TEST_F(BindingsSampleTest, DefaultValues) { EXPECT_EQ(-0x123456789, defaults->a25); } +INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(BindingsSampleTest); + } // namespace } // namespace sample diff --git a/mojo/public/cpp/bindings/tests/serialization_warning_unittest.cc b/mojo/public/cpp/bindings/tests/serialization_warning_unittest.cc index 275f10f9e7..37aaff32d9 100644 --- a/mojo/public/cpp/bindings/tests/serialization_warning_unittest.cc +++ b/mojo/public/cpp/bindings/tests/serialization_warning_unittest.cc @@ -51,11 +51,11 @@ class SerializationWarningTest : public testing::Test { warning_observer_.set_last_warning(mojo::internal::VALIDATION_ERROR_NONE); mojo::internal::SerializationContext context; - mojo::internal::FixedBufferForTesting buf( - mojo::internal::PrepareToSerialize<MojomType>(obj, &context)); - typename mojo::internal::MojomTypeTraits<MojomType>::Data* data; - mojo::internal::Serialize<MojomType>(obj, &buf, &data, &context); - + mojo::Message message(0, 0, 0, 0, nullptr); + typename mojo::internal::MojomTypeTraits<MojomType>::Data::BufferWriter + writer; + mojo::internal::Serialize<MojomType>(obj, message.payload_buffer(), &writer, + &context); EXPECT_EQ(expected_warning, warning_observer_.last_warning()); } @@ -66,12 +66,11 @@ class SerializationWarningTest : public testing::Test { warning_observer_.set_last_warning(mojo::internal::VALIDATION_ERROR_NONE); mojo::internal::SerializationContext context; - mojo::internal::FixedBufferForTesting buf( - mojo::internal::PrepareToSerialize<MojomType>(obj, &context)); - typename mojo::internal::MojomTypeTraits<MojomType>::Data* data; - mojo::internal::Serialize<MojomType>(obj, &buf, &data, validate_params, - &context); - + mojo::Message message(0, 0, 0, 0, nullptr); + typename mojo::internal::MojomTypeTraits<MojomType>::Data::BufferWriter + writer; + mojo::internal::Serialize<MojomType>(obj, message.payload_buffer(), &writer, + validate_params, &context); EXPECT_EQ(expected_warning, warning_observer_.last_warning()); } @@ -83,10 +82,11 @@ class SerializationWarningTest : public testing::Test { warning_observer_.set_last_warning(mojo::internal::VALIDATION_ERROR_NONE); mojo::internal::SerializationContext context; - mojo::internal::FixedBufferForTesting buf( - mojo::internal::PrepareToSerialize<MojomType>(obj, false, &context)); - typename mojo::internal::MojomTypeTraits<MojomType>::Data* data; - mojo::internal::Serialize<MojomType>(obj, &buf, &data, false, &context); + mojo::Message message(0, 0, 0, 0, nullptr); + typename mojo::internal::MojomTypeTraits<MojomType>::Data::BufferWriter + writer; + mojo::internal::Serialize<MojomType>(obj, message.payload_buffer(), &writer, + false, &context); EXPECT_EQ(expected_warning, warning_observer_.last_warning()); } diff --git a/mojo/public/cpp/bindings/tests/struct_traits_unittest.cc b/mojo/public/cpp/bindings/tests/struct_traits_unittest.cc index 77b448a215..f74aa5acff 100644 --- a/mojo/public/cpp/bindings/tests/struct_traits_unittest.cc +++ b/mojo/public/cpp/bindings/tests/struct_traits_unittest.cc @@ -135,7 +135,9 @@ class StructTraitsTest : public testing::Test, } TraitsTestServicePtr GetTraitsTestProxy() { - return traits_test_bindings_.CreateInterfacePtrAndBind(this); + TraitsTestServicePtr proxy; + traits_test_bindings_.AddBinding(this, mojo::MakeRequest(&proxy)); + return proxy; } private: @@ -394,13 +396,12 @@ TEST_F(StructTraitsTest, EchoMoveOnlyStructWithTraits) { EXPECT_EQ(MOJO_RESULT_OK, Wait(received.get(), MOJO_HANDLE_SIGNAL_READABLE)); - char buffer[10] = {0}; - uint32_t buffer_size = static_cast<uint32_t>(sizeof(buffer)); - EXPECT_EQ(MOJO_RESULT_OK, - ReadMessageRaw(received.get(), buffer, &buffer_size, nullptr, - nullptr, MOJO_READ_MESSAGE_FLAG_NONE)); - EXPECT_EQ(kHelloSize, buffer_size); - EXPECT_STREQ(kHello, buffer); + std::vector<uint8_t> bytes; + std::vector<ScopedHandle> handles; + EXPECT_EQ(MOJO_RESULT_OK, ReadMessageRaw(received.get(), &bytes, &handles, + MOJO_READ_MESSAGE_FLAG_NONE)); + EXPECT_EQ(kHelloSize, bytes.size()); + EXPECT_STREQ(kHello, reinterpret_cast<char*>(bytes.data())); } void CaptureNullableMoveOnlyStructWithTraitsImpl( @@ -490,8 +491,8 @@ TEST_F(StructTraitsTest, TypemapUniquePtr) { { base::RunLoop loop; proxy->EchoStructWithTraitsForUniquePtr( - base::MakeUnique<int>(12345), - base::Bind(&ExpectUniquePtr, base::Passed(base::MakeUnique<int>(12345)), + std::make_unique<int>(12345), + base::Bind(&ExpectUniquePtr, base::Passed(std::make_unique<int>(12345)), loop.QuitClosure())); loop.Run(); } @@ -549,5 +550,10 @@ TEST_F(StructTraitsTest, EchoUnionWithTraits) { } } +TEST_F(StructTraitsTest, DefaultValueOfEnumWithTraits) { + auto container = EnumWithTraitsContainer::New(); + EXPECT_EQ(EnumWithTraitsImpl::CUSTOM_VALUE_1, container->f_field); +} + } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/struct_unittest.cc b/mojo/public/cpp/bindings/tests/struct_unittest.cc index a687052706..c6bd169f06 100644 --- a/mojo/public/cpp/bindings/tests/struct_unittest.cc +++ b/mojo/public/cpp/bindings/tests/struct_unittest.cc @@ -28,6 +28,33 @@ void CheckRect(const Rect& rect, int32_t factor = 1) { EXPECT_EQ(20 * factor, rect.height); } +template <typename StructType> +struct SerializeStructHelperTraits { + using DataView = typename StructType::DataView; +}; + +template <> +struct SerializeStructHelperTraits<native::NativeStruct> { + using DataView = native::NativeStructDataView; +}; + +template <typename InputType, typename DataType> +size_t SerializeStruct(InputType& input, + mojo::Message* message, + mojo::internal::SerializationContext* context, + DataType** out_data) { + using StructType = typename InputType::Struct; + using DataViewType = + typename SerializeStructHelperTraits<StructType>::DataView; + *message = mojo::Message(0, 0, 0, 0, nullptr); + const size_t payload_start = message->payload_buffer()->cursor(); + typename DataType::BufferWriter writer; + mojo::internal::Serialize<DataViewType>(input, message->payload_buffer(), + &writer, context); + *out_data = writer.is_null() ? nullptr : writer.data(); + return message->payload_buffer()->cursor() - payload_start; +} + MultiVersionStructPtr MakeMultiVersionStruct() { MessagePipe pipe; return MultiVersionStruct::New(123, MakeRect(5), std::string("hello"), @@ -45,19 +72,18 @@ U SerializeAndDeserialize(T input) { using OutputDataType = typename mojo::internal::MojomTypeTraits<OutputMojomType>::Data*; + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = - mojo::internal::PrepareToSerialize<InputMojomType>(input, &context); - mojo::internal::FixedBufferForTesting buf(size + 32); InputDataType data; - mojo::internal::Serialize<InputMojomType>(input, &buf, &data, &context); + SerializeStruct(input, &message, &context, &data); // Set the subsequent area to a special value, so that we can find out if we // mistakenly access the area. - void* subsequent_area = buf.Allocate(32); + void* subsequent_area = message.payload_buffer()->AllocateAndGet(32); memset(subsequent_area, 0xAA, 32); - OutputDataType output_data = reinterpret_cast<OutputDataType>(data); + OutputDataType output_data = + reinterpret_cast<OutputDataType>(message.mutable_payload()); U output; mojo::internal::Deserialize<OutputMojomType>(output_data, &output, &context); @@ -124,15 +150,13 @@ TEST_F(StructTest, Clone) { TEST_F(StructTest, Serialization_Basic) { RectPtr rect(MakeRect()); - size_t size = mojo::internal::PrepareToSerialize<RectDataView>(rect, nullptr); - EXPECT_EQ(8U + 16U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::Rect_Data* data; - mojo::internal::Serialize<RectDataView>(rect, &buf, &data, nullptr); + EXPECT_EQ(8U + 16U, SerializeStruct(rect, &message, &context, &data)); RectPtr rect2; - mojo::internal::Deserialize<RectDataView>(data, &rect2, nullptr); + mojo::internal::Deserialize<RectDataView>(data, &rect2, &context); CheckRect(*rect2); } @@ -155,16 +179,14 @@ TEST_F(StructTest, Construction_StructPointers) { TEST_F(StructTest, Serialization_StructPointers) { RectPairPtr pair(RectPair::New(MakeRect(), MakeRect())); - size_t size = - mojo::internal::PrepareToSerialize<RectPairDataView>(pair, nullptr); - EXPECT_EQ(8U + 16U + 2 * (8U + 16U), size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::RectPair_Data* data; - mojo::internal::Serialize<RectPairDataView>(pair, &buf, &data, nullptr); + EXPECT_EQ(8U + 16U + 2 * (8U + 16U), + SerializeStruct(pair, &message, &context, &data)); RectPairPtr pair2; - mojo::internal::Deserialize<RectPairDataView>(data, &pair2, nullptr); + mojo::internal::Deserialize<RectPairDataView>(data, &pair2, &context); CheckRect(*pair2->first); CheckRect(*pair2->second); @@ -179,8 +201,9 @@ TEST_F(StructTest, Serialization_ArrayPointers) { NamedRegionPtr region( NamedRegion::New(std::string("region"), std::move(rects))); - size_t size = - mojo::internal::PrepareToSerialize<NamedRegionDataView>(region, nullptr); + mojo::Message message; + mojo::internal::SerializationContext context; + internal::NamedRegion_Data* data; EXPECT_EQ(8U + // header 8U + // name pointer 8U + // rects pointer @@ -190,14 +213,10 @@ TEST_F(StructTest, Serialization_ArrayPointers) { 4 * 8U + // rects payload (four pointers) 4 * (8U + // rect header 16U), // rect payload (four ints) - size); - - mojo::internal::FixedBufferForTesting buf(size); - internal::NamedRegion_Data* data; - mojo::internal::Serialize<NamedRegionDataView>(region, &buf, &data, nullptr); + SerializeStruct(region, &message, &context, &data)); NamedRegionPtr region2; - mojo::internal::Deserialize<NamedRegionDataView>(data, ®ion2, nullptr); + mojo::internal::Deserialize<NamedRegionDataView>(data, ®ion2, &context); EXPECT_EQ("region", *region2->name); @@ -212,19 +231,16 @@ TEST_F(StructTest, Serialization_NullArrayPointers) { EXPECT_FALSE(region->name); EXPECT_FALSE(region->rects); - size_t size = - mojo::internal::PrepareToSerialize<NamedRegionDataView>(region, nullptr); + mojo::Message message; + mojo::internal::SerializationContext context; + internal::NamedRegion_Data* data; EXPECT_EQ(8U + // header 8U + // name pointer 8U, // rects pointer - size); - - mojo::internal::FixedBufferForTesting buf(size); - internal::NamedRegion_Data* data; - mojo::internal::Serialize<NamedRegionDataView>(region, &buf, &data, nullptr); + SerializeStruct(region, &message, &context, &data)); NamedRegionPtr region2; - mojo::internal::Deserialize<NamedRegionDataView>(data, ®ion2, nullptr); + mojo::internal::Deserialize<NamedRegionDataView>(data, ®ion2, &context); EXPECT_FALSE(region2->name); EXPECT_FALSE(region2->rects); @@ -360,70 +376,57 @@ TEST_F(StructTest, Versioning_NewToOld) { // Serialization test for native struct. TEST_F(StructTest, Serialization_NativeStruct) { - using Data = mojo::internal::NativeStruct_Data; + using Data = native::internal::NativeStruct_Data; { // Serialization of a null native struct. - NativeStructPtr native; - size_t size = mojo::internal::PrepareToSerialize<NativeStructDataView>( - native, nullptr); - EXPECT_EQ(0u, size); - mojo::internal::FixedBufferForTesting buf(size); + native::NativeStructPtr native; + mojo::Message message; + mojo::internal::SerializationContext context; Data* data = nullptr; - mojo::internal::Serialize<NativeStructDataView>(std::move(native), &buf, - &data, nullptr); - + EXPECT_EQ(0u, SerializeStruct(native, &message, &context, &data)); EXPECT_EQ(nullptr, data); - NativeStructPtr output_native; - mojo::internal::Deserialize<NativeStructDataView>(data, &output_native, - nullptr); + native::NativeStructPtr output_native; + mojo::internal::Deserialize<native::NativeStructDataView>( + data, &output_native, &context); EXPECT_TRUE(output_native.is_null()); } { // Serialization of a native struct with null data. - NativeStructPtr native(NativeStruct::New()); - size_t size = mojo::internal::PrepareToSerialize<NativeStructDataView>( - native, nullptr); - EXPECT_EQ(0u, size); - mojo::internal::FixedBufferForTesting buf(size); + native::NativeStructPtr native(native::NativeStruct::New()); + mojo::Message message; + mojo::internal::SerializationContext context; Data* data = nullptr; - mojo::internal::Serialize<NativeStructDataView>(std::move(native), &buf, - &data, nullptr); + EXPECT_EQ(32u, SerializeStruct(native, &message, &context, &data)); + EXPECT_EQ(0u, data->data.Get()->size()); - EXPECT_EQ(nullptr, data); - - NativeStructPtr output_native; - mojo::internal::Deserialize<NativeStructDataView>(data, &output_native, - nullptr); - EXPECT_TRUE(output_native.is_null()); + native::NativeStructPtr output_native; + mojo::internal::Deserialize<native::NativeStructDataView>( + data, &output_native, &context); + EXPECT_TRUE(output_native->data.empty()); } { - NativeStructPtr native(NativeStruct::New()); + native::NativeStructPtr native(native::NativeStruct::New()); native->data = std::vector<uint8_t>{'X', 'Y'}; - size_t size = mojo::internal::PrepareToSerialize<NativeStructDataView>( - native, nullptr); - EXPECT_EQ(16u, size); - mojo::internal::FixedBufferForTesting buf(size); - + mojo::Message message; + mojo::internal::SerializationContext context; Data* data = nullptr; - mojo::internal::Serialize<NativeStructDataView>(std::move(native), &buf, - &data, nullptr); - - EXPECT_NE(nullptr, data); + EXPECT_EQ(40u, SerializeStruct(native, &message, &context, &data)); + EXPECT_EQ(2u, data->data.Get()->size()); - NativeStructPtr output_native; - mojo::internal::Deserialize<NativeStructDataView>(data, &output_native, - nullptr); + native::NativeStructPtr output_native; + mojo::internal::Deserialize<native::NativeStructDataView>( + data, &output_native, &context); ASSERT_TRUE(output_native); - ASSERT_FALSE(output_native->data->empty()); - EXPECT_EQ(2u, output_native->data->size()); - EXPECT_EQ('X', (*output_native->data)[0]); - EXPECT_EQ('Y', (*output_native->data)[1]); + ASSERT_FALSE(output_native->data.empty()); + EXPECT_EQ(2u, output_native->data.size()); + EXPECT_EQ('X', output_native->data[0]); + EXPECT_EQ('Y', output_native->data[1]); } } diff --git a/mojo/public/cpp/bindings/tests/struct_with_traits.typemap b/mojo/public/cpp/bindings/tests/struct_with_traits.typemap index 752ce44b58..fccaf2b486 100644 --- a/mojo/public/cpp/bindings/tests/struct_with_traits.typemap +++ b/mojo/public/cpp/bindings/tests/struct_with_traits.typemap @@ -18,9 +18,12 @@ deps = [ type_mappings = [ "mojo.test.EnumWithTraits=mojo::test::EnumWithTraitsImpl", "mojo.test.StructWithTraits=mojo::test::StructWithTraitsImpl", + "mojo.test.StructWithUnreachableTraits=mojo::test::StructWithUnreachableTraitsImpl", "mojo.test.NestedStructWithTraits=mojo::test::NestedStructWithTraitsImpl", "mojo.test.TrivialStructWithTraits=mojo::test::TrivialStructWithTraitsImpl[copyable_pass_by_value]", "mojo.test.MoveOnlyStructWithTraits=mojo::test::MoveOnlyStructWithTraitsImpl[move_only]", "mojo.test.StructWithTraitsForUniquePtr=std::unique_ptr<int>[move_only,nullable_is_same_type]", "mojo.test.UnionWithTraits=std::unique_ptr<mojo::test::UnionWithTraitsBase>[move_only,nullable_is_same_type]", + "mojo.test.StructForceSerialize=mojo::test::StructForceSerializeImpl[force_serialize]", + "mojo.test.StructNestedForceSerialize=mojo::test::StructNestedForceSerializeImpl", ] diff --git a/mojo/public/cpp/bindings/tests/struct_with_traits_impl.cc b/mojo/public/cpp/bindings/tests/struct_with_traits_impl.cc index cbdd4bfde7..e537830a77 100644 --- a/mojo/public/cpp/bindings/tests/struct_with_traits_impl.cc +++ b/mojo/public/cpp/bindings/tests/struct_with_traits_impl.cc @@ -7,30 +7,38 @@ namespace mojo { namespace test { -NestedStructWithTraitsImpl::NestedStructWithTraitsImpl() {} +NestedStructWithTraitsImpl::NestedStructWithTraitsImpl() = default; NestedStructWithTraitsImpl::NestedStructWithTraitsImpl(int32_t in_value) : value(in_value) {} -StructWithTraitsImpl::StructWithTraitsImpl() {} +StructWithTraitsImpl::StructWithTraitsImpl() = default; -StructWithTraitsImpl::~StructWithTraitsImpl() {} +StructWithTraitsImpl::~StructWithTraitsImpl() = default; StructWithTraitsImpl::StructWithTraitsImpl(const StructWithTraitsImpl& other) = default; -MoveOnlyStructWithTraitsImpl::MoveOnlyStructWithTraitsImpl() {} +MoveOnlyStructWithTraitsImpl::MoveOnlyStructWithTraitsImpl() = default; MoveOnlyStructWithTraitsImpl::MoveOnlyStructWithTraitsImpl( MoveOnlyStructWithTraitsImpl&& other) = default; -MoveOnlyStructWithTraitsImpl::~MoveOnlyStructWithTraitsImpl() {} +MoveOnlyStructWithTraitsImpl::~MoveOnlyStructWithTraitsImpl() = default; MoveOnlyStructWithTraitsImpl& MoveOnlyStructWithTraitsImpl::operator=( MoveOnlyStructWithTraitsImpl&& other) = default; -UnionWithTraitsInt32::~UnionWithTraitsInt32() {} +UnionWithTraitsInt32::~UnionWithTraitsInt32() = default; -UnionWithTraitsStruct::~UnionWithTraitsStruct() {} +UnionWithTraitsStruct::~UnionWithTraitsStruct() = default; + +StructForceSerializeImpl::StructForceSerializeImpl() = default; + +StructForceSerializeImpl::~StructForceSerializeImpl() = default; + +StructNestedForceSerializeImpl::StructNestedForceSerializeImpl() = default; + +StructNestedForceSerializeImpl::~StructNestedForceSerializeImpl() = default; } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/struct_with_traits_impl.h b/mojo/public/cpp/bindings/tests/struct_with_traits_impl.h index 7b007cc083..e8d96120a2 100644 --- a/mojo/public/cpp/bindings/tests/struct_with_traits_impl.h +++ b/mojo/public/cpp/bindings/tests/struct_with_traits_impl.h @@ -98,6 +98,14 @@ class StructWithTraitsImpl { std::map<std::string, NestedStructWithTraitsImpl> struct_map_; }; +// A type which corresponds nominally to the +// mojo::test::StructWithUnreachableTraits mojom type. Used to test that said +// type is never serialized, i.e. objects of this type are simply copied into +// a message as-is when written to an intra-process interface. +struct StructWithUnreachableTraitsImpl { + int32_t magic_number = 0; +}; + // A type which knows how to look like a mojo::test::TrivialStructWithTraits // mojom type by way of mojo::StructTraits. struct TrivialStructWithTraitsImpl { @@ -162,6 +170,46 @@ class UnionWithTraitsStruct : public UnionWithTraitsBase { NestedStructWithTraitsImpl struct_; }; +class StructForceSerializeImpl { + public: + StructForceSerializeImpl(); + ~StructForceSerializeImpl(); + + void set_value(int32_t value) { value_ = value; } + int32_t value() const { return value_; } + + void set_was_serialized() const { was_serialized_ = true; } + bool was_serialized() const { return was_serialized_; } + + void set_was_deserialized() { was_deserialized_ = true; } + bool was_deserialized() const { return was_deserialized_; } + + private: + int32_t value_ = 0; + mutable bool was_serialized_ = false; + bool was_deserialized_ = false; +}; + +class StructNestedForceSerializeImpl { + public: + StructNestedForceSerializeImpl(); + ~StructNestedForceSerializeImpl(); + + StructForceSerializeImpl& force() { return force_; } + const StructForceSerializeImpl& force() const { return force_; } + + void set_was_serialized() const { was_serialized_ = true; } + bool was_serialized() const { return was_serialized_; } + + void set_was_deserialized() { was_deserialized_ = true; } + bool was_deserialized() const { return was_deserialized_; } + + private: + StructForceSerializeImpl force_; + mutable bool was_serialized_ = false; + bool was_deserialized_ = false; +}; + } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/struct_with_traits_impl_traits.cc b/mojo/public/cpp/bindings/tests/struct_with_traits_impl_traits.cc index 6b770b1a49..2586d8d0d4 100644 --- a/mojo/public/cpp/bindings/tests/struct_with_traits_impl_traits.cc +++ b/mojo/public/cpp/bindings/tests/struct_with_traits_impl_traits.cc @@ -5,39 +5,11 @@ #include "mojo/public/cpp/bindings/tests/struct_with_traits_impl_traits.h" namespace mojo { -namespace { - -struct Context { - int32_t value; -}; - -} // namespace - -// static -void* StructTraits<test::NestedStructWithTraitsDataView, - test::NestedStructWithTraitsImpl>:: - SetUpContext(const test::NestedStructWithTraitsImpl& input) { - Context* context = new Context; - context->value = input.value; - return context; -} - -// static -void StructTraits<test::NestedStructWithTraitsDataView, - test::NestedStructWithTraitsImpl>:: - TearDownContext(const test::NestedStructWithTraitsImpl& input, - void* context) { - Context* context_obj = static_cast<Context*>(context); - CHECK_EQ(context_obj->value, input.value); - delete context_obj; -} // static int32_t StructTraits<test::NestedStructWithTraitsDataView, test::NestedStructWithTraitsImpl>:: - value(const test::NestedStructWithTraitsImpl& input, void* context) { - Context* context_obj = static_cast<Context*>(context); - CHECK_EQ(context_obj->value, input.value); + value(const test::NestedStructWithTraitsImpl& input) { return input.value; } diff --git a/mojo/public/cpp/bindings/tests/struct_with_traits_impl_traits.h b/mojo/public/cpp/bindings/tests/struct_with_traits_impl_traits.h index adcad8aa9e..69d344f7df 100644 --- a/mojo/public/cpp/bindings/tests/struct_with_traits_impl_traits.h +++ b/mojo/public/cpp/bindings/tests/struct_with_traits_impl_traits.h @@ -20,12 +20,7 @@ namespace mojo { template <> struct StructTraits<test::NestedStructWithTraitsDataView, test::NestedStructWithTraitsImpl> { - static void* SetUpContext(const test::NestedStructWithTraitsImpl& input); - static void TearDownContext(const test::NestedStructWithTraitsImpl& input, - void* context); - - static int32_t value(const test::NestedStructWithTraitsImpl& input, - void* context); + static int32_t value(const test::NestedStructWithTraitsImpl& input); static bool Read(test::NestedStructWithTraitsDataView data, test::NestedStructWithTraitsImpl* output); @@ -99,6 +94,22 @@ struct StructTraits<test::StructWithTraitsDataView, }; template <> +struct StructTraits<test::StructWithUnreachableTraitsDataView, + test::StructWithUnreachableTraitsImpl> { + public: + static bool ignore_me(const test::StructWithUnreachableTraitsImpl& input) { + NOTREACHED(); + return false; + } + + static bool Read(test::StructWithUnreachableTraitsDataView data, + test::StructWithUnreachableTraitsImpl* out) { + NOTREACHED(); + return false; + } +}; + +template <> struct StructTraits<test::TrivialStructWithTraitsDataView, test::TrivialStructWithTraitsImpl> { // Deserialization to test::TrivialStructTraitsImpl. @@ -191,6 +202,40 @@ struct UnionTraits<test::UnionWithTraitsDataView, } }; +template <> +struct StructTraits<test::StructForceSerializeDataView, + test::StructForceSerializeImpl> { + static int32_t value(const test::StructForceSerializeImpl& impl) { + impl.set_was_serialized(); + return impl.value(); + } + + static bool Read(test::StructForceSerializeDataView data, + test::StructForceSerializeImpl* out) { + out->set_value(data.value()); + out->set_was_deserialized(); + return true; + } +}; + +template <> +struct StructTraits<test::StructNestedForceSerializeDataView, + test::StructNestedForceSerializeImpl> { + static const test::StructForceSerializeImpl& force( + const test::StructNestedForceSerializeImpl& impl) { + impl.set_was_serialized(); + return impl.force(); + } + + static bool Read(test::StructNestedForceSerializeDataView data, + test::StructNestedForceSerializeImpl* out) { + if (!data.ReadForce(&out->force())) + return false; + out->set_was_deserialized(); + return true; + } +}; + } // namespace mojo #endif // MOJO_PUBLIC_CPP_BINDINGS_TESTS_STRUCT_WITH_TRAITS_IMPL_TRAITS_H_ diff --git a/mojo/public/cpp/bindings/tests/sync_handle_registry_unittest.cc b/mojo/public/cpp/bindings/tests/sync_handle_registry_unittest.cc new file mode 100644 index 0000000000..2f17fc7fbf --- /dev/null +++ b/mojo/public/cpp/bindings/tests/sync_handle_registry_unittest.cc @@ -0,0 +1,258 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <memory> + +#include "mojo/public/cpp/bindings/sync_handle_registry.h" +#include "base/bind.h" +#include "base/memory/ref_counted.h" +#include "base/synchronization/waitable_event.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace mojo { + +class SyncHandleRegistryTest : public testing::Test { + public: + SyncHandleRegistryTest() : registry_(SyncHandleRegistry::current()) {} + + const scoped_refptr<SyncHandleRegistry>& registry() { return registry_; } + + private: + scoped_refptr<SyncHandleRegistry> registry_; + + DISALLOW_COPY_AND_ASSIGN(SyncHandleRegistryTest); +}; + +TEST_F(SyncHandleRegistryTest, DuplicateEventRegistration) { + bool called1 = false; + bool called2 = false; + auto callback = [](bool* called) { *called = true; }; + auto callback1 = base::Bind(callback, &called1); + auto callback2 = base::Bind(callback, &called2); + + base::WaitableEvent e(base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::SIGNALED); + registry()->RegisterEvent(&e, callback1); + registry()->RegisterEvent(&e, callback2); + + const bool* stop_flags[] = {&called1, &called2}; + registry()->Wait(stop_flags, 2); + + EXPECT_TRUE(called1); + EXPECT_TRUE(called2); + registry()->UnregisterEvent(&e, callback1); + + called1 = false; + called2 = false; + + registry()->Wait(stop_flags, 2); + + EXPECT_FALSE(called1); + EXPECT_TRUE(called2); + + registry()->UnregisterEvent(&e, callback2); +} + +TEST_F(SyncHandleRegistryTest, UnregisterDuplicateEventInNestedWait) { + base::WaitableEvent e(base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::SIGNALED); + bool called1 = false; + bool called2 = false; + bool called3 = false; + auto callback1 = base::Bind([](bool* called) { *called = true; }, &called1); + auto callback2 = base::Bind( + [](base::WaitableEvent* e, const base::Closure& other_callback, + scoped_refptr<SyncHandleRegistry> registry, bool* called) { + registry->UnregisterEvent(e, other_callback); + *called = true; + }, + &e, callback1, registry(), &called2); + auto callback3 = base::Bind([](bool* called) { *called = true; }, &called3); + + registry()->RegisterEvent(&e, callback1); + registry()->RegisterEvent(&e, callback2); + registry()->RegisterEvent(&e, callback3); + + const bool* stop_flags[] = {&called1, &called2, &called3}; + registry()->Wait(stop_flags, 3); + + // We don't make any assumptions about the order in which callbacks run, so + // we can't check |called1| - it may or may not get set depending on internal + // details. All we know is |called2| should be set, and a subsequent wait + // should definitely NOT set |called1|. + EXPECT_TRUE(called2); + EXPECT_TRUE(called3); + + called1 = false; + called2 = false; + called3 = false; + + registry()->UnregisterEvent(&e, callback2); + registry()->Wait(stop_flags, 3); + + EXPECT_FALSE(called1); + EXPECT_FALSE(called2); + EXPECT_TRUE(called3); +} + +TEST_F(SyncHandleRegistryTest, UnregisterAndRegisterForNewEventInCallback) { + auto e = std::make_unique<base::WaitableEvent>( + base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::SIGNALED); + bool called = false; + base::Closure callback_holder; + auto callback = base::Bind( + [](std::unique_ptr<base::WaitableEvent>* e, + base::Closure* callback_holder, + scoped_refptr<SyncHandleRegistry> registry, bool* called) { + EXPECT_FALSE(*called); + + registry->UnregisterEvent(e->get(), *callback_holder); + e->reset(); + *called = true; + + base::WaitableEvent nested_event( + base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::SIGNALED); + bool nested_called = false; + auto nested_callback = + base::Bind([](bool* called) { *called = true; }, &nested_called); + registry->RegisterEvent(&nested_event, nested_callback); + const bool* stop_flag = &nested_called; + registry->Wait(&stop_flag, 1); + registry->UnregisterEvent(&nested_event, nested_callback); + }, + &e, &callback_holder, registry(), &called); + callback_holder = callback; + + registry()->RegisterEvent(e.get(), callback); + + const bool* stop_flag = &called; + registry()->Wait(&stop_flag, 1); + EXPECT_TRUE(called); +} + +TEST_F(SyncHandleRegistryTest, UnregisterAndRegisterForSameEventInCallback) { + base::WaitableEvent e(base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::SIGNALED); + bool called = false; + base::Closure callback_holder; + auto callback = base::Bind( + [](base::WaitableEvent* e, base::Closure* callback_holder, + scoped_refptr<SyncHandleRegistry> registry, bool* called) { + EXPECT_FALSE(*called); + + registry->UnregisterEvent(e, *callback_holder); + *called = true; + + bool nested_called = false; + auto nested_callback = + base::Bind([](bool* called) { *called = true; }, &nested_called); + registry->RegisterEvent(e, nested_callback); + const bool* stop_flag = &nested_called; + registry->Wait(&stop_flag, 1); + registry->UnregisterEvent(e, nested_callback); + + EXPECT_TRUE(nested_called); + }, + &e, &callback_holder, registry(), &called); + callback_holder = callback; + + registry()->RegisterEvent(&e, callback); + + const bool* stop_flag = &called; + registry()->Wait(&stop_flag, 1); + EXPECT_TRUE(called); +} + +TEST_F(SyncHandleRegistryTest, RegisterDuplicateEventFromWithinCallback) { + base::WaitableEvent e(base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::SIGNALED); + bool called = false; + int call_count = 0; + auto callback = base::Bind( + [](base::WaitableEvent* e, scoped_refptr<SyncHandleRegistry> registry, + bool* called, int* call_count) { + // Don't re-enter. + ++(*call_count); + if (*called) + return; + + *called = true; + + bool called2 = false; + auto callback2 = + base::Bind([](bool* called) { *called = true; }, &called2); + registry->RegisterEvent(e, callback2); + + const bool* stop_flag = &called2; + registry->Wait(&stop_flag, 1); + + registry->UnregisterEvent(e, callback2); + }, + &e, registry(), &called, &call_count); + + registry()->RegisterEvent(&e, callback); + + const bool* stop_flag = &called; + registry()->Wait(&stop_flag, 1); + + EXPECT_TRUE(called); + EXPECT_EQ(2, call_count); + + registry()->UnregisterEvent(&e, callback); +} + +TEST_F(SyncHandleRegistryTest, UnregisterUniqueEventInNestedWait) { + auto e1 = std::make_unique<base::WaitableEvent>( + base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::NOT_SIGNALED); + base::WaitableEvent e2(base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::SIGNALED); + bool called1 = false; + bool called2 = false; + auto callback1 = base::Bind([](bool* called) { *called = true; }, &called1); + auto callback2 = base::Bind( + [](std::unique_ptr<base::WaitableEvent>* e1, + const base::Closure& other_callback, + scoped_refptr<SyncHandleRegistry> registry, bool* called) { + // Prevent re-entrancy. + if (*called) + return; + + registry->UnregisterEvent(e1->get(), other_callback); + *called = true; + e1->reset(); + + // Nest another wait. + bool called3 = false; + auto callback3 = + base::Bind([](bool* called) { *called = true; }, &called3); + base::WaitableEvent e3(base::WaitableEvent::ResetPolicy::MANUAL, + base::WaitableEvent::InitialState::SIGNALED); + registry->RegisterEvent(&e3, callback3); + + // This nested Wait() must not attempt to wait on |e1| since it has + // been unregistered. This would crash otherwise, since |e1| has been + // deleted. See http://crbug.com/761097. + const bool* stop_flags[] = {&called3}; + registry->Wait(stop_flags, 1); + + EXPECT_TRUE(called3); + registry->UnregisterEvent(&e3, callback3); + }, + &e1, callback1, registry(), &called2); + + registry()->RegisterEvent(e1.get(), callback1); + registry()->RegisterEvent(&e2, callback2); + + const bool* stop_flags[] = {&called1, &called2}; + registry()->Wait(stop_flags, 2); + + EXPECT_TRUE(called2); + + registry()->UnregisterEvent(&e2, callback2); +} + +} // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/sync_method_unittest.cc b/mojo/public/cpp/bindings/tests/sync_method_unittest.cc index 084e080ad3..f09e732fda 100644 --- a/mojo/public/cpp/bindings/tests/sync_method_unittest.cc +++ b/mojo/public/cpp/bindings/tests/sync_method_unittest.cc @@ -7,11 +7,14 @@ #include "base/bind.h" #include "base/logging.h" #include "base/macros.h" -#include "base/message_loop/message_loop.h" #include "base/run_loop.h" +#include "base/sequence_token.h" +#include "base/task_scheduler/post_task.h" +#include "base/test/scoped_task_environment.h" #include "base/threading/thread.h" #include "mojo/public/cpp/bindings/associated_binding.h" #include "mojo/public/cpp/bindings/binding.h" +#include "mojo/public/cpp/bindings/tests/bindings_test_base.h" #include "mojo/public/interfaces/bindings/tests/test_sync_methods.mojom.h" #include "testing/gtest/include/gtest/gtest.h" @@ -238,9 +241,12 @@ class PtrWrapper { DISALLOW_COPY_AND_ASSIGN(PtrWrapper); }; -// The type parameter for SyncMethodCommonTests for varying the Interface and -// whether to use InterfacePtr or ThreadSafeInterfacePtr. -template <typename InterfaceT, bool use_thread_safe_ptr> +// The type parameter for SyncMethodCommonTests and +// SyncMethodOnSequenceCommonTests for varying the Interface and whether to use +// InterfacePtr or ThreadSafeInterfacePtr. +template <typename InterfaceT, + bool use_thread_safe_ptr, + BindingsTestSerializationMode serialization_mode> struct TestParams { using Interface = InterfaceT; static const bool kIsThreadSafeInterfacePtrTest = use_thread_safe_ptr; @@ -253,18 +259,20 @@ struct TestParams { return PtrWrapper<Interface>(std::move(ptr)); } } + + static const BindingsTestSerializationMode kSerializationMode = + serialization_mode; }; template <typename Interface> -class TestSyncServiceThread { +class TestSyncServiceSequence { public: - TestSyncServiceThread() - : thread_("TestSyncServiceThread"), ping_called_(false) { - thread_.Start(); - } + TestSyncServiceSequence() + : task_runner_(base::CreateSequencedTaskRunnerWithTraits({})), + ping_called_(false) {} void SetUp(InterfaceRequest<Interface> request) { - CHECK(thread_.task_runner()->BelongsToCurrentThread()); + CHECK(task_runner()->RunsTasksInCurrentSequence()); impl_.reset(new ImplTypeFor<Interface>(std::move(request))); impl_->set_ping_handler( [this](const typename Interface::PingCallback& callback) { @@ -277,25 +285,25 @@ class TestSyncServiceThread { } void TearDown() { - CHECK(thread_.task_runner()->BelongsToCurrentThread()); + CHECK(task_runner()->RunsTasksInCurrentSequence()); impl_.reset(); } - base::Thread* thread() { return &thread_; } + base::SequencedTaskRunner* task_runner() { return task_runner_.get(); } bool ping_called() const { base::AutoLock locker(lock_); return ping_called_; } private: - base::Thread thread_; + scoped_refptr<base::SequencedTaskRunner> task_runner_; std::unique_ptr<ImplTypeFor<Interface>> impl_; mutable base::Lock lock_; bool ping_called_; - DISALLOW_COPY_AND_ASSIGN(TestSyncServiceThread); + DISALLOW_COPY_AND_ASSIGN(TestSyncServiceSequence); }; class SyncMethodTest : public testing::Test { @@ -304,14 +312,18 @@ class SyncMethodTest : public testing::Test { ~SyncMethodTest() override { base::RunLoop().RunUntilIdle(); } protected: - base::MessageLoop loop_; + base::test::ScopedTaskEnvironment task_environment; }; -template <typename T> +template <typename TypeParam> class SyncMethodCommonTest : public SyncMethodTest { public: SyncMethodCommonTest() {} ~SyncMethodCommonTest() override {} + + void SetUp() override { + BindingsTestBase::SetupSerializationBehavior(TypeParam::kSerializationMode); + } }; class SyncMethodAssociatedTest : public SyncMethodTest { @@ -386,16 +398,125 @@ TestSync::AsyncEchoCallback BindAsyncEchoCallback(Func func) { return base::Bind(&CallAsyncEchoCallback<Func>, func); } +class SequencedTaskRunnerTestBase; + +void RunTestOnSequencedTaskRunner( + std::unique_ptr<SequencedTaskRunnerTestBase> test); + +class SequencedTaskRunnerTestBase { + public: + virtual ~SequencedTaskRunnerTestBase() = default; + + void RunTest() { + SetUp(); + Run(); + } + + virtual void Run() = 0; + + virtual void SetUp() {} + virtual void TearDown() {} + + protected: + void Done() { + TearDown(); + task_runner_->PostTask(FROM_HERE, quit_closure_); + delete this; + } + + base::Closure DoneClosure() { + return base::Bind(&SequencedTaskRunnerTestBase::Done, + base::Unretained(this)); + } + + private: + friend void RunTestOnSequencedTaskRunner( + std::unique_ptr<SequencedTaskRunnerTestBase> test); + + void Init(const base::Closure& quit_closure) { + task_runner_ = base::SequencedTaskRunnerHandle::Get(); + quit_closure_ = quit_closure; + } + + scoped_refptr<base::SequencedTaskRunner> task_runner_; + base::Closure quit_closure_; +}; + +// A helper class to launch tests on a SequencedTaskRunner. This is necessary +// so gtest can instantiate copies for each |TypeParam|. +template <typename TypeParam> +class SequencedTaskRunnerTestLauncher : public testing::Test { + base::test::ScopedTaskEnvironment task_environment; +}; + +// Similar to SyncMethodCommonTest, but the test body runs on a +// SequencedTaskRunner. +template <typename TypeParam> +class SyncMethodOnSequenceCommonTest : public SequencedTaskRunnerTestBase { + public: + void SetUp() override { + BindingsTestBase::SetupSerializationBehavior(TypeParam::kSerializationMode); + impl_ = std::make_unique<ImplTypeFor<typename TypeParam::Interface>>( + MakeRequest(&ptr_)); + } + + protected: + InterfacePtr<typename TypeParam::Interface> ptr_; + std::unique_ptr<ImplTypeFor<typename TypeParam::Interface>> impl_; +}; + +void RunTestOnSequencedTaskRunner( + std::unique_ptr<SequencedTaskRunnerTestBase> test) { + base::RunLoop run_loop; + test->Init(run_loop.QuitClosure()); + base::CreateSequencedTaskRunnerWithTraits({base::WithBaseSyncPrimitives()}) + ->PostTask(FROM_HERE, base::Bind(&SequencedTaskRunnerTestBase::RunTest, + base::Unretained(test.release()))); + run_loop.Run(); +} + // TestSync (without associated interfaces) and TestSyncMaster (with associated // interfaces) exercise MultiplexRouter with different configurations. // Each test is run once with an InterfacePtr and once with a // ThreadSafeInterfacePtr to ensure that they behave the same with respect to -// sync calls. -using InterfaceTypes = testing::Types<TestParams<TestSync, true>, - TestParams<TestSync, false>, - TestParams<TestSyncMaster, true>, - TestParams<TestSyncMaster, false>>; +// sync calls. Finally, all such combinations are tested in different message +// serialization modes. +using InterfaceTypes = testing::Types< + TestParams<TestSync, + true, + BindingsTestSerializationMode::kSerializeBeforeSend>, + TestParams<TestSync, + false, + BindingsTestSerializationMode::kSerializeBeforeSend>, + TestParams<TestSyncMaster, + true, + BindingsTestSerializationMode::kSerializeBeforeSend>, + TestParams<TestSyncMaster, + false, + BindingsTestSerializationMode::kSerializeBeforeSend>, + TestParams<TestSync, + true, + BindingsTestSerializationMode::kSerializeBeforeDispatch>, + TestParams<TestSync, + false, + BindingsTestSerializationMode::kSerializeBeforeDispatch>, + TestParams<TestSyncMaster, + true, + BindingsTestSerializationMode::kSerializeBeforeDispatch>, + TestParams<TestSyncMaster, + false, + BindingsTestSerializationMode::kSerializeBeforeDispatch>, + TestParams<TestSync, true, BindingsTestSerializationMode::kNeverSerialize>, + TestParams<TestSync, false, BindingsTestSerializationMode::kNeverSerialize>, + TestParams<TestSyncMaster, + true, + BindingsTestSerializationMode::kNeverSerialize>, + TestParams<TestSyncMaster, + false, + BindingsTestSerializationMode::kNeverSerialize>>; + TYPED_TEST_CASE(SyncMethodCommonTest, InterfaceTypes); +TYPED_TEST_CASE(SequencedTaskRunnerTestLauncher, InterfaceTypes); TYPED_TEST(SyncMethodCommonTest, CallSyncMethodAsynchronously) { using Interface = typename TypeParam::Interface; @@ -409,29 +530,65 @@ TYPED_TEST(SyncMethodCommonTest, CallSyncMethodAsynchronously) { run_loop.Run(); } +#define SEQUENCED_TASK_RUNNER_TYPED_TEST_NAME(fixture_name, name) \ + fixture_name##name##_SequencedTaskRunnerTestSuffix + +#define SEQUENCED_TASK_RUNNER_TYPED_TEST(fixture_name, name) \ + template <typename TypeParam> \ + class SEQUENCED_TASK_RUNNER_TYPED_TEST_NAME(fixture_name, name) \ + : public fixture_name<TypeParam> { \ + void Run() override; \ + }; \ + TYPED_TEST(SequencedTaskRunnerTestLauncher, name) { \ + RunTestOnSequencedTaskRunner( \ + std::make_unique<SEQUENCED_TASK_RUNNER_TYPED_TEST_NAME( \ + fixture_name, name) < TypeParam>> ()); \ + } \ + template <typename TypeParam> \ + void SEQUENCED_TASK_RUNNER_TYPED_TEST_NAME(fixture_name, \ + name)<TypeParam>::Run() + +#define SEQUENCED_TASK_RUNNER_TYPED_TEST_F(fixture_name, name) \ + template <typename TypeParam> \ + class SEQUENCED_TASK_RUNNER_TYPED_TEST_NAME(fixture_name, name); \ + TYPED_TEST(SequencedTaskRunnerTestLauncher, name) { \ + RunTestOnSequencedTaskRunner( \ + std::make_unique<SEQUENCED_TASK_RUNNER_TYPED_TEST_NAME( \ + fixture_name, name) < TypeParam>> ()); \ + } \ + template <typename TypeParam> \ + class SEQUENCED_TASK_RUNNER_TYPED_TEST_NAME(fixture_name, name) \ + : public fixture_name<TypeParam> + +SEQUENCED_TASK_RUNNER_TYPED_TEST(SyncMethodOnSequenceCommonTest, + CallSyncMethodAsynchronously) { + this->ptr_->Echo( + 123, base::Bind(&ExpectValueAndRunClosure, 123, this->DoneClosure())); +} + TYPED_TEST(SyncMethodCommonTest, BasicSyncCalls) { using Interface = typename TypeParam::Interface; InterfacePtr<Interface> interface_ptr; InterfaceRequest<Interface> request = MakeRequest(&interface_ptr); auto ptr = TypeParam::Wrap(std::move(interface_ptr)); - TestSyncServiceThread<Interface> service_thread; - service_thread.thread()->task_runner()->PostTask( + TestSyncServiceSequence<Interface> service_sequence; + service_sequence.task_runner()->PostTask( FROM_HERE, - base::Bind(&TestSyncServiceThread<Interface>::SetUp, - base::Unretained(&service_thread), base::Passed(&request))); + base::Bind(&TestSyncServiceSequence<Interface>::SetUp, + base::Unretained(&service_sequence), base::Passed(&request))); ASSERT_TRUE(ptr->Ping()); - ASSERT_TRUE(service_thread.ping_called()); + ASSERT_TRUE(service_sequence.ping_called()); int32_t output_value = -1; ASSERT_TRUE(ptr->Echo(42, &output_value)); ASSERT_EQ(42, output_value); base::RunLoop run_loop; - service_thread.thread()->task_runner()->PostTaskAndReply( + service_sequence.task_runner()->PostTaskAndReply( FROM_HERE, - base::Bind(&TestSyncServiceThread<Interface>::TearDown, - base::Unretained(&service_thread)), + base::Bind(&TestSyncServiceSequence<Interface>::TearDown, + base::Unretained(&service_sequence)), run_loop.QuitClosure()); run_loop.Run(); } @@ -450,6 +607,17 @@ TYPED_TEST(SyncMethodCommonTest, ReenteredBySyncMethodBinding) { EXPECT_EQ(42, output_value); } +SEQUENCED_TASK_RUNNER_TYPED_TEST(SyncMethodOnSequenceCommonTest, + ReenteredBySyncMethodBinding) { + // Test that an interface pointer waiting for a sync call response can be + // reentered by a binding serving sync methods on the same thread. + + int32_t output_value = -1; + ASSERT_TRUE(this->ptr_->Echo(42, &output_value)); + EXPECT_EQ(42, output_value); + this->Done(); +} + TYPED_TEST(SyncMethodCommonTest, InterfacePtrDestroyedDuringSyncCall) { // Test that it won't result in crash or hang if an interface pointer is // destroyed while it is waiting for a sync call response. @@ -465,6 +633,20 @@ TYPED_TEST(SyncMethodCommonTest, InterfacePtrDestroyedDuringSyncCall) { ASSERT_FALSE(ptr->Ping()); } +SEQUENCED_TASK_RUNNER_TYPED_TEST(SyncMethodOnSequenceCommonTest, + InterfacePtrDestroyedDuringSyncCall) { + // Test that it won't result in crash or hang if an interface pointer is + // destroyed while it is waiting for a sync call response. + + auto* ptr = &this->ptr_; + this->impl_->set_ping_handler([ptr](const TestSync::PingCallback& callback) { + ptr->reset(); + callback.Run(); + }); + ASSERT_FALSE(this->ptr_->Ping()); + this->Done(); +} + TYPED_TEST(SyncMethodCommonTest, BindingDestroyedDuringSyncCall) { // Test that it won't result in crash or hang if a binding is // closed (and therefore the message pipe handle is closed) while the @@ -481,6 +663,22 @@ TYPED_TEST(SyncMethodCommonTest, BindingDestroyedDuringSyncCall) { ASSERT_FALSE(ptr->Ping()); } +SEQUENCED_TASK_RUNNER_TYPED_TEST(SyncMethodOnSequenceCommonTest, + BindingDestroyedDuringSyncCall) { + // Test that it won't result in crash or hang if a binding is + // closed (and therefore the message pipe handle is closed) while the + // corresponding interface pointer is waiting for a sync call response. + + auto& impl = *this->impl_; + this->impl_->set_ping_handler( + [&impl](const TestSync::PingCallback& callback) { + impl.binding()->Close(); + callback.Run(); + }); + ASSERT_FALSE(this->ptr_->Ping()); + this->Done(); +} + TYPED_TEST(SyncMethodCommonTest, NestedSyncCallsWithInOrderResponses) { // Test that we can call a sync method on an interface ptr, while there is // already a sync call ongoing. The responses arrive in order. @@ -509,6 +707,34 @@ TYPED_TEST(SyncMethodCommonTest, NestedSyncCallsWithInOrderResponses) { EXPECT_EQ(123, result_value); } +SEQUENCED_TASK_RUNNER_TYPED_TEST(SyncMethodOnSequenceCommonTest, + NestedSyncCallsWithInOrderResponses) { + // Test that we can call a sync method on an interface ptr, while there is + // already a sync call ongoing. The responses arrive in order. + + // The same variable is used to store the output of the two sync calls, in + // order to test that responses are handled in the correct order. + int32_t result_value = -1; + + bool first_call = true; + auto& ptr = this->ptr_; + auto& impl = *this->impl_; + impl.set_echo_handler( + [&first_call, &ptr, &result_value]( + int32_t value, const TestSync::EchoCallback& callback) { + if (first_call) { + first_call = false; + ASSERT_TRUE(ptr->Echo(456, &result_value)); + EXPECT_EQ(456, result_value); + } + callback.Run(value); + }); + + ASSERT_TRUE(ptr->Echo(123, &result_value)); + EXPECT_EQ(123, result_value); + this->Done(); +} + TYPED_TEST(SyncMethodCommonTest, NestedSyncCallsWithOutOfOrderResponses) { // Test that we can call a sync method on an interface ptr, while there is // already a sync call ongoing. The responses arrive out of order. @@ -537,6 +763,34 @@ TYPED_TEST(SyncMethodCommonTest, NestedSyncCallsWithOutOfOrderResponses) { EXPECT_EQ(123, result_value); } +SEQUENCED_TASK_RUNNER_TYPED_TEST(SyncMethodOnSequenceCommonTest, + NestedSyncCallsWithOutOfOrderResponses) { + // Test that we can call a sync method on an interface ptr, while there is + // already a sync call ongoing. The responses arrive out of order. + + // The same variable is used to store the output of the two sync calls, in + // order to test that responses are handled in the correct order. + int32_t result_value = -1; + + bool first_call = true; + auto& ptr = this->ptr_; + auto& impl = *this->impl_; + impl.set_echo_handler( + [&first_call, &ptr, &result_value]( + int32_t value, const TestSync::EchoCallback& callback) { + callback.Run(value); + if (first_call) { + first_call = false; + ASSERT_TRUE(ptr->Echo(456, &result_value)); + EXPECT_EQ(456, result_value); + } + }); + + ASSERT_TRUE(ptr->Echo(123, &result_value)); + EXPECT_EQ(123, result_value); + this->Done(); +} + TYPED_TEST(SyncMethodCommonTest, AsyncResponseQueuedDuringSyncCall) { // Test that while an interface pointer is waiting for the response to a sync // call, async responses are queued until the sync call completes. @@ -594,6 +848,52 @@ TYPED_TEST(SyncMethodCommonTest, AsyncResponseQueuedDuringSyncCall) { EXPECT_TRUE(async_echo_response_dispatched); } +SEQUENCED_TASK_RUNNER_TYPED_TEST_F(SyncMethodOnSequenceCommonTest, + AsyncResponseQueuedDuringSyncCall) { + // Test that while an interface pointer is waiting for the response to a sync + // call, async responses are queued until the sync call completes. + + void Run() override { + this->impl_->set_async_echo_handler( + [this](int32_t value, const TestSync::AsyncEchoCallback& callback) { + async_echo_request_value_ = value; + async_echo_request_callback_ = callback; + OnAsyncEchoReceived(); + }); + + this->ptr_->AsyncEcho(123, BindAsyncEchoCallback([this](int32_t result) { + async_echo_response_dispatched_ = true; + EXPECT_EQ(123, result); + EXPECT_TRUE(async_echo_response_dispatched_); + this->Done(); + })); + } + + // Called when the AsyncEcho request reaches the service side. + void OnAsyncEchoReceived() { + this->impl_->set_echo_handler( + [this](int32_t value, const TestSync::EchoCallback& callback) { + // Send back the async response first. + EXPECT_FALSE(async_echo_request_callback_.is_null()); + async_echo_request_callback_.Run(async_echo_request_value_); + + callback.Run(value); + }); + + int32_t result_value = -1; + ASSERT_TRUE(this->ptr_->Echo(456, &result_value)); + EXPECT_EQ(456, result_value); + + // Although the AsyncEcho response arrives before the Echo response, it + // should be queued and not yet dispatched. + EXPECT_FALSE(async_echo_response_dispatched_); + } + + int32_t async_echo_request_value_ = -1; + TestSync::AsyncEchoCallback async_echo_request_callback_; + bool async_echo_response_dispatched_ = false; +}; + TYPED_TEST(SyncMethodCommonTest, AsyncRequestQueuedDuringSyncCall) { // Test that while an interface pointer is waiting for the response to a sync // call, async requests for a binding running on the same thread are queued @@ -645,6 +945,44 @@ TYPED_TEST(SyncMethodCommonTest, AsyncRequestQueuedDuringSyncCall) { EXPECT_TRUE(async_echo_response_dispatched); } +SEQUENCED_TASK_RUNNER_TYPED_TEST_F(SyncMethodOnSequenceCommonTest, + AsyncRequestQueuedDuringSyncCall) { + // Test that while an interface pointer is waiting for the response to a sync + // call, async requests for a binding running on the same thread are queued + // until the sync call completes. + void Run() override { + this->impl_->set_async_echo_handler( + [this](int32_t value, const TestSync::AsyncEchoCallback& callback) { + async_echo_request_dispatched_ = true; + callback.Run(value); + }); + + this->ptr_->AsyncEcho(123, BindAsyncEchoCallback([this](int32_t result) { + EXPECT_EQ(123, result); + this->Done(); + })); + + this->impl_->set_echo_handler( + [this](int32_t value, const TestSync::EchoCallback& callback) { + // Although the AsyncEcho request is sent before the Echo request, it + // shouldn't be dispatched yet at this point, because there is an + // ongoing + // sync call on the same thread. + EXPECT_FALSE(async_echo_request_dispatched_); + callback.Run(value); + }); + + int32_t result_value = -1; + ASSERT_TRUE(this->ptr_->Echo(456, &result_value)); + EXPECT_EQ(456, result_value); + + // Although the AsyncEcho request is sent before the Echo request, it + // shouldn't be dispatched yet. + EXPECT_FALSE(async_echo_request_dispatched_); + } + bool async_echo_request_dispatched_ = false; +}; + TYPED_TEST(SyncMethodCommonTest, QueuedMessagesProcessedBeforeErrorNotification) { // Test that while an interface pointer is waiting for the response to a sync @@ -675,19 +1013,17 @@ TYPED_TEST(SyncMethodCommonTest, bool async_echo_response_dispatched = false; bool connection_error_dispatched = false; base::RunLoop run_loop2; - ptr->AsyncEcho( - 123, - BindAsyncEchoCallback( - [&async_echo_response_dispatched, &connection_error_dispatched, &ptr, - &run_loop2](int32_t result) { - async_echo_response_dispatched = true; - // At this point, error notification should not be dispatched - // yet. - EXPECT_FALSE(connection_error_dispatched); - EXPECT_FALSE(ptr.encountered_error()); - EXPECT_EQ(123, result); - run_loop2.Quit(); - })); + ptr->AsyncEcho(123, BindAsyncEchoCallback([&async_echo_response_dispatched, + &connection_error_dispatched, &ptr, + &run_loop2](int32_t result) { + async_echo_response_dispatched = true; + // At this point, error notification should not be dispatched + // yet. + EXPECT_FALSE(connection_error_dispatched); + EXPECT_FALSE(ptr.encountered_error()); + EXPECT_EQ(123, result); + run_loop2.Quit(); + })); // Run until the AsyncEcho request reaches the service side. run_loop1.Run(); @@ -702,9 +1038,9 @@ TYPED_TEST(SyncMethodCommonTest, }); base::RunLoop run_loop3; - ptr.set_connection_error_handler( - base::Bind(&SetFlagAndRunClosure, &connection_error_dispatched, - run_loop3.QuitClosure())); + ptr.set_connection_error_handler(base::Bind(&SetFlagAndRunClosure, + &connection_error_dispatched, + run_loop3.QuitClosure())); int32_t result_value = -1; ASSERT_FALSE(ptr->Echo(456, &result_value)); @@ -728,6 +1064,74 @@ TYPED_TEST(SyncMethodCommonTest, EXPECT_TRUE(ptr.encountered_error()); } +SEQUENCED_TASK_RUNNER_TYPED_TEST_F( + SyncMethodOnSequenceCommonTest, + QueuedMessagesProcessedBeforeErrorNotification) { + // Test that while an interface pointer is waiting for the response to a sync + // call, async responses are queued. If the message pipe is disconnected + // before the queued messages are processed, the connection error + // notification is delayed until all the queued messages are processed. + + void Run() override { + this->impl_->set_async_echo_handler( + [this](int32_t value, const TestSync::AsyncEchoCallback& callback) { + OnAsyncEchoReachedService(value, callback); + }); + + this->ptr_->AsyncEcho(123, BindAsyncEchoCallback([this](int32_t result) { + async_echo_response_dispatched_ = true; + // At this point, error notification should not be + // dispatched + // yet. + EXPECT_FALSE(connection_error_dispatched_); + EXPECT_FALSE(this->ptr_.encountered_error()); + EXPECT_EQ(123, result); + EXPECT_TRUE(async_echo_response_dispatched_); + })); + } + + void OnAsyncEchoReachedService(int32_t value, + const TestSync::AsyncEchoCallback& callback) { + async_echo_request_value_ = value; + async_echo_request_callback_ = callback; + this->impl_->set_echo_handler( + [this](int32_t value, const TestSync::EchoCallback& callback) { + // Send back the async response first. + EXPECT_FALSE(async_echo_request_callback_.is_null()); + async_echo_request_callback_.Run(async_echo_request_value_); + + this->impl_->binding()->Close(); + }); + + this->ptr_.set_connection_error_handler( + base::Bind(&SetFlagAndRunClosure, &connection_error_dispatched_, + LambdaBinder<>::BindLambda( + [this]() { OnErrorNotificationDispatched(); }))); + + int32_t result_value = -1; + ASSERT_FALSE(this->ptr_->Echo(456, &result_value)); + EXPECT_EQ(-1, result_value); + ASSERT_FALSE(connection_error_dispatched_); + EXPECT_FALSE(this->ptr_.encountered_error()); + + // Although the AsyncEcho response arrives before the Echo response, it + // should + // be queued and not yet dispatched. + EXPECT_FALSE(async_echo_response_dispatched_); + } + + void OnErrorNotificationDispatched() { + ASSERT_TRUE(connection_error_dispatched_); + EXPECT_TRUE(this->ptr_.encountered_error()); + this->Done(); + } + + int32_t async_echo_request_value_ = -1; + TestSync::AsyncEchoCallback async_echo_request_callback_; + bool async_echo_response_dispatched_ = false; + bool connection_error_dispatched_ = false; +}; + TYPED_TEST(SyncMethodCommonTest, InvalidMessageDuringSyncCall) { // Test that while an interface pointer is waiting for the response to a sync // call, an invalid incoming message will disconnect the message pipe, cause @@ -742,7 +1146,8 @@ TYPED_TEST(SyncMethodCommonTest, InvalidMessageDuringSyncCall) { auto ptr = TypeParam::Wrap(std::move(interface_ptr)); MessagePipeHandle raw_binding_handle = pipe.handle1.get(); - ImplTypeFor<Interface> impl(MakeRequest<Interface>(std::move(pipe.handle1))); + ImplTypeFor<Interface> impl( + InterfaceRequest<Interface>(std::move(pipe.handle1))); impl.set_echo_handler([&raw_binding_handle]( int32_t value, const TestSync::EchoCallback& callback) { @@ -775,6 +1180,50 @@ TYPED_TEST(SyncMethodCommonTest, InvalidMessageDuringSyncCall) { } } +SEQUENCED_TASK_RUNNER_TYPED_TEST_F(SyncMethodOnSequenceCommonTest, + InvalidMessageDuringSyncCall) { + // Test that while an interface pointer is waiting for the response to a sync + // call, an invalid incoming message will disconnect the message pipe, cause + // the sync call to return false, and run the connection error handler + // asynchronously. + + void Run() override { + MessagePipe pipe; + + using InterfaceType = typename TypeParam::Interface; + this->ptr_.Bind( + InterfacePtrInfo<InterfaceType>(std::move(pipe.handle0), 0u)); + + MessagePipeHandle raw_binding_handle = pipe.handle1.get(); + this->impl_ = std::make_unique<ImplTypeFor<InterfaceType>>( + InterfaceRequest<InterfaceType>(std::move(pipe.handle1))); + + this->impl_->set_echo_handler( + [raw_binding_handle](int32_t value, + const TestSync::EchoCallback& callback) { + // Write a 1-byte message, which is considered invalid. + char invalid_message = 0; + MojoResult result = + WriteMessageRaw(raw_binding_handle, &invalid_message, 1u, nullptr, + 0u, MOJO_WRITE_MESSAGE_FLAG_NONE); + ASSERT_EQ(MOJO_RESULT_OK, result); + callback.Run(value); + }); + + this->ptr_.set_connection_error_handler( + LambdaBinder<>::BindLambda([this]() { + connection_error_dispatched_ = true; + this->Done(); + })); + + int32_t result_value = -1; + ASSERT_FALSE(this->ptr_->Echo(456, &result_value)); + EXPECT_EQ(-1, result_value); + ASSERT_FALSE(connection_error_dispatched_); + } + bool connection_error_dispatched_ = false; +}; + TEST_F(SyncMethodAssociatedTest, ReenteredBySyncMethodAssoBindingOfSameRouter) { // Test that an interface pointer waiting for a sync call response can be // reentered by an associated binding serving sync methods on the same thread. diff --git a/mojo/public/cpp/bindings/tests/test_helpers_unittest.cc b/mojo/public/cpp/bindings/tests/test_helpers_unittest.cc new file mode 100644 index 0000000000..4595dea4ed --- /dev/null +++ b/mojo/public/cpp/bindings/tests/test_helpers_unittest.cc @@ -0,0 +1,128 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/macros.h" +#include "base/test/scoped_task_environment.h" +#include "mojo/public/cpp/bindings/binding.h" +#include "mojo/public/cpp/system/message_pipe.h" +#include "mojo/public/cpp/system/wait.h" +#include "mojo/public/interfaces/bindings/tests/ping_service.mojom.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace mojo { +namespace { + +class TestHelperTest : public testing::Test { + public: + TestHelperTest() = default; + ~TestHelperTest() override = default; + + private: + base::test::ScopedTaskEnvironment task_environment_; + + DISALLOW_COPY_AND_ASSIGN(TestHelperTest); +}; + +class PingImpl : public test::PingService { + public: + explicit PingImpl(test::PingServiceRequest request) + : binding_(this, std::move(request)) {} + ~PingImpl() override = default; + + bool pinged() const { return pinged_; } + + // test::PingService: + void Ping(const PingCallback& callback) override { + pinged_ = true; + callback.Run(); + } + + private: + bool pinged_ = false; + Binding<test::PingService> binding_; + + DISALLOW_COPY_AND_ASSIGN(PingImpl); +}; + +class EchoImpl : public test::EchoService { + public: + explicit EchoImpl(test::EchoServiceRequest request) + : binding_(this, std::move(request)) {} + ~EchoImpl() override = default; + + // test::EchoService: + void Echo(const std::string& message, const EchoCallback& callback) override { + callback.Run(message); + } + + private: + Binding<test::EchoService> binding_; + + DISALLOW_COPY_AND_ASSIGN(EchoImpl); +}; + +class TrampolineImpl : public test::HandleTrampoline { + public: + explicit TrampolineImpl(test::HandleTrampolineRequest request) + : binding_(this, std::move(request)) {} + ~TrampolineImpl() override = default; + + // test::HandleTrampoline: + void BounceOne(ScopedMessagePipeHandle one, + const BounceOneCallback& callback) override { + callback.Run(std::move(one)); + } + + void BounceTwo(ScopedMessagePipeHandle one, + ScopedMessagePipeHandle two, + const BounceTwoCallback& callback) override { + callback.Run(std::move(one), std::move(two)); + } + + private: + Binding<test::HandleTrampoline> binding_; + + DISALLOW_COPY_AND_ASSIGN(TrampolineImpl); +}; + +TEST_F(TestHelperTest, AsyncWaiter) { + test::PingServicePtr ping; + PingImpl ping_impl(MakeRequest(&ping)); + + test::PingServiceAsyncWaiter wait_for_ping(ping.get()); + EXPECT_FALSE(ping_impl.pinged()); + wait_for_ping.Ping(); + EXPECT_TRUE(ping_impl.pinged()); + + test::EchoServicePtr echo; + EchoImpl echo_impl(MakeRequest(&echo)); + + test::EchoServiceAsyncWaiter wait_for_echo(echo.get()); + const std::string kTestString = "a machine that goes 'ping'"; + std::string response; + wait_for_echo.Echo(kTestString, &response); + EXPECT_EQ(kTestString, response); + + test::HandleTrampolinePtr trampoline; + TrampolineImpl trampoline_impl(MakeRequest(&trampoline)); + + test::HandleTrampolineAsyncWaiter wait_for_trampoline(trampoline.get()); + MessagePipe pipe; + ScopedMessagePipeHandle handle0, handle1; + WriteMessageRaw(pipe.handle0.get(), kTestString.data(), kTestString.size(), + nullptr, 0, MOJO_WRITE_MESSAGE_FLAG_NONE); + wait_for_trampoline.BounceOne(std::move(pipe.handle0), &handle0); + wait_for_trampoline.BounceTwo(std::move(handle0), std::move(pipe.handle1), + &handle0, &handle1); + + // Verify that our pipe handles are the same as the original pipe. + Wait(handle1.get(), MOJO_HANDLE_SIGNAL_READABLE); + std::vector<uint8_t> payload; + ReadMessageRaw(handle1.get(), &payload, nullptr, MOJO_READ_MESSAGE_FLAG_NONE); + std::string original_message(payload.begin(), payload.end()); + EXPECT_EQ(kTestString, original_message); +} + +} // namespace +} // namespace mojo diff --git a/mojo/public/cpp/bindings/tests/test_native_types.cc b/mojo/public/cpp/bindings/tests/test_native_types.cc new file mode 100644 index 0000000000..b11cc23172 --- /dev/null +++ b/mojo/public/cpp/bindings/tests/test_native_types.cc @@ -0,0 +1,99 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "mojo/public/cpp/bindings/tests/test_native_types.h" + +#include "base/macros.h" +#include "ipc/ipc_mojo_message_helper.h" + +namespace mojo { +namespace test { + +TestNativeStruct::TestNativeStruct() = default; + +TestNativeStruct::TestNativeStruct(const std::string& message, int x, int y) + : message_(message), x_(x), y_(y) {} + +TestNativeStruct::~TestNativeStruct() = default; + +TestNativeStructWithAttachments::TestNativeStructWithAttachments() = default; + +TestNativeStructWithAttachments::TestNativeStructWithAttachments( + TestNativeStructWithAttachments&& other) = default; + +TestNativeStructWithAttachments::TestNativeStructWithAttachments( + const std::string& message, + mojo::ScopedMessagePipeHandle pipe) + : message_(message), pipe_(std::move(pipe)) {} + +TestNativeStructWithAttachments::~TestNativeStructWithAttachments() = default; + +TestNativeStructWithAttachments& TestNativeStructWithAttachments::operator=( + TestNativeStructWithAttachments&& other) = default; + +} // namespace test +} // namespace mojo + +namespace IPC { + +// static +void ParamTraits<mojo::test::TestNativeStruct>::Write(base::Pickle* m, + const param_type& p) { + m->WriteString(p.message()); + m->WriteInt(p.x()); + m->WriteInt(p.y()); +} + +// static +bool ParamTraits<mojo::test::TestNativeStruct>::Read(const base::Pickle* m, + base::PickleIterator* iter, + param_type* r) { + std::string message; + if (!iter->ReadString(&message)) + return false; + int x, y; + if (!iter->ReadInt(&x) || !iter->ReadInt(&y)) + return false; + r->set_message(message); + r->set_x(x); + r->set_y(y); + return true; +} + +// static +void ParamTraits<mojo::test::TestNativeStruct>::Log(const param_type& p, + std::string* l) {} + +// static +void ParamTraits<mojo::test::TestNativeStructWithAttachments>::Write( + Message* m, + const param_type& p) { + m->WriteString(p.message()); + IPC::MojoMessageHelper::WriteMessagePipeTo(m, p.PassPipe()); +} + +// static +bool ParamTraits<mojo::test::TestNativeStructWithAttachments>::Read( + const Message* m, + base::PickleIterator* iter, + param_type* r) { + std::string message; + if (!iter->ReadString(&message)) + return false; + r->set_message(message); + + mojo::ScopedMessagePipeHandle pipe; + if (!IPC::MojoMessageHelper::ReadMessagePipeFrom(m, iter, &pipe)) + return false; + + r->set_pipe(std::move(pipe)); + return true; +} + +// static +void ParamTraits<mojo::test::TestNativeStructWithAttachments>::Log( + const param_type& p, + std::string* l) {} + +} // namespace IPC diff --git a/mojo/public/cpp/bindings/tests/test_native_types.h b/mojo/public/cpp/bindings/tests/test_native_types.h new file mode 100644 index 0000000000..9ef2f902b5 --- /dev/null +++ b/mojo/public/cpp/bindings/tests/test_native_types.h @@ -0,0 +1,89 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef MOJO_PUBLIC_CPP_BINDINGS_TESTS_BINDINGS_TEST_NATIVE_TYPES_H_ +#define MOJO_PUBLIC_CPP_BINDINGS_TESTS_BINDINGS_TEST_NATIVE_TYPES_H_ + +#include <string> + +#include "base/macros.h" +#include "ipc/ipc_message.h" +#include "ipc/ipc_param_traits.h" +#include "mojo/public/cpp/system/message_pipe.h" + +namespace mojo { +namespace test { + +class TestNativeStruct { + public: + TestNativeStruct(); + TestNativeStruct(const std::string& message, int x, int y); + ~TestNativeStruct(); + + const std::string& message() const { return message_; } + void set_message(const std::string& message) { message_ = message; } + + int x() const { return x_; } + void set_x(int x) { x_ = x; } + + int y() const { return y_; } + void set_y(int y) { y_ = y; } + + private: + std::string message_; + int x_, y_; +}; + +class TestNativeStructWithAttachments { + public: + TestNativeStructWithAttachments(); + TestNativeStructWithAttachments(TestNativeStructWithAttachments&& other); + TestNativeStructWithAttachments(const std::string& message, + ScopedMessagePipeHandle pipe); + ~TestNativeStructWithAttachments(); + + TestNativeStructWithAttachments& operator=( + TestNativeStructWithAttachments&& other); + + const std::string& message() const { return message_; } + void set_message(const std::string& message) { message_ = message; } + + void set_pipe(mojo::ScopedMessagePipeHandle pipe) { pipe_ = std::move(pipe); } + mojo::ScopedMessagePipeHandle PassPipe() const { return std::move(pipe_); } + + private: + std::string message_; + mutable mojo::ScopedMessagePipeHandle pipe_; + + DISALLOW_COPY_AND_ASSIGN(TestNativeStructWithAttachments); +}; + +} // namespace test +} // namespace mojo + +namespace IPC { + +template <> +struct ParamTraits<mojo::test::TestNativeStruct> { + using param_type = mojo::test::TestNativeStruct; + + static void Write(base::Pickle* m, const param_type& p); + static bool Read(const base::Pickle* m, + base::PickleIterator* iter, + param_type* r); + static void Log(const param_type& p, std::string* l); +}; + +template <> +struct ParamTraits<mojo::test::TestNativeStructWithAttachments> { + using param_type = mojo::test::TestNativeStructWithAttachments; + + static void Write(Message* m, const param_type& p); + static bool Read(const Message* m, base::PickleIterator* iter, param_type* r); + static void Log(const param_type& p, std::string* l); +}; + +} // namespace IPC + +#endif // MOJO_PUBLIC_CPP_BINDINGS_TESTS_BINDINGS_TEST_NATIVE_TYPES_H_ diff --git a/mojo/public/cpp/bindings/tests/test_native_types_chromium.typemap b/mojo/public/cpp/bindings/tests/test_native_types_chromium.typemap index 50e8076a50..da99a1a8d7 100644 --- a/mojo/public/cpp/bindings/tests/test_native_types_chromium.typemap +++ b/mojo/public/cpp/bindings/tests/test_native_types_chromium.typemap @@ -3,9 +3,13 @@ # found in the LICENSE file. mojom = "//mojo/public/interfaces/bindings/tests/test_native_types.mojom" -public_headers = [ "//mojo/public/cpp/bindings/tests/pickled_types_chromium.h" ] +public_headers = [ + "//mojo/public/cpp/bindings/tests/pickled_types_chromium.h", + "//mojo/public/cpp/bindings/tests/test_native_types.h", +] sources = [ "//mojo/public/cpp/bindings/tests/pickled_types_chromium.cc", + "//mojo/public/cpp/bindings/tests/test_native_types.cc", ] deps = [ "//ipc", @@ -14,4 +18,6 @@ deps = [ type_mappings = [ "mojo.test.PickledEnum=mojo::test::PickledEnumChromium", "mojo.test.PickledStruct=mojo::test::PickledStructChromium[move_only]", + "mojo.test.TestNativeStructMojom=mojo::test::TestNativeStruct", + "mojo.test.TestNativeStructWithAttachmentsMojom=mojo::test::TestNativeStructWithAttachments[move_only]", ] diff --git a/mojo/public/cpp/bindings/tests/union_unittest.cc b/mojo/public/cpp/bindings/tests/union_unittest.cc index bdf27dfff3..a45d17d7c5 100644 --- a/mojo/public/cpp/bindings/tests/union_unittest.cc +++ b/mojo/public/cpp/bindings/tests/union_unittest.cc @@ -7,6 +7,7 @@ #include <utility> #include <vector> +#include "base/containers/flat_map.h" #include "base/message_loop/message_loop.h" #include "base/run_loop.h" #include "mojo/public/cpp/bindings/binding.h" @@ -15,6 +16,7 @@ #include "mojo/public/cpp/bindings/lib/serialization.h" #include "mojo/public/cpp/bindings/lib/validation_context.h" #include "mojo/public/cpp/bindings/lib/validation_errors.h" +#include "mojo/public/cpp/bindings/message.h" #include "mojo/public/cpp/test_support/test_utils.h" #include "mojo/public/interfaces/bindings/tests/test_structs.mojom.h" #include "mojo/public/interfaces/bindings/tests/test_unions.mojom.h" @@ -23,6 +25,55 @@ namespace mojo { namespace test { +template <typename InputType, typename DataType> +size_t SerializeStruct(InputType& input, + mojo::Message* message, + mojo::internal::SerializationContext* context, + DataType** out_data) { + using StructType = typename InputType::Struct; + using DataViewType = typename StructType::DataView; + *message = mojo::Message(0, 0, 0, 0, nullptr); + const size_t payload_start = message->payload_buffer()->cursor(); + typename DataType::BufferWriter writer; + mojo::internal::Serialize<DataViewType>(input, message->payload_buffer(), + &writer, context); + *out_data = writer.is_null() ? nullptr : writer.data(); + return message->payload_buffer()->cursor() - payload_start; +} + +template <typename InputType, typename DataType> +size_t SerializeUnion(InputType& input, + mojo::Message* message, + mojo::internal::SerializationContext* context, + DataType** out_data = nullptr) { + using StructType = typename InputType::Struct; + using DataViewType = typename StructType::DataView; + *message = mojo::Message(0, 0, 0, 0, nullptr); + const size_t payload_start = message->payload_buffer()->cursor(); + typename DataType::BufferWriter writer; + mojo::internal::Serialize<DataViewType>(input, message->payload_buffer(), + &writer, false, context); + *out_data = writer.is_null() ? nullptr : writer.data(); + return message->payload_buffer()->cursor() - payload_start; +} + +template <typename DataViewType, typename InputType> +size_t SerializeArray(InputType& input, + bool nullable_elements, + mojo::Message* message, + mojo::internal::SerializationContext* context, + typename DataViewType::Data_** out_data) { + *message = mojo::Message(0, 0, 0, 0, nullptr); + const size_t payload_start = message->payload_buffer()->cursor(); + typename DataViewType::Data_::BufferWriter writer; + mojo::internal::ContainerValidateParams validate_params(0, nullable_elements, + nullptr); + mojo::internal::Serialize<DataViewType>(input, message->payload_buffer(), + &writer, &validate_params, context); + *out_data = writer.is_null() ? nullptr : writer.data(); + return message->payload_buffer()->cursor() - payload_start; +} + TEST(UnionTest, PlainOldDataGetterSetter) { PodUnionPtr pod(PodUnion::New()); @@ -53,8 +104,8 @@ TEST(UnionTest, PlainOldDataGetterSetter) { EXPECT_TRUE(pod->is_f_int32()); EXPECT_EQ(pod->which(), PodUnion::Tag::F_INT32); - pod->set_f_uint32(static_cast<uint32_t>(15)); - EXPECT_EQ(static_cast<uint32_t>(15), pod->get_f_uint32()); + pod->set_f_uint32(uint32_t{15}); + EXPECT_EQ(uint32_t{15}, pod->get_f_uint32()); EXPECT_TRUE(pod->is_f_uint32()); EXPECT_EQ(pod->which(), PodUnion::Tag::F_UINT32); @@ -63,8 +114,8 @@ TEST(UnionTest, PlainOldDataGetterSetter) { EXPECT_TRUE(pod->is_f_int64()); EXPECT_EQ(pod->which(), PodUnion::Tag::F_INT64); - pod->set_f_uint64(static_cast<uint64_t>(17)); - EXPECT_EQ(static_cast<uint64_t>(17), pod->get_f_uint64()); + pod->set_f_uint64(uint64_t{17}); + EXPECT_EQ(uint64_t{17}, pod->get_f_uint64()); EXPECT_TRUE(pod->is_f_uint64()); EXPECT_EQ(pod->which(), PodUnion::Tag::F_UINT64); @@ -91,6 +142,70 @@ TEST(UnionTest, PlainOldDataGetterSetter) { EXPECT_EQ(pod->which(), PodUnion::Tag::F_ENUM); } +TEST(UnionTest, PlainOldDataFactoryFunction) { + PodUnionPtr pod = PodUnion::NewFInt8(11); + EXPECT_EQ(11, pod->get_f_int8()); + EXPECT_TRUE(pod->is_f_int8()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_INT8); + + pod = PodUnion::NewFInt16(12); + EXPECT_EQ(12, pod->get_f_int16()); + EXPECT_TRUE(pod->is_f_int16()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_INT16); + + pod = PodUnion::NewFUint16(13); + EXPECT_EQ(13, pod->get_f_uint16()); + EXPECT_TRUE(pod->is_f_uint16()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_UINT16); + + pod = PodUnion::NewFInt32(14); + EXPECT_EQ(14, pod->get_f_int32()); + EXPECT_TRUE(pod->is_f_int32()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_INT32); + + pod = PodUnion::NewFUint32(15); + EXPECT_EQ(uint32_t{15}, pod->get_f_uint32()); + EXPECT_TRUE(pod->is_f_uint32()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_UINT32); + + pod = PodUnion::NewFInt64(16); + EXPECT_EQ(16, pod->get_f_int64()); + EXPECT_TRUE(pod->is_f_int64()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_INT64); + + pod = PodUnion::NewFUint64(17); + EXPECT_EQ(uint64_t{17}, pod->get_f_uint64()); + EXPECT_TRUE(pod->is_f_uint64()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_UINT64); + + pod = PodUnion::NewFFloat(1.5); + EXPECT_EQ(1.5, pod->get_f_float()); + EXPECT_TRUE(pod->is_f_float()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_FLOAT); + + pod = PodUnion::NewFDouble(1.9); + EXPECT_EQ(1.9, pod->get_f_double()); + EXPECT_TRUE(pod->is_f_double()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_DOUBLE); + + pod = PodUnion::NewFBool(true); + EXPECT_TRUE(pod->get_f_bool()); + pod = PodUnion::NewFBool(false); + EXPECT_FALSE(pod->get_f_bool()); + EXPECT_TRUE(pod->is_f_bool()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_BOOL); + + pod = PodUnion::NewFEnum(AnEnum::SECOND); + EXPECT_EQ(AnEnum::SECOND, pod->get_f_enum()); + EXPECT_TRUE(pod->is_f_enum()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_ENUM); + + pod = PodUnion::NewFEnum(AnEnum::FIRST); + EXPECT_EQ(AnEnum::FIRST, pod->get_f_enum()); + EXPECT_TRUE(pod->is_f_enum()); + EXPECT_EQ(pod->which(), PodUnion::Tag::F_ENUM); +} + TEST(UnionTest, PodEquals) { PodUnionPtr pod1(PodUnion::New()); PodUnionPtr pod2(PodUnion::New()); @@ -120,15 +235,10 @@ TEST(UnionTest, PodSerialization) { PodUnionPtr pod1(PodUnion::New()); pod1->set_f_int8(10); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<PodUnionDataView>( - pod1, false, &context); - EXPECT_EQ(16U, size); - - mojo::internal::FixedBufferForTesting buf(size); internal::PodUnion_Data* data = nullptr; - mojo::internal::Serialize<PodUnionDataView>(pod1, &buf, &data, false, - &context); + EXPECT_EQ(16U, SerializeUnion(pod1, &message, &context, &data)); PodUnionPtr pod2; mojo::internal::Deserialize<PodUnionDataView>(data, &pod2, &context); @@ -139,17 +249,12 @@ TEST(UnionTest, PodSerialization) { } TEST(UnionTest, EnumSerialization) { - PodUnionPtr pod1(PodUnion::New()); - pod1->set_f_enum(AnEnum::SECOND); + PodUnionPtr pod1(PodUnion::NewFEnum(AnEnum::SECOND)); - size_t size = mojo::internal::PrepareToSerialize<PodUnionDataView>( - pod1, false, nullptr); - EXPECT_EQ(16U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::PodUnion_Data* data = nullptr; - mojo::internal::Serialize<PodUnionDataView>(pod1, &buf, &data, false, - nullptr); + EXPECT_EQ(16U, SerializeUnion(pod1, &message, &context, &data)); PodUnionPtr pod2; mojo::internal::Deserialize<PodUnionDataView>(data, &pod2, nullptr); @@ -163,62 +268,52 @@ TEST(UnionTest, PodValidation) { PodUnionPtr pod(PodUnion::New()); pod->set_f_int8(10); - size_t size = - mojo::internal::PrepareToSerialize<PodUnionDataView>(pod, false, nullptr); - EXPECT_EQ(16U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::PodUnion_Data* data = nullptr; - mojo::internal::Serialize<PodUnionDataView>(pod, &buf, &data, false, nullptr); + const size_t size = SerializeUnion(pod, &message, &context, &data); + EXPECT_EQ(16U, size); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); EXPECT_TRUE( - internal::PodUnion_Data::Validate(raw_buf, &validation_context, false)); - free(raw_buf); + internal::PodUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, SerializeNotNull) { PodUnionPtr pod(PodUnion::New()); pod->set_f_int8(0); - size_t size = - mojo::internal::PrepareToSerialize<PodUnionDataView>(pod, false, nullptr); - mojo::internal::FixedBufferForTesting buf(size); + + mojo::Message message; + mojo::internal::SerializationContext context; internal::PodUnion_Data* data = nullptr; - mojo::internal::Serialize<PodUnionDataView>(pod, &buf, &data, false, nullptr); + SerializeUnion(pod, &message, &context, &data); EXPECT_FALSE(data->is_null()); } TEST(UnionTest, SerializeIsNullInlined) { PodUnionPtr pod; - size_t size = - mojo::internal::PrepareToSerialize<PodUnionDataView>(pod, false, nullptr); - EXPECT_EQ(16U, size); - mojo::internal::FixedBufferForTesting buf(size); - internal::PodUnion_Data* data = internal::PodUnion_Data::New(&buf); - - // Check that dirty output buffers are handled correctly by serialization. - data->size = 16U; - data->tag = PodUnion::Tag::F_UINT16; - data->data.f_f_int16 = 20; - mojo::internal::Serialize<PodUnionDataView>(pod, &buf, &data, true, nullptr); - EXPECT_TRUE(data->is_null()); + mojo::internal::FixedBufferForTesting buffer(16); + internal::PodUnion_Data::BufferWriter writer; + writer.Allocate(&buffer); + mojo::internal::SerializationContext context; + mojo::internal::Serialize<PodUnionDataView>(pod, &buffer, &writer, true, + &context); + EXPECT_TRUE(writer.data()->is_null()); + EXPECT_EQ(16U, buffer.cursor()); PodUnionPtr pod2; - mojo::internal::Deserialize<PodUnionDataView>(data, &pod2, nullptr); + mojo::internal::Deserialize<PodUnionDataView>(writer.data(), &pod2, nullptr); EXPECT_TRUE(pod2.is_null()); } TEST(UnionTest, SerializeIsNullNotInlined) { PodUnionPtr pod; - size_t size = - mojo::internal::PrepareToSerialize<PodUnionDataView>(pod, false, nullptr); - EXPECT_EQ(16U, size); - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::PodUnion_Data* data = nullptr; - mojo::internal::Serialize<PodUnionDataView>(pod, &buf, &data, false, nullptr); + EXPECT_EQ(0u, SerializeUnion(pod, &message, &context, &data)); EXPECT_EQ(nullptr, data); } @@ -229,85 +324,58 @@ TEST(UnionTest, NullValidation) { buf, &validation_context, false)); } -TEST(UnionTest, OutOfAlignmentValidation) { - size_t size = sizeof(internal::PodUnion_Data); - // Get an aligned object and shift the alignment. - mojo::internal::FixedBufferForTesting aligned_buf(size + 1); - void* raw_buf = aligned_buf.Leak(); - char* buf = reinterpret_cast<char*>(raw_buf) + 1; - - internal::PodUnion_Data* data = - reinterpret_cast<internal::PodUnion_Data*>(buf); - mojo::internal::ValidationContext validation_context( - data, static_cast<uint32_t>(size), 0, 0); - EXPECT_FALSE(internal::PodUnion_Data::Validate( - buf, &validation_context, false)); - free(raw_buf); -} - TEST(UnionTest, OOBValidation) { - size_t size = sizeof(internal::PodUnion_Data) - 1; - mojo::internal::FixedBufferForTesting buf(size); - internal::PodUnion_Data* data = internal::PodUnion_Data::New(&buf); + constexpr size_t size = sizeof(internal::PodUnion_Data) - 1; + mojo::Message message(0, 0, size, 0, nullptr); + internal::PodUnion_Data::BufferWriter writer; + writer.Allocate(message.payload_buffer()); mojo::internal::ValidationContext validation_context( - data, static_cast<uint32_t>(size), 0, 0); - void* raw_buf = buf.Leak(); - EXPECT_FALSE( - internal::PodUnion_Data::Validate(raw_buf, &validation_context, false)); - free(raw_buf); + writer.data(), static_cast<uint32_t>(size), 0, 0); + EXPECT_FALSE(internal::PodUnion_Data::Validate(writer.data(), + &validation_context, false)); } TEST(UnionTest, UnknownTagValidation) { - size_t size = sizeof(internal::PodUnion_Data); - mojo::internal::FixedBufferForTesting buf(size); - internal::PodUnion_Data* data = internal::PodUnion_Data::New(&buf); - data->tag = static_cast<internal::PodUnion_Data::PodUnion_Tag>(0xFFFFFF); + constexpr size_t size = sizeof(internal::PodUnion_Data); + mojo::Message message(0, 0, size, 0, nullptr); + internal::PodUnion_Data::BufferWriter writer; + writer.Allocate(message.payload_buffer()); + writer->tag = static_cast<internal::PodUnion_Data::PodUnion_Tag>(0xFFFFFF); mojo::internal::ValidationContext validation_context( - data, static_cast<uint32_t>(size), 0, 0); - void* raw_buf = buf.Leak(); - EXPECT_FALSE( - internal::PodUnion_Data::Validate(raw_buf, &validation_context, false)); - free(raw_buf); + writer.data(), static_cast<uint32_t>(size), 0, 0); + EXPECT_FALSE(internal::PodUnion_Data::Validate(writer.data(), + &validation_context, false)); } TEST(UnionTest, UnknownEnumValueValidation) { - PodUnionPtr pod(PodUnion::New()); - pod->set_f_enum(static_cast<AnEnum>(0xFFFF)); - - size_t size = - mojo::internal::PrepareToSerialize<PodUnionDataView>(pod, false, nullptr); - EXPECT_EQ(16U, size); + PodUnionPtr pod(PodUnion::NewFEnum(static_cast<AnEnum>(0xFFFF))); - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::PodUnion_Data* data = nullptr; - mojo::internal::Serialize<PodUnionDataView>(pod, &buf, &data, false, nullptr); + const size_t size = SerializeUnion(pod, &message, &context, &data); + EXPECT_EQ(16U, size); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); EXPECT_FALSE( - internal::PodUnion_Data::Validate(raw_buf, &validation_context, false)); - free(raw_buf); + internal::PodUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, UnknownExtensibleEnumValueValidation) { - PodUnionPtr pod(PodUnion::New()); - pod->set_f_extensible_enum(static_cast<AnExtensibleEnum>(0xFFFF)); + PodUnionPtr pod( + PodUnion::NewFExtensibleEnum(static_cast<AnExtensibleEnum>(0xFFFF))); - size_t size = - mojo::internal::PrepareToSerialize<PodUnionDataView>(pod, false, nullptr); - EXPECT_EQ(16U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::PodUnion_Data* data = nullptr; - mojo::internal::Serialize<PodUnionDataView>(pod, &buf, &data, false, nullptr); + const size_t size = SerializeUnion(pod, &message, &context, &data); + EXPECT_EQ(16U, size); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); EXPECT_TRUE( - internal::PodUnion_Data::Validate(raw_buf, &validation_context, false)); - free(raw_buf); + internal::PodUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, StringGetterSetter) { @@ -320,12 +388,19 @@ TEST(UnionTest, StringGetterSetter) { EXPECT_EQ(pod->which(), ObjectUnion::Tag::F_STRING); } +TEST(UnionTest, StringFactoryFunction) { + std::string hello("hello world"); + ObjectUnionPtr pod(ObjectUnion::NewFString(hello)); + + EXPECT_EQ(hello, pod->get_f_string()); + EXPECT_TRUE(pod->is_f_string()); + EXPECT_EQ(pod->which(), ObjectUnion::Tag::F_STRING); +} + TEST(UnionTest, StringEquals) { - ObjectUnionPtr pod1(ObjectUnion::New()); - ObjectUnionPtr pod2(ObjectUnion::New()); + ObjectUnionPtr pod1(ObjectUnion::NewFString("hello world")); + ObjectUnionPtr pod2(ObjectUnion::NewFString("hello world")); - pod1->set_f_string("hello world"); - pod2->set_f_string("hello world"); EXPECT_TRUE(pod1.Equals(pod2)); pod2->set_f_string("hello universe"); @@ -333,10 +408,9 @@ TEST(UnionTest, StringEquals) { } TEST(UnionTest, StringClone) { - ObjectUnionPtr pod(ObjectUnion::New()); - std::string hello("hello world"); - pod->set_f_string(hello); + ObjectUnionPtr pod(ObjectUnion::NewFString(hello)); + ObjectUnionPtr pod_clone = pod.Clone(); EXPECT_EQ(hello, pod_clone->get_f_string()); EXPECT_TRUE(pod_clone->is_f_string()); @@ -344,17 +418,13 @@ TEST(UnionTest, StringClone) { } TEST(UnionTest, StringSerialization) { - ObjectUnionPtr pod1(ObjectUnion::New()); - std::string hello("hello world"); - pod1->set_f_string(hello); + ObjectUnionPtr pod1(ObjectUnion::NewFString(hello)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - pod1, false, nullptr); - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(pod1, &buf, &data, false, - nullptr); + SerializeUnion(pod1, &message, &context, &data); ObjectUnionPtr pod2; mojo::internal::Deserialize<ObjectUnionDataView>(data, &pod2, nullptr); @@ -364,50 +434,47 @@ TEST(UnionTest, StringSerialization) { } TEST(UnionTest, NullStringValidation) { - size_t size = sizeof(internal::ObjectUnion_Data); - mojo::internal::FixedBufferForTesting buf(size); - internal::ObjectUnion_Data* data = internal::ObjectUnion_Data::New(&buf); - data->tag = internal::ObjectUnion_Data::ObjectUnion_Tag::F_STRING; - data->data.unknown = 0x0; + constexpr size_t size = sizeof(internal::ObjectUnion_Data); + mojo::internal::FixedBufferForTesting buffer(size); + internal::ObjectUnion_Data::BufferWriter writer; + writer.Allocate(&buffer); + writer->tag = internal::ObjectUnion_Data::ObjectUnion_Tag::F_STRING; + writer->data.unknown = 0x0; mojo::internal::ValidationContext validation_context( - data, static_cast<uint32_t>(size), 0, 0); - void* raw_buf = buf.Leak(); + writer.data(), static_cast<uint32_t>(size), 0, 0); EXPECT_FALSE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + buffer.data(), &validation_context, false)); } TEST(UnionTest, StringPointerOverflowValidation) { - size_t size = sizeof(internal::ObjectUnion_Data); - mojo::internal::FixedBufferForTesting buf(size); - internal::ObjectUnion_Data* data = internal::ObjectUnion_Data::New(&buf); - data->tag = internal::ObjectUnion_Data::ObjectUnion_Tag::F_STRING; - data->data.unknown = 0xFFFFFFFFFFFFFFFF; + constexpr size_t size = sizeof(internal::ObjectUnion_Data); + mojo::internal::FixedBufferForTesting buffer(size); + internal::ObjectUnion_Data::BufferWriter writer; + writer.Allocate(&buffer); + writer->tag = internal::ObjectUnion_Data::ObjectUnion_Tag::F_STRING; + writer->data.unknown = 0xFFFFFFFFFFFFFFFF; mojo::internal::ValidationContext validation_context( - data, static_cast<uint32_t>(size), 0, 0); - void* raw_buf = buf.Leak(); + writer.data(), static_cast<uint32_t>(size), 0, 0); EXPECT_FALSE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + buffer.data(), &validation_context, false)); } TEST(UnionTest, StringValidateOOB) { - size_t size = 32; - mojo::internal::FixedBufferForTesting buf(size); - internal::ObjectUnion_Data* data = internal::ObjectUnion_Data::New(&buf); - data->tag = internal::ObjectUnion_Data::ObjectUnion_Tag::F_STRING; - - data->data.f_f_string.offset = 8; - char* ptr = reinterpret_cast<char*>(&data->data.f_f_string); + constexpr size_t size = 32; + mojo::internal::FixedBufferForTesting buffer(size); + internal::ObjectUnion_Data::BufferWriter writer; + writer.Allocate(&buffer); + writer->tag = internal::ObjectUnion_Data::ObjectUnion_Tag::F_STRING; + + writer->data.f_f_string.offset = 8; + char* ptr = reinterpret_cast<char*>(&writer->data.f_f_string); mojo::internal::ArrayHeader* array_header = reinterpret_cast<mojo::internal::ArrayHeader*>(ptr + *ptr); array_header->num_bytes = 20; // This should go out of bounds. array_header->num_elements = 20; - mojo::internal::ValidationContext validation_context(data, 32, 0, 0); - void* raw_buf = buf.Leak(); + mojo::internal::ValidationContext validation_context(writer.data(), 32, 0, 0); EXPECT_FALSE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + buffer.data(), &validation_context, false)); } // TODO(azani): Move back in array_unittest.cc when possible. @@ -434,23 +501,16 @@ TEST(UnionTest, PodUnionInArraySerialization) { array[1]->set_f_int16(12); EXPECT_EQ(2U, array.size()); - size_t size = - mojo::internal::PrepareToSerialize<ArrayDataView<PodUnionDataView>>( - array, nullptr); - EXPECT_EQ(40U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; mojo::internal::Array_Data<internal::PodUnion_Data>* data; - mojo::internal::ContainerValidateParams validate_params(0, false, nullptr); - mojo::internal::Serialize<ArrayDataView<PodUnionDataView>>( - array, &buf, &data, &validate_params, nullptr); + EXPECT_EQ(40U, SerializeArray<ArrayDataView<PodUnionDataView>>( + array, false, &message, &context, &data)); std::vector<PodUnionPtr> array2; mojo::internal::Deserialize<ArrayDataView<PodUnionDataView>>(data, &array2, nullptr); - EXPECT_EQ(2U, array2.size()); - EXPECT_EQ(10, array2[0]->get_f_int8()); EXPECT_EQ(12, array2[1]->get_f_int16()); } @@ -462,23 +522,16 @@ TEST(UnionTest, PodUnionInArraySerializationWithNull) { array[0]->set_f_int8(10); EXPECT_EQ(2U, array.size()); - size_t size = - mojo::internal::PrepareToSerialize<ArrayDataView<PodUnionDataView>>( - array, nullptr); - EXPECT_EQ(40U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; mojo::internal::Array_Data<internal::PodUnion_Data>* data; - mojo::internal::ContainerValidateParams validate_params(0, true, nullptr); - mojo::internal::Serialize<ArrayDataView<PodUnionDataView>>( - array, &buf, &data, &validate_params, nullptr); + EXPECT_EQ(40U, SerializeArray<ArrayDataView<PodUnionDataView>>( + array, true, &message, &context, &data)); std::vector<PodUnionPtr> array2; mojo::internal::Deserialize<ArrayDataView<PodUnionDataView>>(data, &array2, nullptr); - EXPECT_EQ(2U, array2.size()); - EXPECT_EQ(10, array2[0]->get_f_int8()); EXPECT_TRUE(array2[1].is_null()); } @@ -492,30 +545,23 @@ TEST(UnionTest, ObjectUnionInArraySerialization) { array[1]->set_f_string("world"); EXPECT_EQ(2U, array.size()); - size_t size = - mojo::internal::PrepareToSerialize<ArrayDataView<ObjectUnionDataView>>( - array, nullptr); - EXPECT_EQ(72U, size); - - mojo::internal::FixedBufferForTesting buf(size); - + mojo::Message message; + mojo::internal::SerializationContext context; mojo::internal::Array_Data<internal::ObjectUnion_Data>* data; - mojo::internal::ContainerValidateParams validate_params(0, false, nullptr); - mojo::internal::Serialize<ArrayDataView<ObjectUnionDataView>>( - array, &buf, &data, &validate_params, nullptr); + const size_t size = SerializeArray<ArrayDataView<ObjectUnionDataView>>( + array, false, &message, &context, &data); + EXPECT_EQ(72U, size); std::vector<char> new_buf; new_buf.resize(size); - - void* raw_buf = buf.Leak(); - memcpy(new_buf.data(), raw_buf, size); - free(raw_buf); + memcpy(new_buf.data(), data, size); data = reinterpret_cast<mojo::internal::Array_Data<internal::ObjectUnion_Data>*>( new_buf.data()); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); + mojo::internal::ContainerValidateParams validate_params(0, false, nullptr); ASSERT_TRUE(mojo::internal::Array_Data<internal::ObjectUnion_Data>::Validate( data, &validation_context, &validate_params)); @@ -546,14 +592,10 @@ TEST(UnionTest, Serialization_UnionOfPods) { small_struct->pod_union = PodUnion::New(); small_struct->pod_union->set_f_int32(10); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<SmallStructDataView>( - small_struct, &context); - - mojo::internal::FixedBufferForTesting buf(size); internal::SmallStruct_Data* data = nullptr; - mojo::internal::Serialize<SmallStructDataView>(small_struct, &buf, &data, - &context); + SerializeStruct(small_struct, &message, &context, &data); SmallStructPtr deserialized; mojo::internal::Deserialize<SmallStructDataView>(data, &deserialized, @@ -569,13 +611,10 @@ TEST(UnionTest, Serialization_UnionOfObjects) { std::string hello("hello world"); obj_struct->obj_union->set_f_string(hello); - size_t size = mojo::internal::PrepareToSerialize<SmallObjStructDataView>( - obj_struct, nullptr); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::SmallObjStruct_Data* data = nullptr; - mojo::internal::Serialize<SmallObjStructDataView>(obj_struct, &buf, &data, - nullptr); + SerializeStruct(obj_struct, &message, &context, &data); SmallObjStructPtr deserialized; mojo::internal::Deserialize<SmallObjStructDataView>(data, &deserialized, @@ -590,21 +629,14 @@ TEST(UnionTest, Validation_UnionsInStruct) { small_struct->pod_union = PodUnion::New(); small_struct->pod_union->set_f_int32(10); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<SmallStructDataView>( - small_struct, &context); - - mojo::internal::FixedBufferForTesting buf(size); internal::SmallStruct_Data* data = nullptr; - mojo::internal::Serialize<SmallStructDataView>(small_struct, &buf, &data, - &context); + const size_t size = SerializeStruct(small_struct, &message, &context, &data); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - EXPECT_TRUE(internal::SmallStruct_Data::Validate( - raw_buf, &validation_context)); - free(raw_buf); + EXPECT_TRUE(internal::SmallStruct_Data::Validate(data, &validation_context)); } // Validation test of a struct union fails due to unknown union tag. @@ -613,22 +645,15 @@ TEST(UnionTest, Validation_PodUnionInStruct_Failure) { small_struct->pod_union = PodUnion::New(); small_struct->pod_union->set_f_int32(10); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<SmallStructDataView>( - small_struct, &context); - - mojo::internal::FixedBufferForTesting buf(size); internal::SmallStruct_Data* data = nullptr; - mojo::internal::Serialize<SmallStructDataView>(small_struct, &buf, &data, - &context); + const size_t size = SerializeStruct(small_struct, &message, &context, &data); data->pod_union.tag = static_cast<internal::PodUnion_Data::PodUnion_Tag>(100); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - EXPECT_FALSE(internal::SmallStruct_Data::Validate( - raw_buf, &validation_context)); - free(raw_buf); + EXPECT_FALSE(internal::SmallStruct_Data::Validate(data, &validation_context)); } // Validation fails due to non-nullable null union in struct. @@ -636,41 +661,29 @@ TEST(UnionTest, Validation_NullUnion_Failure) { SmallStructNonNullableUnionPtr small_struct( SmallStructNonNullableUnion::New()); - size_t size = - mojo::internal::PrepareToSerialize<SmallStructNonNullableUnionDataView>( - small_struct, nullptr); - - mojo::internal::FixedBufferForTesting buf(size); - internal::SmallStructNonNullableUnion_Data* data = - internal::SmallStructNonNullableUnion_Data::New(&buf); - - void* raw_buf = buf.Leak(); + constexpr size_t size = sizeof(internal::SmallStructNonNullableUnion_Data); + mojo::internal::FixedBufferForTesting buffer(size); + mojo::Message message; + internal::SmallStructNonNullableUnion_Data::BufferWriter writer; + writer.Allocate(&buffer); mojo::internal::ValidationContext validation_context( - data, static_cast<uint32_t>(size), 0, 0); + writer.data(), static_cast<uint32_t>(size), 0, 0); EXPECT_FALSE(internal::SmallStructNonNullableUnion_Data::Validate( - raw_buf, &validation_context)); - free(raw_buf); + writer.data(), &validation_context)); } // Validation passes with nullable null union. TEST(UnionTest, Validation_NullableUnion) { SmallStructPtr small_struct(SmallStruct::New()); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<SmallStructDataView>( - small_struct, &context); - - mojo::internal::FixedBufferForTesting buf(size); internal::SmallStruct_Data* data = nullptr; - mojo::internal::Serialize<SmallStructDataView>(small_struct, &buf, &data, - &context); + const size_t size = SerializeStruct(small_struct, &message, &context, &data); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - EXPECT_TRUE(internal::SmallStruct_Data::Validate( - raw_buf, &validation_context)); - free(raw_buf); + EXPECT_TRUE(internal::SmallStruct_Data::Validate(data, &validation_context)); } // TODO(azani): Move back in map_unittest.cc when possible. @@ -691,28 +704,28 @@ TEST(UnionTest, PodUnionInMap) { TEST(UnionTest, PodUnionInMapSerialization) { using MojomType = MapDataView<StringDataView, PodUnionDataView>; - std::unordered_map<std::string, PodUnionPtr> map; + base::flat_map<std::string, PodUnionPtr> map; map.insert(std::make_pair("one", PodUnion::New())); map.insert(std::make_pair("two", PodUnion::New())); map["one"]->set_f_int8(8); map["two"]->set_f_int16(16); + mojo::Message message(0, 0, 0, 0, nullptr); mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<MojomType>(map, &context); - EXPECT_EQ(120U, size); - - mojo::internal::FixedBufferForTesting buf(size); + const size_t payload_start = message.payload_buffer()->cursor(); - typename mojo::internal::MojomTypeTraits<MojomType>::Data* data; + typename mojo::internal::MojomTypeTraits<MojomType>::Data::BufferWriter + writer; mojo::internal::ContainerValidateParams validate_params( new mojo::internal::ContainerValidateParams(0, false, nullptr), new mojo::internal::ContainerValidateParams(0, false, nullptr)); - mojo::internal::Serialize<MojomType>(map, &buf, &data, &validate_params, - &context); + mojo::internal::Serialize<MojomType>(map, message.payload_buffer(), &writer, + &validate_params, &context); + EXPECT_EQ(120U, message.payload_buffer()->cursor() - payload_start); - std::unordered_map<std::string, PodUnionPtr> map2; - mojo::internal::Deserialize<MojomType>(data, &map2, &context); + base::flat_map<std::string, PodUnionPtr> map2; + mojo::internal::Deserialize<MojomType>(writer.data(), &map2, &context); EXPECT_EQ(8, map2["one"]->get_f_int8()); EXPECT_EQ(16, map2["two"]->get_f_int16()); @@ -721,26 +734,27 @@ TEST(UnionTest, PodUnionInMapSerialization) { TEST(UnionTest, PodUnionInMapSerializationWithNull) { using MojomType = MapDataView<StringDataView, PodUnionDataView>; - std::unordered_map<std::string, PodUnionPtr> map; + base::flat_map<std::string, PodUnionPtr> map; map.insert(std::make_pair("one", PodUnion::New())); map.insert(std::make_pair("two", nullptr)); map["one"]->set_f_int8(8); + mojo::Message message(0, 0, 0, 0, nullptr); mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<MojomType>(map, &context); - EXPECT_EQ(120U, size); + const size_t payload_start = message.payload_buffer()->cursor(); - mojo::internal::FixedBufferForTesting buf(size); - typename mojo::internal::MojomTypeTraits<MojomType>::Data* data; + typename mojo::internal::MojomTypeTraits<MojomType>::Data::BufferWriter + writer; mojo::internal::ContainerValidateParams validate_params( new mojo::internal::ContainerValidateParams(0, false, nullptr), new mojo::internal::ContainerValidateParams(0, true, nullptr)); - mojo::internal::Serialize<MojomType>(map, &buf, &data, &validate_params, - &context); + mojo::internal::Serialize<MojomType>(map, message.payload_buffer(), &writer, + &validate_params, &context); + EXPECT_EQ(120U, message.payload_buffer()->cursor() - payload_start); - std::unordered_map<std::string, PodUnionPtr> map2; - mojo::internal::Deserialize<MojomType>(data, &map2, &context); + base::flat_map<std::string, PodUnionPtr> map2; + mojo::internal::Deserialize<MojomType>(writer.data(), &map2, &context); EXPECT_EQ(8, map2["one"]->get_f_int8()); EXPECT_TRUE(map2["two"].is_null()); @@ -763,14 +777,10 @@ TEST(UnionTest, StructInUnionSerialization) { ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_dummy(std::move(dummy)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, nullptr); - EXPECT_EQ(32U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - nullptr); + EXPECT_EQ(32U, SerializeUnion(obj, &message, &context, &data)); ObjectUnionPtr obj2; mojo::internal::Deserialize<ObjectUnionDataView>(data, &obj2, nullptr); @@ -784,20 +794,15 @@ TEST(UnionTest, StructInUnionValidation) { ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_dummy(std::move(dummy)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, nullptr); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - nullptr); + const size_t size = SerializeUnion(obj, &message, &context, &data); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - EXPECT_TRUE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + EXPECT_TRUE( + internal::ObjectUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, StructInUnionValidationNonNullable) { @@ -808,20 +813,15 @@ TEST(UnionTest, StructInUnionValidationNonNullable) { ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_dummy(std::move(dummy)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, nullptr); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - nullptr); + const size_t size = SerializeUnion(obj, &message, &context, &data); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - EXPECT_FALSE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + EXPECT_FALSE( + internal::ObjectUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, StructInUnionValidationNullable) { @@ -830,20 +830,15 @@ TEST(UnionTest, StructInUnionValidationNullable) { ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_nullable(std::move(dummy)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, nullptr); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - nullptr); + const size_t size = SerializeUnion(obj, &message, &context, &data); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - EXPECT_TRUE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + EXPECT_TRUE( + internal::ObjectUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, ArrayInUnionGetterSetter) { @@ -866,14 +861,11 @@ TEST(UnionTest, ArrayInUnionSerialization) { ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_array_int8(std::move(array)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, nullptr); - EXPECT_EQ(32U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - nullptr); + const size_t size = SerializeUnion(obj, &message, &context, &data); + EXPECT_EQ(32U, size); ObjectUnionPtr obj2; mojo::internal::Deserialize<ObjectUnionDataView>(data, &obj2, nullptr); @@ -890,24 +882,19 @@ TEST(UnionTest, ArrayInUnionValidation) { ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_array_int8(std::move(array)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, nullptr); - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - nullptr); + const size_t size = SerializeUnion(obj, &message, &context, &data); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - - EXPECT_TRUE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + EXPECT_TRUE( + internal::ObjectUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, MapInUnionGetterSetter) { - std::unordered_map<std::string, int8_t> map; + base::flat_map<std::string, int8_t> map; map.insert({"one", 1}); map.insert({"two", 2}); @@ -919,22 +906,18 @@ TEST(UnionTest, MapInUnionGetterSetter) { } TEST(UnionTest, MapInUnionSerialization) { - std::unordered_map<std::string, int8_t> map; + base::flat_map<std::string, int8_t> map; map.insert({"one", 1}); map.insert({"two", 2}); ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_map_int8(std::move(map)); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, &context); - EXPECT_EQ(112U, size); - - mojo::internal::FixedBufferForTesting buf(size); internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - &context); + const size_t size = SerializeUnion(obj, &message, &context, &data); + EXPECT_EQ(112U, size); ObjectUnionPtr obj2; mojo::internal::Deserialize<ObjectUnionDataView>(data, &obj2, &context); @@ -944,30 +927,23 @@ TEST(UnionTest, MapInUnionSerialization) { } TEST(UnionTest, MapInUnionValidation) { - std::unordered_map<std::string, int8_t> map; + base::flat_map<std::string, int8_t> map; map.insert({"one", 1}); map.insert({"two", 2}); ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_map_int8(std::move(map)); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, &context); - EXPECT_EQ(112U, size); - - mojo::internal::FixedBufferForTesting buf(size); internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - &context); + const size_t size = SerializeUnion(obj, &message, &context, &data); + EXPECT_EQ(112U, size); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - - EXPECT_TRUE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + EXPECT_TRUE( + internal::ObjectUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, UnionInUnionGetterSetter) { @@ -980,6 +956,15 @@ TEST(UnionTest, UnionInUnionGetterSetter) { EXPECT_EQ(10, obj->get_f_pod_union()->get_f_int8()); } +TEST(UnionTest, UnionInUnionFactoryFunction) { + PodUnionPtr pod(PodUnion::New()); + pod->set_f_int8(10); + + ObjectUnionPtr obj(ObjectUnion::NewFPodUnion(std::move(pod))); + + EXPECT_EQ(10, obj->get_f_pod_union()->get_f_int8()); +} + TEST(UnionTest, UnionInUnionSerialization) { PodUnionPtr pod(PodUnion::New()); pod->set_f_int8(10); @@ -987,14 +972,11 @@ TEST(UnionTest, UnionInUnionSerialization) { ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_pod_union(std::move(pod)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, nullptr); - EXPECT_EQ(32U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - nullptr); + const size_t size = SerializeUnion(obj, &message, &context, &data); + EXPECT_EQ(32U, size); ObjectUnionPtr obj2; mojo::internal::Deserialize<ObjectUnionDataView>(data, &obj2, nullptr); @@ -1008,21 +990,16 @@ TEST(UnionTest, UnionInUnionValidation) { ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_pod_union(std::move(pod)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, nullptr); - EXPECT_EQ(32U, size); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - nullptr); + const size_t size = SerializeUnion(obj, &message, &context, &data); + EXPECT_EQ(32U, size); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - EXPECT_TRUE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + EXPECT_TRUE( + internal::ObjectUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, UnionInUnionValidationNonNullable) { @@ -1033,20 +1010,15 @@ TEST(UnionTest, UnionInUnionValidationNonNullable) { ObjectUnionPtr obj(ObjectUnion::New()); obj->set_f_pod_union(std::move(pod)); - size_t size = mojo::internal::PrepareToSerialize<ObjectUnionDataView>( - obj, false, nullptr); - - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::ObjectUnion_Data* data = nullptr; - mojo::internal::Serialize<ObjectUnionDataView>(obj, &buf, &data, false, - nullptr); + const size_t size = SerializeUnion(obj, &message, &context, &data); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 0, 0); - EXPECT_FALSE(internal::ObjectUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + EXPECT_FALSE( + internal::ObjectUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, HandleInUnionGetterSetter) { @@ -1067,6 +1039,23 @@ TEST(UnionTest, HandleInUnionGetterSetter) { EXPECT_EQ(golden, actual); } +TEST(UnionTest, HandleInUnionGetterFactoryFunction) { + ScopedMessagePipeHandle pipe0; + ScopedMessagePipeHandle pipe1; + + CreateMessagePipe(nullptr, &pipe0, &pipe1); + + HandleUnionPtr handle(HandleUnion::NewFMessagePipe(std::move(pipe1))); + + std::string golden("hello world"); + WriteTextMessage(pipe0.get(), golden); + + std::string actual; + ReadTextMessage(handle->get_f_message_pipe().get(), &actual); + + EXPECT_EQ(golden, actual); +} + TEST(UnionTest, HandleInUnionSerialization) { ScopedMessagePipeHandle pipe0; ScopedMessagePipeHandle pipe1; @@ -1076,16 +1065,12 @@ TEST(UnionTest, HandleInUnionSerialization) { HandleUnionPtr handle(HandleUnion::New()); handle->set_f_message_pipe(std::move(pipe1)); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<HandleUnionDataView>( - handle, false, &context); - EXPECT_EQ(16U, size); - - mojo::internal::FixedBufferForTesting buf(size); internal::HandleUnion_Data* data = nullptr; - mojo::internal::Serialize<HandleUnionDataView>(handle, &buf, &data, false, - &context); - EXPECT_EQ(1U, context.handles.size()); + const size_t size = SerializeUnion(handle, &message, &context, &data); + EXPECT_EQ(16U, size); + EXPECT_EQ(1U, context.handles()->size()); HandleUnionPtr handle2(HandleUnion::New()); mojo::internal::Deserialize<HandleUnionDataView>(data, &handle2, &context); @@ -1108,22 +1093,16 @@ TEST(UnionTest, HandleInUnionValidation) { HandleUnionPtr handle(HandleUnion::New()); handle->set_f_message_pipe(std::move(pipe1)); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<HandleUnionDataView>( - handle, false, &context); - EXPECT_EQ(16U, size); - - mojo::internal::FixedBufferForTesting buf(size); internal::HandleUnion_Data* data = nullptr; - mojo::internal::Serialize<HandleUnionDataView>(handle, &buf, &data, false, - &context); + const size_t size = SerializeUnion(handle, &message, &context, &data); + EXPECT_EQ(16U, size); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 1, 0); - EXPECT_TRUE(internal::HandleUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + EXPECT_TRUE( + internal::HandleUnion_Data::Validate(data, &validation_context, false)); } TEST(UnionTest, HandleInUnionValidationNull) { @@ -1133,22 +1112,16 @@ TEST(UnionTest, HandleInUnionValidationNull) { HandleUnionPtr handle(HandleUnion::New()); handle->set_f_message_pipe(std::move(pipe)); + mojo::Message message; mojo::internal::SerializationContext context; - size_t size = mojo::internal::PrepareToSerialize<HandleUnionDataView>( - handle, false, &context); - EXPECT_EQ(16U, size); - - mojo::internal::FixedBufferForTesting buf(size); internal::HandleUnion_Data* data = nullptr; - mojo::internal::Serialize<HandleUnionDataView>(handle, &buf, &data, false, - &context); + const size_t size = SerializeUnion(handle, &message, &context, &data); + EXPECT_EQ(16U, size); - void* raw_buf = buf.Leak(); mojo::internal::ValidationContext validation_context( data, static_cast<uint32_t>(size), 1, 0); - EXPECT_FALSE(internal::HandleUnion_Data::Validate( - raw_buf, &validation_context, false)); - free(raw_buf); + EXPECT_FALSE( + internal::HandleUnion_Data::Validate(data, &validation_context, false)); } class SmallCacheImpl : public SmallCache { @@ -1179,9 +1152,24 @@ TEST(UnionTest, InterfaceInUnion) { Binding<SmallCache> bindings(&impl, MakeRequest(&ptr)); HandleUnionPtr handle(HandleUnion::New()); - handle->set_f_small_cache(std::move(ptr)); + handle->set_f_small_cache(ptr.PassInterface()); - handle->get_f_small_cache()->SetIntValue(10); + ptr.Bind(std::move(handle->get_f_small_cache())); + ptr->SetIntValue(10); + run_loop.Run(); + EXPECT_EQ(10, impl.int_value()); +} + +TEST(UnionTest, InterfaceInUnionFactoryFunction) { + base::MessageLoop message_loop; + base::RunLoop run_loop; + SmallCacheImpl impl(run_loop.QuitClosure()); + SmallCachePtr ptr; + Binding<SmallCache> bindings(&impl, MakeRequest(&ptr)); + + HandleUnionPtr handle = HandleUnion::NewFSmallCache(ptr.PassInterface()); + ptr.Bind(std::move(handle->get_f_small_cache())); + ptr->SetIntValue(10); run_loop.Run(); EXPECT_EQ(10, impl.int_value()); } @@ -1193,23 +1181,21 @@ TEST(UnionTest, InterfaceInUnionSerialization) { SmallCachePtr ptr; Binding<SmallCache> bindings(&impl, MakeRequest(&ptr)); - mojo::internal::SerializationContext context; HandleUnionPtr handle(HandleUnion::New()); - handle->set_f_small_cache(std::move(ptr)); - size_t size = mojo::internal::PrepareToSerialize<HandleUnionDataView>( - handle, false, &context); - EXPECT_EQ(16U, size); + handle->set_f_small_cache(ptr.PassInterface()); - mojo::internal::FixedBufferForTesting buf(size); + mojo::Message message; + mojo::internal::SerializationContext context; internal::HandleUnion_Data* data = nullptr; - mojo::internal::Serialize<HandleUnionDataView>(handle, &buf, &data, false, - &context); - EXPECT_EQ(1U, context.handles.size()); + const size_t size = SerializeUnion(handle, &message, &context, &data); + EXPECT_EQ(16U, size); + EXPECT_EQ(1U, context.handles()->size()); HandleUnionPtr handle2(HandleUnion::New()); mojo::internal::Deserialize<HandleUnionDataView>(data, &handle2, &context); - handle2->get_f_small_cache()->SetIntValue(10); + ptr.Bind(std::move(handle2->get_f_small_cache())); + ptr->SetIntValue(10); run_loop.Run(); EXPECT_EQ(10, impl.int_value()); } diff --git a/mojo/public/cpp/bindings/tests/validation_unittest.cc b/mojo/public/cpp/bindings/tests/validation_unittest.cc index 7af7396d4e..2db4d9b598 100644 --- a/mojo/public/cpp/bindings/tests/validation_unittest.cc +++ b/mojo/public/cpp/bindings/tests/validation_unittest.cc @@ -10,7 +10,9 @@ #include <utility> #include <vector> +#include "base/macros.h" #include "base/message_loop/message_loop.h" +#include "base/numerics/safe_math.h" #include "base/run_loop.h" #include "base/threading/thread_task_runner_handle.h" #include "mojo/public/c/system/macros.h" @@ -23,6 +25,7 @@ #include "mojo/public/cpp/bindings/message_header_validator.h" #include "mojo/public/cpp/bindings/tests/validation_test_input_parser.h" #include "mojo/public/cpp/system/core.h" +#include "mojo/public/cpp/system/message.h" #include "mojo/public/cpp/test_support/test_support.h" #include "mojo/public/interfaces/bindings/tests/validation_test_associated_interfaces.mojom.h" #include "mojo/public/interfaces/bindings/tests/validation_test_interfaces.mojom.h" @@ -32,6 +35,25 @@ namespace mojo { namespace test { namespace { +Message CreateRawMessage(size_t size) { + ScopedMessageHandle handle; + MojoResult rv = CreateMessage(&handle); + DCHECK_EQ(MOJO_RESULT_OK, rv); + DCHECK(handle.is_valid()); + + DCHECK(base::IsValueInRangeForNumericType<uint32_t>(size)); + MojoAppendMessageDataOptions options; + options.struct_size = sizeof(options); + options.flags = MOJO_APPEND_MESSAGE_DATA_FLAG_COMMIT_SIZE; + void* buffer; + uint32_t buffer_size; + rv = MojoAppendMessageData(handle->value(), static_cast<uint32_t>(size), + nullptr, 0, &options, &buffer, &buffer_size); + DCHECK_EQ(MOJO_RESULT_OK, rv); + + return Message(std::move(handle)); +} + template <typename T> void Append(std::vector<uint8_t>* data_vector, T data) { size_t pos = data_vector->size(); @@ -150,8 +172,7 @@ bool ReadTestCase(const std::string& test, return false; } - message->Initialize(static_cast<uint32_t>(data.size()), - false /* zero_initialized */); + *message = CreateRawMessage(data.size()); if (!data.empty()) memcpy(message->mutable_data(), &data[0], data.size()); message->mutable_handles()->resize(num_handles); @@ -443,8 +464,7 @@ TEST_F(ValidationIntegrationTest, InterfacePtr) { TEST_F(ValidationIntegrationTest, Binding) { IntegrationTestInterfaceImpl interface_impl; Binding<IntegrationTestInterface> binding( - &interface_impl, - MakeRequest<IntegrationTestInterface>(testee_endpoint())); + &interface_impl, IntegrationTestInterfaceRequest(testee_endpoint())); binding.EnableTestingMode(); RunValidationTests("integration_intf_rqst", test_message_receiver()); diff --git a/mojo/public/cpp/bindings/tests/variant_test_util.h b/mojo/public/cpp/bindings/tests/variant_test_util.h index 3f6b1f1b5f..f1e75b3298 100644 --- a/mojo/public/cpp/bindings/tests/variant_test_util.h +++ b/mojo/public/cpp/bindings/tests/variant_test_util.h @@ -21,9 +21,7 @@ template <typename Interface0, typename Interface1> InterfaceRequest<Interface0> ConvertInterfaceRequest( InterfaceRequest<Interface1> request) { DCHECK_EQ(0, strcmp(Interface0::Name_, Interface1::Name_)); - InterfaceRequest<Interface0> result; - result.Bind(request.PassMessagePipe()); - return result; + return InterfaceRequest<Interface0>(request.PassMessagePipe()); } } // namespace test diff --git a/mojo/public/cpp/bindings/tests/versioning_test_service.cc b/mojo/public/cpp/bindings/tests/versioning_test_service.cc index 313a6249e2..174147883f 100644 --- a/mojo/public/cpp/bindings/tests/versioning_test_service.cc +++ b/mojo/public/cpp/bindings/tests/versioning_test_service.cc @@ -11,7 +11,7 @@ #include "mojo/public/cpp/bindings/strong_binding.h" #include "mojo/public/interfaces/bindings/tests/versioning_test_service.mojom.h" #include "services/service_manager/public/c/main.h" -#include "services/service_manager/public/cpp/interface_factory.h" +#include "services/service_manager/public/cpp/binder_registry.h" #include "services/service_manager/public/cpp/service.h" #include "services/service_manager/public/cpp/service_runner.h" @@ -94,25 +94,28 @@ class HumanResourceDatabaseImpl : public HumanResourceDatabase { StrongBinding<HumanResourceDatabase> strong_binding_; }; -class HumanResourceSystemServer - : public service_manager::Service, - public InterfaceFactory<HumanResourceDatabase> { +class HumanResourceSystemServer : public service_manager::Service { public: - HumanResourceSystemServer() {} + HumanResourceSystemServer() { + registry_.AddInterface<HumanResourceDatabase>( + base::Bind(&HumanResourceSystemServer::Create, base::Unretained(this))); + } // service_manager::Service implementation. - bool OnConnect(Connection* connection) override { - connection->AddInterface<HumanResourceDatabase>(this); - return true; + void OnBindInterface(const service_manager::BindSourceInfo& source_info, + const std::string& interface_name, + mojo::ScopedMessagePipeHandle interface_pipe) override { + registry_.BindInterface(interface_name, std::move(interface_pipe)); } - // InterfaceFactory<HumanResourceDatabase> implementation. - void Create(Connection* connection, - InterfaceRequest<HumanResourceDatabase> request) override { + void Create(HumanResourceDatabaseRequest request) { // It will be deleted automatically when the underlying pipe encounters a // connection error. new HumanResourceDatabaseImpl(std::move(request)); } + + private: + service_manager::BinderRegistry registry_; }; } // namespace versioning diff --git a/mojo/public/cpp/bindings/tests/wtf_hash_unittest.cc b/mojo/public/cpp/bindings/tests/wtf_hash_unittest.cc index 959d25b368..458df2c9f1 100644 --- a/mojo/public/cpp/bindings/tests/wtf_hash_unittest.cc +++ b/mojo/public/cpp/bindings/tests/wtf_hash_unittest.cc @@ -7,7 +7,7 @@ #include "mojo/public/interfaces/bindings/tests/test_structs.mojom-blink.h" #include "mojo/public/interfaces/bindings/tests/test_wtf_types.mojom-blink.h" #include "testing/gtest/include/gtest/gtest.h" -#include "third_party/WebKit/Source/wtf/HashFunctions.h" +#include "third_party/blink/renderer/platform/wtf/hash_functions.h" namespace mojo { namespace test { @@ -25,33 +25,25 @@ TEST_F(WTFHashTest, NestedStruct) { blink::SimpleNestedStruct::New(blink::ContainsOther::New(1)))); } -TEST_F(WTFHashTest, UnmappedNativeStruct) { - // Just check that this template instantiation compiles. - ASSERT_EQ(::mojo::internal::Hash(::mojo::internal::kHashSeed, - blink::UnmappedNativeStruct::New()), - ::mojo::internal::Hash(::mojo::internal::kHashSeed, - blink::UnmappedNativeStruct::New())); -} - TEST_F(WTFHashTest, Enum) { // Just check that this template instantiation compiles. // Top-level. - ASSERT_EQ(WTF::DefaultHash<blink::TopLevelEnum>::Hash().hash( + ASSERT_EQ(WTF::DefaultHash<blink::TopLevelEnum>::Hash().GetHash( blink::TopLevelEnum::E0), - WTF::DefaultHash<blink::TopLevelEnum>::Hash().hash( + WTF::DefaultHash<blink::TopLevelEnum>::Hash().GetHash( blink::TopLevelEnum::E0)); // Nested in struct. - ASSERT_EQ(WTF::DefaultHash<blink::TestWTFStruct::NestedEnum>::Hash().hash( + ASSERT_EQ(WTF::DefaultHash<blink::TestWTFStruct::NestedEnum>::Hash().GetHash( blink::TestWTFStruct::NestedEnum::E0), - WTF::DefaultHash<blink::TestWTFStruct::NestedEnum>::Hash().hash( + WTF::DefaultHash<blink::TestWTFStruct::NestedEnum>::Hash().GetHash( blink::TestWTFStruct::NestedEnum::E0)); // Nested in interface. - ASSERT_EQ(WTF::DefaultHash<blink::TestWTF::NestedEnum>::Hash().hash( + ASSERT_EQ(WTF::DefaultHash<blink::TestWTF::NestedEnum>::Hash().GetHash( blink::TestWTF::NestedEnum::E0), - WTF::DefaultHash<blink::TestWTF::NestedEnum>::Hash().hash( + WTF::DefaultHash<blink::TestWTF::NestedEnum>::Hash().GetHash( blink::TestWTF::NestedEnum::E0)); } diff --git a/mojo/public/cpp/bindings/tests/wtf_types_unittest.cc b/mojo/public/cpp/bindings/tests/wtf_types_unittest.cc index 363ef7cdab..1064e41b8b 100644 --- a/mojo/public/cpp/bindings/tests/wtf_types_unittest.cc +++ b/mojo/public/cpp/bindings/tests/wtf_types_unittest.cc @@ -14,7 +14,7 @@ #include "mojo/public/interfaces/bindings/tests/test_wtf_types.mojom-blink.h" #include "mojo/public/interfaces/bindings/tests/test_wtf_types.mojom.h" #include "testing/gtest/include/gtest/gtest.h" -#include "third_party/WebKit/Source/wtf/text/StringHash.h" +#include "third_party/blink/renderer/platform/wtf/text/string_hash.h" namespace mojo { namespace test { @@ -44,8 +44,7 @@ class TestWTFImpl : public TestWTF { void EchoStringMap( const base::Optional< - std::unordered_map<std::string, base::Optional<std::string>>>& - str_map, + base::flat_map<std::string, base::Optional<std::string>>>& str_map, const EchoStringMapCallback& callback) override { callback.Run(std::move(str_map)); } @@ -68,7 +67,7 @@ WTF::Vector<WTF::String> ConstructStringArray() { // strs[1] is empty. strs[1] = ""; strs[2] = kHelloWorld; - strs[3] = WTF::String::fromUTF8(kUTF8HelloWorld); + strs[3] = WTF::String::FromUTF8(kUTF8HelloWorld); return strs; } @@ -78,7 +77,7 @@ WTF::HashMap<WTF::String, WTF::String> ConstructStringMap() { // A null string as value. str_map.insert("0", WTF::String()); str_map.insert("1", kHelloWorld); - str_map.insert("2", WTF::String::fromUTF8(kUTF8HelloWorld)); + str_map.insert("2", WTF::String::FromUTF8(kUTF8HelloWorld)); return str_map; } @@ -90,17 +89,17 @@ void ExpectString(const WTF::String& expected_string, closure.Run(); } -void ExpectStringArray(WTF::Optional<WTF::Vector<WTF::String>>* expected_arr, +void ExpectStringArray(base::Optional<WTF::Vector<WTF::String>>* expected_arr, const base::Closure& closure, - const WTF::Optional<WTF::Vector<WTF::String>>& arr) { + const base::Optional<WTF::Vector<WTF::String>>& arr) { EXPECT_EQ(*expected_arr, arr); closure.Run(); } void ExpectStringMap( - WTF::Optional<WTF::HashMap<WTF::String, WTF::String>>* expected_map, + base::Optional<WTF::HashMap<WTF::String, WTF::String>>* expected_map, const base::Closure& closure, - const WTF::Optional<WTF::HashMap<WTF::String, WTF::String>>& map) { + const base::Optional<WTF::HashMap<WTF::String, WTF::String>>& map) { EXPECT_EQ(*expected_map, map); closure.Run(); } @@ -113,19 +112,43 @@ TEST_F(WTFTypesTest, Serialization_WTFVectorToWTFVector) { WTF::Vector<WTF::String> strs = ConstructStringArray(); auto cloned_strs = strs; + mojo::Message message(0, 0, 0, 0, nullptr); mojo::internal::SerializationContext context; - size_t size = - mojo::internal::PrepareToSerialize<MojomType>(cloned_strs, &context); - - mojo::internal::FixedBufferForTesting buf(size); - typename mojo::internal::MojomTypeTraits<MojomType>::Data* data; + typename mojo::internal::MojomTypeTraits<MojomType>::Data::BufferWriter + writer; mojo::internal::ContainerValidateParams validate_params( 0, true, new mojo::internal::ContainerValidateParams(0, false, nullptr)); - mojo::internal::Serialize<MojomType>(cloned_strs, &buf, &data, - &validate_params, &context); + mojo::internal::Serialize<MojomType>(cloned_strs, message.payload_buffer(), + &writer, &validate_params, &context); WTF::Vector<WTF::String> strs2; - mojo::internal::Deserialize<MojomType>(data, &strs2, &context); + mojo::internal::Deserialize<MojomType>(writer.data(), &strs2, &context); + + EXPECT_EQ(strs, strs2); +} + +TEST_F(WTFTypesTest, Serialization_WTFVectorInlineCapacity) { + using MojomType = ArrayDataView<StringDataView>; + + WTF::Vector<WTF::String, 1> strs(4); + // strs[0] is null. + // strs[1] is empty. + strs[1] = ""; + strs[2] = kHelloWorld; + strs[3] = WTF::String::FromUTF8(kUTF8HelloWorld); + auto cloned_strs = strs; + + mojo::Message message(0, 0, 0, 0, nullptr); + mojo::internal::SerializationContext context; + typename mojo::internal::MojomTypeTraits<MojomType>::Data::BufferWriter + writer; + mojo::internal::ContainerValidateParams validate_params( + 0, true, new mojo::internal::ContainerValidateParams(0, false, nullptr)); + mojo::internal::Serialize<MojomType>(cloned_strs, message.payload_buffer(), + &writer, &validate_params, &context); + + WTF::Vector<WTF::String, 1> strs2; + mojo::internal::Deserialize<MojomType>(writer.data(), &strs2, &context); EXPECT_EQ(strs, strs2); } @@ -136,19 +159,17 @@ TEST_F(WTFTypesTest, Serialization_WTFVectorToStlVector) { WTF::Vector<WTF::String> strs = ConstructStringArray(); auto cloned_strs = strs; + mojo::Message message(0, 0, 0, 0, nullptr); mojo::internal::SerializationContext context; - size_t size = - mojo::internal::PrepareToSerialize<MojomType>(cloned_strs, &context); - - mojo::internal::FixedBufferForTesting buf(size); - typename mojo::internal::MojomTypeTraits<MojomType>::Data* data; + typename mojo::internal::MojomTypeTraits<MojomType>::Data::BufferWriter + writer; mojo::internal::ContainerValidateParams validate_params( 0, true, new mojo::internal::ContainerValidateParams(0, false, nullptr)); - mojo::internal::Serialize<MojomType>(cloned_strs, &buf, &data, - &validate_params, &context); + mojo::internal::Serialize<MojomType>(cloned_strs, message.payload_buffer(), + &writer, &validate_params, &context); std::vector<base::Optional<std::string>> strs2; - mojo::internal::Deserialize<MojomType>(data, &strs2, &context); + mojo::internal::Deserialize<MojomType>(writer.data(), &strs2, &context); ASSERT_EQ(4u, strs2.size()); EXPECT_FALSE(strs2[0]); @@ -192,7 +213,7 @@ TEST_F(WTFTypesTest, SendStringArray) { blink::TestWTFPtr ptr; TestWTFImpl impl(ConvertInterfaceRequest<TestWTF>(MakeRequest(&ptr))); - WTF::Optional<WTF::Vector<WTF::String>> arrs[3]; + base::Optional<WTF::Vector<WTF::String>> arrs[3]; // arrs[0] is empty. arrs[0].emplace(); // arrs[1] is null. @@ -200,13 +221,13 @@ TEST_F(WTFTypesTest, SendStringArray) { for (size_t i = 0; i < arraysize(arrs); ++i) { base::RunLoop loop; - // Test that a WTF::Optional<WTF::Vector<WTF::String>> is unchanged after + // Test that a base::Optional<WTF::Vector<WTF::String>> is unchanged after // the following conversion: // - serialized; // - deserialized as // base::Optional<std::vector<base::Optional<std::string>>>; // - serialized; - // - deserialized as WTF::Optional<WTF::Vector<WTF::String>>. + // - deserialized as base::Optional<WTF::Vector<WTF::String>>. ptr->EchoStringArray( arrs[i], base::Bind(&ExpectStringArray, base::Unretained(&arrs[i]), loop.QuitClosure())); @@ -218,7 +239,7 @@ TEST_F(WTFTypesTest, SendStringMap) { blink::TestWTFPtr ptr; TestWTFImpl impl(ConvertInterfaceRequest<TestWTF>(MakeRequest(&ptr))); - WTF::Optional<WTF::HashMap<WTF::String, WTF::String>> maps[3]; + base::Optional<WTF::HashMap<WTF::String, WTF::String>> maps[3]; // maps[0] is empty. maps[0].emplace(); // maps[1] is null. @@ -226,13 +247,13 @@ TEST_F(WTFTypesTest, SendStringMap) { for (size_t i = 0; i < arraysize(maps); ++i) { base::RunLoop loop; - // Test that a WTF::Optional<WTF::HashMap<WTF::String, WTF::String>> is + // Test that a base::Optional<WTF::HashMap<WTF::String, WTF::String>> is // unchanged after the following conversion: // - serialized; // - deserialized as base::Optional< - // std::unordered_map<std::string, base::Optional<std::string>>>; + // base::flat_map<std::string, base::Optional<std::string>>>; // - serialized; - // - deserialized as WTF::Optional<WTF::HashMap<WTF::String, + // - deserialized as base::Optional<WTF::HashMap<WTF::String, // WTF::String>>. ptr->EchoStringMap(maps[i], base::Bind(&ExpectStringMap, base::Unretained(&maps[i]), @@ -241,5 +262,20 @@ TEST_F(WTFTypesTest, SendStringMap) { } } +TEST_F(WTFTypesTest, NestedStruct_CloneAndEquals) { + auto a = blink::TestWTFStructWrapper::New(); + a->nested_struct = blink::TestWTFStruct::New("foo", 1); + a->array_struct.push_back(blink::TestWTFStruct::New("bar", 2)); + a->array_struct.push_back(blink::TestWTFStruct::New("bar", 3)); + a->map_struct.insert(blink::TestWTFStruct::New("baz", 4), + blink::TestWTFStruct::New("baz", 5)); + auto b = a.Clone(); + EXPECT_EQ(a, b); + EXPECT_EQ(2u, b->array_struct.size()); + EXPECT_EQ(1u, b->map_struct.size()); + EXPECT_NE(blink::TestWTFStructWrapper::New(), a); + EXPECT_NE(blink::TestWTFStructWrapper::New(), b); +} + } // namespace test } // namespace mojo diff --git a/mojo/public/cpp/bindings/thread_safe_interface_ptr.h b/mojo/public/cpp/bindings/thread_safe_interface_ptr.h index 740687f379..0a5c7f6f43 100644 --- a/mojo/public/cpp/bindings/thread_safe_interface_ptr.h +++ b/mojo/public/cpp/bindings/thread_safe_interface_ptr.h @@ -13,7 +13,7 @@ #include "base/stl_util.h" #include "base/synchronization/waitable_event.h" #include "base/task_runner.h" -#include "base/threading/thread_task_runner_handle.h" +#include "base/threading/sequenced_task_runner_handle.h" #include "mojo/public/cpp/bindings/associated_group.h" #include "mojo/public/cpp/bindings/associated_interface_ptr.h" #include "mojo/public/cpp/bindings/interface_ptr.h" @@ -22,22 +22,22 @@ #include "mojo/public/cpp/bindings/sync_event_watcher.h" // ThreadSafeInterfacePtr wraps a non-thread-safe InterfacePtr and proxies -// messages to it. Async calls are posted to the thread that the InteracePtr is -// bound to, and the responses are posted back. Sync calls are dispatched -// directly if the call is made on the thread that the wrapped InterfacePtr is +// messages to it. Async calls are posted to the sequence that the InteracePtr +// is bound to, and the responses are posted back. Sync calls are dispatched +// directly if the call is made on the sequence that the wrapped InterfacePtr is // bound to, or posted otherwise. It's important to be aware that sync calls -// block both the calling thread and the InterfacePtr thread. That means that -// you cannot make sync calls through a ThreadSafeInterfacePtr if the -// underlying InterfacePtr is bound to a thread that cannot block, like the IO +// block both the calling sequence and the InterfacePtr sequence. That means +// that you cannot make sync calls through a ThreadSafeInterfacePtr if the +// underlying InterfacePtr is bound to a sequence that cannot block, like the IO // thread. namespace mojo { -// Instances of this class may be used from any thread to serialize |Interface| -// messages and forward them elsewhere. In general you should use one of the -// ThreadSafeInterfacePtrBase helper aliases defined below, but this type may be -// useful if you need/want to manually manage the lifetime of the underlying -// proxy object which will be used to ultimately send messages. +// Instances of this class may be used from any sequence to serialize +// |Interface| messages and forward them elsewhere. In general you should use +// one of the ThreadSafeInterfacePtrBase helper aliases defined below, but this +// type may be useful if you need/want to manually manage the lifetime of the +// underlying proxy object which will be used to ultimately send messages. template <typename Interface> class ThreadSafeForwarder : public MessageReceiverWithResponder { public: @@ -50,9 +50,10 @@ class ThreadSafeForwarder : public MessageReceiverWithResponder { // |forward| or |forward_with_responder| by posting to |task_runner|. // // Any message sent through this forwarding interface will dispatch its reply, - // if any, back to the thread which called the corresponding interface method. + // if any, back to the sequence which called the corresponding interface + // method. ThreadSafeForwarder( - const scoped_refptr<base::SingleThreadTaskRunner>& task_runner, + const scoped_refptr<base::SequencedTaskRunner>& task_runner, const ForwardMessageCallback& forward, const ForwardMessageWithResponderCallback& forward_with_responder, const AssociatedGroup& associated_group) @@ -74,6 +75,13 @@ class ThreadSafeForwarder : public MessageReceiverWithResponder { private: // MessageReceiverWithResponder implementation: + bool PrefersSerializedMessages() override { + // TSIP is primarily used because it emulates legacy IPC threading behavior. + // In practice this means it's only for cross-process messaging and we can + // just always assume messages should be serialized. + return true; + } + bool Accept(Message* message) override { if (!message->associated_endpoint_handles()->empty()) { // If this DCHECK fails, it is likely because: @@ -104,10 +112,10 @@ class ThreadSafeForwarder : public MessageReceiverWithResponder { } // Async messages are always posted (even if |task_runner_| runs tasks on - // this thread) to guarantee that two async calls can't be reordered. + // this sequence) to guarantee that two async calls can't be reordered. if (!message->has_flag(Message::kFlagIsSync)) { auto reply_forwarder = - base::MakeUnique<ForwardToCallingThread>(std::move(responder)); + std::make_unique<ForwardToCallingThread>(std::move(responder)); task_runner_->PostTask( FROM_HERE, base::Bind(forward_with_responder_, base::Passed(message), base::Passed(&reply_forwarder))); @@ -116,17 +124,17 @@ class ThreadSafeForwarder : public MessageReceiverWithResponder { SyncCallRestrictions::AssertSyncCallAllowed(); - // If the InterfacePtr is bound to this thread, dispatch it directly. - if (task_runner_->RunsTasksOnCurrentThread()) { + // If the InterfacePtr is bound to this sequence, dispatch it directly. + if (task_runner_->RunsTasksInCurrentSequence()) { forward_with_responder_.Run(std::move(*message), std::move(responder)); return true; } - // If the InterfacePtr is bound on another thread, post the call. - // TODO(yzshen, watk): We block both this thread and the InterfacePtr - // thread. Ideally only this thread would block. - auto response = make_scoped_refptr(new SyncResponseInfo()); - auto response_signaler = base::MakeUnique<SyncResponseSignaler>(response); + // If the InterfacePtr is bound on another sequence, post the call. + // TODO(yzshen, watk): We block both this sequence and the InterfacePtr + // sequence. Ideally only this sequence would block. + auto response = base::MakeRefCounted<SyncResponseInfo>(); + auto response_signaler = std::make_unique<SyncResponseSignaler>(response); task_runner_->PostTask( FROM_HERE, base::Bind(forward_with_responder_, base::Passed(message), base::Passed(&response_signaler))); @@ -144,7 +152,8 @@ class ThreadSafeForwarder : public MessageReceiverWithResponder { bool event_signaled = false; SyncEventWatcher watcher(&response->event, base::Bind(assign_true, &event_signaled)); - watcher.SyncWatch(&event_signaled); + const bool* stop_flags[] = {&event_signaled}; + watcher.SyncWatch(stop_flags, 1); { base::AutoLock l(sync_calls->lock); @@ -157,7 +166,7 @@ class ThreadSafeForwarder : public MessageReceiverWithResponder { return true; } - // Data that we need to share between the threads involved in a sync call. + // Data that we need to share between the sequences involved in a sync call. struct SyncResponseInfo : public base::RefCountedThreadSafe<SyncResponseInfo> { Message message; @@ -183,7 +192,7 @@ class ThreadSafeForwarder : public MessageReceiverWithResponder { response_->event.Signal(); } - bool Accept(Message* message) { + bool Accept(Message* message) override { response_->message = std::move(*message); response_->received = true; response_->event.Signal(); @@ -208,10 +217,13 @@ class ThreadSafeForwarder : public MessageReceiverWithResponder { public: explicit ForwardToCallingThread(std::unique_ptr<MessageReceiver> responder) : responder_(std::move(responder)), - caller_task_runner_(base::ThreadTaskRunnerHandle::Get()) {} + caller_task_runner_(base::SequencedTaskRunnerHandle::Get()) {} + ~ForwardToCallingThread() override { + caller_task_runner_->DeleteSoon(FROM_HERE, std::move(responder_)); + } private: - bool Accept(Message* message) { + bool Accept(Message* message) override { // The current instance will be deleted when this method returns, so we // have to relinquish the responder's ownership so it does not get // deleted. @@ -230,11 +242,11 @@ class ThreadSafeForwarder : public MessageReceiverWithResponder { } std::unique_ptr<MessageReceiver> responder_; - scoped_refptr<base::SingleThreadTaskRunner> caller_task_runner_; + scoped_refptr<base::SequencedTaskRunner> caller_task_runner_; }; ProxyType proxy_; - const scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + const scoped_refptr<base::SequencedTaskRunner> task_runner_; const ForwardMessageCallback forward_; const ForwardMessageWithResponderCallback forward_with_responder_; AssociatedGroup associated_group_; @@ -256,9 +268,9 @@ class ThreadSafeInterfacePtrBase : forwarder_(std::move(forwarder)) {} // Creates a ThreadSafeInterfacePtrBase wrapping an underlying non-thread-safe - // InterfacePtrType which is bound to the calling thread. All messages sent + // InterfacePtrType which is bound to the calling sequence. All messages sent // via this thread-safe proxy will internally be sent by first posting to this - // (the calling) thread's TaskRunner. + // (the calling) sequence's TaskRunner. static scoped_refptr<ThreadSafeInterfacePtrBase> Create( InterfacePtrType interface_ptr) { scoped_refptr<PtrWrapper> wrapper = @@ -272,7 +284,7 @@ class ThreadSafeInterfacePtrBase // that TaskRunner. static scoped_refptr<ThreadSafeInterfacePtrBase> Create( PtrInfoType ptr_info, - const scoped_refptr<base::SingleThreadTaskRunner>& bind_task_runner) { + const scoped_refptr<base::SequencedTaskRunner>& bind_task_runner) { scoped_refptr<PtrWrapper> wrapper = new PtrWrapper(bind_task_runner); wrapper->BindOnTaskRunner(std::move(ptr_info)); return new ThreadSafeInterfacePtrBase(wrapper->CreateForwarder()); @@ -289,19 +301,19 @@ class ThreadSafeInterfacePtrBase struct PtrWrapperDeleter; // Helper class which owns an |InterfacePtrType| instance on an appropriate - // thread. This is kept alive as long its bound within some + // sequence. This is kept alive as long its bound within some // ThreadSafeForwarder's callbacks. class PtrWrapper : public base::RefCountedThreadSafe<PtrWrapper, PtrWrapperDeleter> { public: explicit PtrWrapper(InterfacePtrType ptr) - : PtrWrapper(base::ThreadTaskRunnerHandle::Get()) { + : PtrWrapper(base::SequencedTaskRunnerHandle::Get()) { ptr_ = std::move(ptr); associated_group_ = *ptr_.internal_state()->associated_group(); } explicit PtrWrapper( - const scoped_refptr<base::SingleThreadTaskRunner>& task_runner) + const scoped_refptr<base::SequencedTaskRunner>& task_runner) : task_runner_(task_runner) {} void BindOnTaskRunner(AssociatedInterfacePtrInfo<InterfaceType> ptr_info) { @@ -316,14 +328,14 @@ class ThreadSafeInterfacePtrBase // endpoints on this interface (at least not immediately). In order to fix // this, we need to create a MultiplexRouter immediately and bind it to // the interface pointer on the |task_runner_|. Therefore, MultiplexRouter - // should be able to be created on a thread different than the one that it - // is supposed to listen on. crbug.com/682334 + // should be able to be created on a sequence different than the one that + // it is supposed to listen on. crbug.com/682334 task_runner_->PostTask(FROM_HERE, base::Bind(&PtrWrapper::Bind, this, base::Passed(&ptr_info))); } std::unique_ptr<ThreadSafeForwarder<InterfaceType>> CreateForwarder() { - return base::MakeUnique<ThreadSafeForwarder<InterfaceType>>( + return std::make_unique<ThreadSafeForwarder<InterfaceType>>( task_runner_, base::Bind(&PtrWrapper::Accept, this), base::Bind(&PtrWrapper::AcceptWithResponder, this), associated_group_); @@ -335,7 +347,7 @@ class ThreadSafeInterfacePtrBase ~PtrWrapper() {} void Bind(PtrInfoType ptr_info) { - DCHECK(task_runner_->RunsTasksOnCurrentThread()); + DCHECK(task_runner_->RunsTasksInCurrentSequence()); ptr_.Bind(std::move(ptr_info)); } @@ -350,7 +362,7 @@ class ThreadSafeInterfacePtrBase } void DeleteOnCorrectThread() const { - if (!task_runner_->RunsTasksOnCurrentThread()) { + if (!task_runner_->RunsTasksInCurrentSequence()) { // NOTE: This is only called when there are no more references to // |this|, so binding it unretained is both safe and necessary. task_runner_->PostTask(FROM_HERE, @@ -362,7 +374,7 @@ class ThreadSafeInterfacePtrBase } InterfacePtrType ptr_; - const scoped_refptr<base::SingleThreadTaskRunner> task_runner_; + const scoped_refptr<base::SequencedTaskRunner> task_runner_; AssociatedGroup associated_group_; DISALLOW_COPY_AND_ASSIGN(PtrWrapper); diff --git a/mojo/public/cpp/bindings/union_traits.h b/mojo/public/cpp/bindings/union_traits.h index 292ee58f27..243addd7ed 100644 --- a/mojo/public/cpp/bindings/union_traits.h +++ b/mojo/public/cpp/bindings/union_traits.h @@ -5,6 +5,8 @@ #ifndef MOJO_PUBLIC_CPP_BINDINGS_UNION_TRAITS_H_ #define MOJO_PUBLIC_CPP_BINDINGS_UNION_TRAITS_H_ +#include "mojo/public/cpp/bindings/lib/template_util.h" + namespace mojo { // This must be specialized for any type |T| to be serialized/deserialized as @@ -32,7 +34,11 @@ namespace mojo { // will be called. // template <typename DataViewType, typename T> -struct UnionTraits; +struct UnionTraits { + static_assert(internal::AlwaysFalse<T>::value, + "Cannot find the mojo::UnionTraits specialization. Did you " + "forget to include the corresponding header file?"); +}; } // namespace mojo diff --git a/mojo/public/cpp/bindings/unique_ptr_impl_ref_traits.h b/mojo/public/cpp/bindings/unique_ptr_impl_ref_traits.h index f1ac097396..ca7fe930c6 100644 --- a/mojo/public/cpp/bindings/unique_ptr_impl_ref_traits.h +++ b/mojo/public/cpp/bindings/unique_ptr_impl_ref_traits.h @@ -9,9 +9,9 @@ namespace mojo { // Traits for a binding's implementation reference type. // This corresponds to a unique_ptr reference type. -template <typename Interface> +template <typename Interface, typename Deleter = std::default_delete<Interface>> struct UniquePtrImplRefTraits { - using PointerType = std::unique_ptr<Interface>; + using PointerType = std::unique_ptr<Interface, Deleter>; static bool IsNull(const PointerType& ptr) { return !ptr; } static Interface* GetRawPointer(PointerType* ptr) { return ptr->get(); } |