diff options
author | android-build-team Robot <android-build-team-robot@google.com> | 2021-06-19 12:01:48 +0000 |
---|---|---|
committer | android-build-team Robot <android-build-team-robot@google.com> | 2021-06-19 12:01:48 +0000 |
commit | 9063677c3726d766de3c5fc3a2cd3feabd840d08 (patch) | |
tree | 8f6ec66173bb45a55a65c8449f382c2c16b3b13a | |
parent | 9e260fa59a43bc04701216c51053922d445fefd8 (diff) | |
parent | 6901182adc9a85d7f514df4d68605a179c195f75 (diff) | |
download | libbrillo-android-mainline-12.0.0_r29.tar.gz |
Snap for 7474514 from 6901182adc9a85d7f514df4d68605a179c195f75 to mainline-media-releaseandroid-mainline-12.0.0_r89android-mainline-12.0.0_r74android-mainline-12.0.0_r62android-mainline-12.0.0_r46android-mainline-12.0.0_r29android-mainline-12.0.0_r12android-mainline-12.0.0_r119android-mainline-12.0.0_r104android12-mainline-media-release
Change-Id: I1756a28b97a568bbb2c02438730597bb2aa59efe
255 files changed, 11217 insertions, 3308 deletions
@@ -16,13 +16,42 @@ // by setting BRILLO_USE_* values. Note that we define local variables like // local_use_* to prevent leaking our default setting for other packages. +package { + default_applicable_licenses: ["external_libbrillo_license"], +} + +// Added automatically by a large-scale-change that took the approach of +// 'apply every license found to every target'. While this makes sure we respect +// every license restriction, it may not be entirely correct. +// +// e.g. GPL in an MIT project might only apply to the contrib/ directory. +// +// Please consider splitting the single license below into multiple licenses, +// taking care not to lose any license_kind information, and overriding the +// default license using the 'licenses: [...]' property on targets as needed. +// +// For unused files, consider creating a 'fileGroup' with "//visibility:private" +// to attach the license to, and including a comment whether the files may be +// used in the current project. +// See: http://go/android-license-faq +license { + name: "external_libbrillo_license", + visibility: [":__subpackages__"], + license_kinds: [ + "SPDX-license-identifier-Apache-2.0", + "SPDX-license-identifier-BSD", + ], + license_text: [ + "NOTICE", + ], +} + libbrillo_core_sources = [ "brillo/backoff_entry.cc", "brillo/data_encoding.cc", "brillo/errors/error.cc", "brillo/errors/error_codes.cc", "brillo/flag_helper.cc", - "brillo/imageloader/manifest.cc", "brillo/key_value_store.cc", "brillo/message_loops/base_message_loop.cc", "brillo/message_loops/message_loop.cc", @@ -85,46 +114,48 @@ libbrillo_test_helpers_sources = [ ] libbrillo_test_sources = [ - "brillo/asynchronous_signal_handler_unittest.cc", - "brillo/backoff_entry_unittest.cc", - "brillo/data_encoding_unittest.cc", - "brillo/enum_flags_unittest.cc", - "brillo/errors/error_codes_unittest.cc", - "brillo/errors/error_unittest.cc", - "brillo/file_utils_unittest.cc", - "brillo/flag_helper_unittest.cc", - "brillo/http/http_connection_curl_unittest.cc", - "brillo/http/http_form_data_unittest.cc", - "brillo/http/http_request_unittest.cc", - "brillo/http/http_transport_curl_unittest.cc", - "brillo/http/http_utils_unittest.cc", - "brillo/imageloader/manifest_unittest.cc", - "brillo/key_value_store_unittest.cc", - "brillo/map_utils_unittest.cc", - "brillo/message_loops/base_message_loop_unittest.cc", - "brillo/message_loops/fake_message_loop_unittest.cc", - "brillo/mime_utils_unittest.cc", - "brillo/osrelease_reader_unittest.cc", - "brillo/process_reaper_unittest.cc", - "brillo/process_unittest.cc", - "brillo/secure_blob_unittest.cc", - "brillo/streams/fake_stream_unittest.cc", - "brillo/streams/file_stream_unittest.cc", - "brillo/streams/input_stream_set_unittest.cc", - "brillo/streams/memory_containers_unittest.cc", - "brillo/streams/memory_stream_unittest.cc", - "brillo/streams/openssl_stream_bio_unittests.cc", - "brillo/streams/stream_unittest.cc", - "brillo/streams/stream_utils_unittest.cc", - "brillo/strings/string_utils_unittest.cc", + "brillo/asynchronous_signal_handler_test.cc", + "brillo/backoff_entry_test.cc", + "brillo/data_encoding_test.cc", + "brillo/enum_flags_test.cc", + "brillo/errors/error_codes_test.cc", + "brillo/errors/error_test.cc", + "brillo/file_utils_test.cc", + "brillo/flag_helper_test.cc", + "brillo/http/http_connection_curl_test.cc", + "brillo/http/http_form_data_test.cc", + "brillo/http/http_request_test.cc", + "brillo/http/http_transport_curl_test.cc", + "brillo/http/http_utils_test.cc", + "brillo/key_value_store_test.cc", + "brillo/map_utils_test.cc", + "brillo/message_loops/base_message_loop_test.cc", + "brillo/message_loops/fake_message_loop_test.cc", + "brillo/mime_utils_test.cc", + "brillo/osrelease_reader_test.cc", + "brillo/process_reaper_test.cc", + "brillo/process_test.cc", + "brillo/secure_blob_test.cc", + "brillo/streams/fake_stream_test.cc", + "brillo/streams/file_stream_test.cc", + "brillo/streams/input_stream_set_test.cc", + "brillo/streams/memory_containers_test.cc", + "brillo/streams/memory_stream_test.cc", + "brillo/streams/openssl_stream_bio_test.cc", + "brillo/streams/stream_test.cc", + "brillo/streams/stream_utils_test.cc", + "brillo/strings/string_utils_test.cc", "brillo/unittest_utils.cc", - "brillo/url_utils_unittest.cc", - "brillo/value_conversion_unittest.cc", + "brillo/url_utils_test.cc", + "brillo/value_conversion_test.cc", ] libbrillo_CFLAGS = [ "-Wall", "-Werror", + "-Wno-non-virtual-dtor", + "-Wno-unused-parameter", + "-Wno-unused-variable", ] libbrillo_shared_libraries = ["libchrome"] @@ -139,8 +170,8 @@ cc_library { shared_libs: libbrillo_shared_libraries, static_libs: [ "libmodpb64", - "libgtest_prod", ], + header_libs: ["libgtest_prod_headers"], cflags: libbrillo_CFLAGS, export_include_dirs: ["."], @@ -167,7 +198,7 @@ cc_library_shared { "libbrillo", "libutils", ], - static_libs: ["libgtest_prod"], + header_libs: ["libgtest_prod_headers"], cflags: libbrillo_CFLAGS, export_include_dirs: ["."], } @@ -184,7 +215,7 @@ cc_library_shared { "libbrillo", "libminijail", ], - static_libs: ["libgtest_prod"], + header_libs: ["libgtest_prod_headers"], cflags: libbrillo_CFLAGS, export_include_dirs: ["."], } @@ -199,7 +230,7 @@ cc_library { "libcrypto", "libssl", ], - static_libs: ["libgtest_prod"], + header_libs: ["libgtest_prod_headers"], cflags: libbrillo_CFLAGS, export_include_dirs: ["."], @@ -225,7 +256,7 @@ cc_library_shared { "libbrillo-stream", "libcurl", ], - static_libs: ["libgtest_prod"], + header_libs: ["libgtest_prod_headers"], cflags: libbrillo_CFLAGS, export_include_dirs: ["."], @@ -246,7 +277,7 @@ cc_library_shared { name: "libbrillo-policy", srcs: libbrillo_policy_sources, shared_libs: libbrillo_shared_libraries, - static_libs: ["libgtest_prod"], + header_libs: ["libgtest_prod_headers"], cflags: libbrillo_CFLAGS, export_include_dirs: ["."], } diff --git a/BUILD.gn b/BUILD.gn new file mode 100644 index 0000000..8eefb58 --- /dev/null +++ b/BUILD.gn @@ -0,0 +1,669 @@ +# Copyright 2019 The Chromium OS Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//common-mk/deps.gni") +import("//common-mk/pkg_config.gni") +import("//common-mk/proto_library.gni") + +group("all") { + deps = [ + ":libbrillo", + ":libbrillo-glib", + ":libbrillo-test", + ":libinstallattributes", + ":libpolicy", + ] + if (use.test) { + deps += [ + ":libbrillo_tests", + ":libinstallattributes_tests", + ":libpolicy_tests", + ] + } + if (use.fuzzer) { + deps += [ + ":libbrillo_data_encoding_fuzzer", + ":libbrillo_dbus_data_serialization_fuzzer", + ":libbrillo_http_form_data_fuzzer", + ] + } +} + +default_pkg_deps = [ "libchrome-${libbase_ver}" ] +pkg_config("target_defaults_pkg_deps") { + pkg_deps = default_pkg_deps +} + +config("target_defaults") { + configs = [ ":target_defaults_pkg_deps" ] + include_dirs = [ "../libbrillo" ] + defines = [ + "USE_DBUS=${use.dbus}", + "USE_RTTI_FOR_TYPE_TAGS", + ] +} + +config("libbrillo_configs") { + include_dirs = [ "../libbrillo" ] +} + +# Properties of shared libraries which libbrillo consists of. +# Stored to variables once before actually declaring the targets, so that +# another target can collect information for making the .pc and .so files. +libbrillo_sublibs = [ + { + # |library_name| is library file name without "lib" prefix. This is needed + # for composing -l*** flags in libbrillo-${libbasever}.so. + # (Current version of GN deployed to ChromeOS doesn't have string_replace.) + library_name = "brillo-core" + if (use.dbus) { + all_dependent_pkg_deps = [ "dbus-1" ] + } + libs = [ "modp_b64" ] + sources = [ + "brillo/asynchronous_signal_handler.cc", + "brillo/backoff_entry.cc", + "brillo/daemons/daemon.cc", + "brillo/data_encoding.cc", + "brillo/errors/error.cc", + "brillo/errors/error_codes.cc", + "brillo/file_utils.cc", + "brillo/files/file_util.cc", + "brillo/files/safe_fd.cc", + "brillo/flag_helper.cc", + "brillo/key_value_store.cc", + "brillo/message_loops/base_message_loop.cc", + "brillo/message_loops/message_loop.cc", + "brillo/message_loops/message_loop_utils.cc", + "brillo/mime_utils.cc", + "brillo/osrelease_reader.cc", + "brillo/process.cc", + "brillo/process_information.cc", + "brillo/process_reaper.cc", + "brillo/scoped_umask.cc", + "brillo/secure_blob.cc", + "brillo/strings/string_utils.cc", + "brillo/syslog_logging.cc", + "brillo/timezone/tzif_parser.cc", + "brillo/type_name_undecorate.cc", + "brillo/url_utils.cc", + "brillo/userdb_utils.cc", + "brillo/value_conversion.cc", + ] + if (use.dbus) { + sources += [ + "brillo/any.cc", + "brillo/daemons/dbus_daemon.cc", + "brillo/dbus/async_event_sequencer.cc", + "brillo/dbus/data_serialization.cc", + "brillo/dbus/dbus_connection.cc", + "brillo/dbus/dbus_method_invoker.cc", + "brillo/dbus/dbus_method_response.cc", + "brillo/dbus/dbus_object.cc", + "brillo/dbus/dbus_service_watcher.cc", + "brillo/dbus/dbus_signal.cc", + "brillo/dbus/exported_object_manager.cc", + "brillo/dbus/exported_property_set.cc", + "brillo/dbus/introspectable_helper.cc", + "brillo/dbus/utils.cc", + ] + } + }, + + { + library_name = "brillo-blockdeviceutils" + deps = [ ":libbrillo-core" ] + sources = [ "brillo/blkdev_utils/loop_device.cc" ] + if (use.device_mapper) { + pkg_deps = [ "devmapper" ] + sources += [ + "brillo/blkdev_utils/device_mapper.cc", + "brillo/blkdev_utils/device_mapper_task.cc", + ] + } + }, + + { + library_name = "brillo-http" + deps = [ + ":libbrillo-core", + ":libbrillo-streams", + ] + all_dependent_pkg_deps = [ "libcurl" ] + sources = [ + "brillo/http/curl_api.cc", + "brillo/http/http_connection_curl.cc", + "brillo/http/http_form_data.cc", + "brillo/http/http_request.cc", + "brillo/http/http_transport.cc", + "brillo/http/http_transport_curl.cc", + "brillo/http/http_utils.cc", + ] + if (use.dbus) { + sources += [ "brillo/http/http_proxy.cc" ] + } + }, + + { + library_name = "brillo-streams" + deps = [ ":libbrillo-core" ] + all_dependent_pkg_deps = [ "openssl" ] + sources = [ + "brillo/streams/file_stream.cc", + "brillo/streams/input_stream_set.cc", + "brillo/streams/memory_containers.cc", + "brillo/streams/memory_stream.cc", + "brillo/streams/openssl_stream_bio.cc", + "brillo/streams/stream.cc", + "brillo/streams/stream_errors.cc", + "brillo/streams/stream_utils.cc", + "brillo/streams/tls_stream.cc", + ] + }, + + { + library_name = "brillo-cryptohome" + all_dependent_pkg_deps = [ "openssl" ] + sources = [ "brillo/cryptohome.cc" ] + }, + + { + library_name = "brillo-namespaces" + deps = [ ":libbrillo-core" ] + sources = [ + "brillo/namespaces/mount_namespace.cc", + "brillo/namespaces/platform.cc", + "brillo/scoped_mount_namespace.cc", + ] + }, + + { + library_name = "brillo-minijail" + all_dependent_pkg_deps = [ "libminijail" ] + sources = [ "brillo/minijail/minijail.cc" ] + }, + + { + library_name = "brillo-protobuf" + all_dependent_pkg_deps = [ "protobuf" ] + sources = [ "brillo/proto_file_io.cc" ] + }, +] + +if (use.udev) { + libbrillo_sublibs += [ + { + library_name = "brillo-udev" + all_dependent_pkg_deps = [ "libudev" ] + sources = [ + "brillo/udev/udev.cc", + "brillo/udev/udev_device.cc", + "brillo/udev/udev_enumerate.cc", + "brillo/udev/udev_list_entry.cc", + "brillo/udev/udev_monitor.cc", + ] + }, + ] +} + +# Generate shared libraries. +foreach(attr, libbrillo_sublibs) { + shared_library("lib" + attr.library_name) { + sources = attr.sources + if (defined(attr.deps)) { + deps = attr.deps + } + if (defined(attr.libs)) { + libs = attr.libs + } + if (defined(attr.pkg_deps)) { + pkg_deps = attr.pkg_deps + } + if (defined(attr.public_pkg_deps)) { + public_pkg_deps = attr.public_pkg_deps + } + if (defined(attr.all_dependent_pkg_deps)) { + all_dependent_pkg_deps = attr.all_dependent_pkg_deps + } + if (defined(attr.cflags)) { + cflags = attr.cflags + } + if (defined(attr.configs)) { + configs += attr.configs + } + configs += [ ":target_defaults" ] + } +} + +generate_pkg_config("libbrillo-${libbase_ver}_pc") { + name = "libbrillo" + output_name = "libbrillo-${libbase_ver}" + description = "brillo base library" + version = libbase_ver + requires_private = default_pkg_deps + foreach(sublib, libbrillo_sublibs) { + if (defined(sublib.pkg_deps)) { + requires_private += sublib.pkg_deps + } + if (defined(sublib.public_pkg_deps)) { + requires_private += sublib.public_pkg_deps + } + if (defined(sublib.all_dependent_pkg_deps)) { + requires_private += sublib.all_dependent_pkg_deps + } + } + defines = [ "USE_RTTI_FOR_TYPE_TAGS" ] + libs = [ "-lbrillo" ] +} + +generate_pkg_config("libbrillo_pc") { + name = "libbrillo" + output_name = "libbrillo" + description = "brillo base library" + version = libbase_ver + requires_private = default_pkg_deps + foreach(sublib, libbrillo_sublibs) { + if (defined(sublib.pkg_deps)) { + requires_private += sublib.pkg_deps + } + if (defined(sublib.public_pkg_deps)) { + requires_private += sublib.public_pkg_deps + } + if (defined(sublib.all_dependent_pkg_deps)) { + requires_private += sublib.all_dependent_pkg_deps + } + } + defines = [ "USE_RTTI_FOR_TYPE_TAGS" ] + libs = [ "-lbrillo" ] +} + +action("libbrillo") { + deps = [ + ":libbrillo-${libbase_ver}_pc", + ":libbrillo_pc", + ] + foreach(sublib, libbrillo_sublibs) { + deps += [ ":lib" + sublib.library_name ] + } + script = "//common-mk/write_args.py" + outputs = [ "${root_out_dir}/lib/libbrillo.so" ] + args = [ "--output" ] + outputs + [ "--" ] + [ + "GROUP", + "(", + "AS_NEEDED", + "(", + ] + foreach(sublib, libbrillo_sublibs) { + args += [ "-l" + sublib.library_name ] + } + args += [ + ")", + ")", + ] +} + +libbrillo_test_deps = [ "libbrillo-http" ] + +generate_pkg_config("libbrillo-test-${libbase_ver}_pc") { + name = "libbrillo-test" + output_name = "libbrillo-test-${libbase_ver}" + description = "brillo test library" + version = libbase_ver + + # Because libbrillo-test is static, we have to depend directly on everything. + requires = [ "libbrillo" ] + default_pkg_deps + foreach(name, libbrillo_test_deps) { + foreach(t, libbrillo_sublibs) { + if ("lib" + t.library_name == name) { + if (defined(t.pkg_deps)) { + requires += t.pkg_deps + } + if (defined(t.public_pkg_deps)) { + requires += t.public_pkg_deps + } + if (defined(t.all_dependent_pkg_deps)) { + requires += t.all_dependent_pkg_deps + } + } + } + } + libs = [ "-lbrillo-test" ] +} + +generate_pkg_config("libbrillo-test_pc") { + name = "libbrillo-test" + output_name = "libbrillo-test" + description = "brillo test library" + version = libbase_ver + + # Because libbrillo-test is static, we have to depend directly on everything. + requires = [ "libbrillo" ] + default_pkg_deps + foreach(name, libbrillo_test_deps) { + foreach(t, libbrillo_sublibs) { + if ("lib" + t.library_name == name) { + if (defined(t.pkg_deps)) { + requires += t.pkg_deps + } + if (defined(t.public_pkg_deps)) { + requires += t.public_pkg_deps + } + if (defined(t.all_dependent_pkg_deps)) { + requires += t.all_dependent_pkg_deps + } + } + } + } + libs = [ "-lbrillo-test" ] +} + +static_library("libbrillo-test") { + configs -= [ "//common-mk:use_thin_archive" ] + configs += [ + "//common-mk:nouse_thin_archive", + ":target_defaults", + ] + deps = [ + ":libbrillo-http", + ":libbrillo-test-${libbase_ver}_pc", + ":libbrillo-test_pc", + ] + foreach(name, libbrillo_test_deps) { + deps += [ ":" + name ] + } + sources = [ + "brillo/blkdev_utils/loop_device_fake.cc", + "brillo/http/http_connection_fake.cc", + "brillo/http/http_transport_fake.cc", + "brillo/message_loops/fake_message_loop.cc", + "brillo/streams/fake_stream.cc", + "brillo/unittest_utils.cc", + ] + if (use.device_mapper) { + sources += [ "brillo/blkdev_utils/device_mapper_fake.cc" ] + } +} + +shared_library("libinstallattributes") { + configs += [ ":target_defaults" ] + deps = [ + ":libinstallattributes-includes", + "../common-mk/external_dependencies:install_attributes-proto", + ] + all_dependent_pkg_deps = [ "protobuf-lite" ] + sources = [ "install_attributes/libinstallattributes.cc" ] +} + +shared_library("libpolicy") { + configs += [ ":target_defaults" ] + deps = [ + ":libinstallattributes", + ":libpolicy-includes", + "../common-mk/external_dependencies:policy-protos", + ] + all_dependent_pkg_deps = [ + "openssl", + "protobuf-lite", + ] + ldflags = [ "-Wl,--version-script,${platform2_root}/libbrillo/libpolicy.ver" ] + sources = [ + "policy/device_policy.cc", + "policy/device_policy_impl.cc", + "policy/libpolicy.cc", + "policy/policy_util.cc", + "policy/resilient_policy_util.cc", + ] +} + +libbrillo_glib_pkg_deps = [ + "glib-2.0", + "gobject-2.0", +] +if (use.dbus) { + libbrillo_glib_pkg_deps += [ + "dbus-1", + "dbus-glib-1", + ] +} + +generate_pkg_config("libbrillo-glib-${libbase_ver}_pc") { + name = "libbrillo-glib" + output_name = "libbrillo-glib-${libbase_ver}" + description = "brillo glib wrapper library" + version = libbase_ver + requires_private = libbrillo_glib_pkg_deps + libs = [ "-lbrillo-glib" ] +} + +generate_pkg_config("libbrillo-glib_pc") { + name = "libbrillo-glib" + output_name = "libbrillo-glib" + description = "brillo glib wrapper library" + version = libbase_ver + requires_private = libbrillo_glib_pkg_deps + libs = [ "-lbrillo-glib" ] +} + +shared_library("libbrillo-glib") { + configs += [ ":target_defaults" ] + deps = [ + ":libbrillo", + ":libbrillo-glib-${libbase_ver}_pc", + ":libbrillo-glib_pc", + ] + all_dependent_pkg_deps = libbrillo_glib_pkg_deps + if (use.dbus) { + sources = [ + "brillo/glib/abstract_dbus_service.cc", + "brillo/glib/dbus.cc", + ] + } + cflags = [ + # glib uses the deprecated "register" attribute in some header files. + "-Wno-deprecated-register", + ] +} + +if (use.test) { + static_library("libbrillo_static") { + configs += [ ":target_defaults" ] + deps = [ + ":libbrillo-${libbase_ver}_pc", + ":libbrillo_pc", + ":libinstallattributes", + ":libpolicy", + ] + foreach(sublib, libbrillo_sublibs) { + deps += [ ":lib" + sublib.library_name ] + } + public_configs = [ ":libbrillo_configs" ] + } + proto_library("libbrillo_tests_proto") { + proto_in_dir = "brillo/dbus" + proto_out_dir = "include/brillo/dbus" + sources = [ "${proto_in_dir}/test.proto" ] + } + executable("libbrillo_tests") { + configs += [ + "//common-mk:test", + ":target_defaults", + ] + deps = [ + ":libbrillo-glib", + ":libbrillo-test", + ":libbrillo_static", + ":libbrillo_tests_proto", + ] + pkg_deps = [ "libchrome-test-${libbase_ver}" ] + cflags = [ "-Wno-format-zero-length" ] + + if (is_debug) { + cflags += [ + "-fprofile-arcs", + "-ftest-coverage", + "-fno-inline", + ] + libs = [ "gcov" ] + } + sources = [ + "brillo/asynchronous_signal_handler_test.cc", + "brillo/backoff_entry_test.cc", + "brillo/blkdev_utils/loop_device_test.cc", + "brillo/data_encoding_test.cc", + "brillo/enum_flags_test.cc", + "brillo/errors/error_codes_test.cc", + "brillo/errors/error_test.cc", + "brillo/file_utils_test.cc", + "brillo/files/file_util_test.cc", + "brillo/files/safe_fd_test.cc", + "brillo/flag_helper_test.cc", + "brillo/glib/object_test.cc", + "brillo/http/http_connection_curl_test.cc", + "brillo/http/http_form_data_test.cc", + "brillo/http/http_request_test.cc", + "brillo/http/http_transport_curl_test.cc", + "brillo/http/http_utils_test.cc", + "brillo/key_value_store_test.cc", + "brillo/map_utils_test.cc", + "brillo/message_loops/base_message_loop_test.cc", + "brillo/message_loops/fake_message_loop_test.cc", + "brillo/message_loops/message_loop_test.cc", + "brillo/mime_utils_test.cc", + "brillo/namespaces/mount_namespace_test.cc", + "brillo/osrelease_reader_test.cc", + "brillo/process_reaper_test.cc", + "brillo/process_test.cc", + "brillo/scoped_umask_test.cc", + "brillo/secure_blob_test.cc", + "brillo/streams/fake_stream_test.cc", + "brillo/streams/file_stream_test.cc", + "brillo/streams/input_stream_set_test.cc", + "brillo/streams/memory_containers_test.cc", + "brillo/streams/memory_stream_test.cc", + "brillo/streams/openssl_stream_bio_test.cc", + "brillo/streams/stream_test.cc", + "brillo/streams/stream_utils_test.cc", + "brillo/strings/string_utils_test.cc", + "brillo/timezone/tzif_parser_test.cc", + "brillo/unittest_utils.cc", + "brillo/url_utils_test.cc", + "brillo/value_conversion_test.cc", + "testrunner.cc", + ] + if (use.dbus) { + sources += [ + "brillo/any_internal_impl_test.cc", + "brillo/any_test.cc", + "brillo/dbus/async_event_sequencer_test.cc", + "brillo/dbus/data_serialization_test.cc", + "brillo/dbus/dbus_method_invoker_test.cc", + "brillo/dbus/dbus_object_test.cc", + "brillo/dbus/dbus_param_reader_test.cc", + "brillo/dbus/dbus_param_writer_test.cc", + "brillo/dbus/dbus_signal_handler_test.cc", + "brillo/dbus/exported_object_manager_test.cc", + "brillo/dbus/exported_property_set_test.cc", + "brillo/http/http_proxy_test.cc", + "brillo/type_name_undecorate_test.cc", + "brillo/variant_dictionary_test.cc", + ] + } + if (use.device_mapper) { + sources += [ "brillo/blkdev_utils/device_mapper_test.cc" ] + } + } + + executable("libinstallattributes_tests") { + configs += [ + "//common-mk:test", + ":target_defaults", + ] + deps = [ + ":libinstallattributes", + "../common-mk/external_dependencies:install_attributes-proto", + "../common-mk/testrunner:testrunner", + ] + sources = [ "install_attributes/tests/libinstallattributes_test.cc" ] + } + + executable("libpolicy_tests") { + configs += [ + "//common-mk:test", + ":target_defaults", + ] + deps = [ + ":libinstallattributes", + ":libpolicy", + "../common-mk/external_dependencies:install_attributes-proto", + "../common-mk/external_dependencies:policy-protos", + "../common-mk/testrunner:testrunner", + ] + sources = [ + "install_attributes/mock_install_attributes_reader.cc", + "policy/tests/device_policy_impl_test.cc", + "policy/tests/libpolicy_test.cc", + "policy/tests/policy_util_test.cc", + "policy/tests/resilient_policy_util_test.cc", + ] + } +} + +if (use.fuzzer) { + executable("libbrillo_data_encoding_fuzzer") { + sources = [ "brillo/data_encoding_fuzzer.cc" ] + + configs += [ "//common-mk/common_fuzzer:common_fuzzer" ] + + pkg_deps = [ "libchrome-${libbase_ver}" ] + + include_dirs = [ "../libbrillo" ] + + deps = [ ":libbrillo-core" ] + } + + executable("libbrillo_dbus_data_serialization_fuzzer") { + sources = [ "brillo/dbus/data_serialization_fuzzer.cc" ] + + configs += [ "//common-mk/common_fuzzer:common_fuzzer" ] + + pkg_deps = [ "libchrome-${libbase_ver}" ] + + include_dirs = [ "../libbrillo" ] + + deps = [ ":libbrillo-core" ] + } + + executable("libbrillo_http_form_data_fuzzer") { + sources = [ "brillo/http/http_form_data_fuzzer.cc" ] + + configs += [ "//common-mk/common_fuzzer:common_fuzzer" ] + + pkg_deps = [ "libchrome-${libbase_ver}" ] + + include_dirs = [ "../libbrillo" ] + + deps = [ + ":libbrillo-http", + ":libbrillo-streams", + ] + } +} + +copy("libinstallattributes-includes") { + sources = [ "install_attributes/libinstallattributes.h" ] + outputs = + [ "${root_gen_dir}/include/install_attributes/{{source_file_part}}" ] +} + +copy("libpolicy-includes") { + sources = [ + "policy/device_policy.h", + "policy/device_policy_impl.h", + "policy/libpolicy.h", + "policy/mock_device_policy.h", + "policy/mock_libpolicy.h", + "policy/policy_util.h", + "policy/resilient_policy_util.h", + ] + outputs = [ "${root_gen_dir}/include/policy/{{source_file_part}}" ] +} diff --git a/METADATA b/METADATA new file mode 100644 index 0000000..d97975c --- /dev/null +++ b/METADATA @@ -0,0 +1,3 @@ +third_party { + license_type: NOTICE +} @@ -1,13 +1,9 @@ set noparent # Android owners -avakulenko@google.com -dpursell@google.com senj@google.com -stevefung@google.com # Chromium owners -benchan@google.com -derat@google.com vapier@google.com ejcaruso@google.com +hidehiko@google.com diff --git a/README.md b/README.md new file mode 100644 index 0000000..118e3f1 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +# libbrillo: platform utility library + +libbrillo is a shared library meant to hold common utility code that we deem +useful for platform projects. +It supplements the functionality provided by libbase/libchrome since that +project, by design, only holds functionality that Chromium (the browser) needs. +As a result, this tends to be more OS-centric code. + +## AOSP Usage + +This project is also used by [Update Engine] which is maintained in AOSP. +However, AOSP doesn't use this codebase directly, it maintains its own +[libbrillo fork](https://android.googlesource.com/platform/external/libbrillo/). + +To help keep the projects in sync, we have a gsubtree set up on our GoB: +https://chromium.googlesource.com/chromiumos/platform2/libbrillo/ + +This allows AOSP to cherry pick or merge changes directly back into their fork. + +[Update Engine]: https://android.googlesource.com/platform/system/update_engine/ diff --git a/brillo/any.cc b/brillo/any.cc index f84badf..b5ac84f 100644 --- a/brillo/any.cc +++ b/brillo/any.cc @@ -5,6 +5,7 @@ #include <brillo/any.h> #include <algorithm> +#include <utility> namespace brillo { diff --git a/brillo/any.h b/brillo/any.h index 51016b5..d41dd4a 100644 --- a/brillo/any.h +++ b/brillo/any.h @@ -18,7 +18,7 @@ // use helper functions std::ref() and std::cref() to create non-const and // const references respectively. In such a case, the type of contained data // will be std::reference_wrapper<T>. See 'References' unit tests in -// any_unittest.cc for examples. +// any_test.cc for examples. #ifndef LIBBRILLO_BRILLO_ANY_H_ #define LIBBRILLO_BRILLO_ANY_H_ @@ -26,6 +26,8 @@ #include <brillo/any_internal_impl.h> #include <algorithm> +#include <string> +#include <utility> #include <brillo/brillo_export.h> #include <brillo/type_name_undecorate.h> @@ -189,7 +191,7 @@ class BRILLO_EXPORT Any final { // (an appropriate specialization of AppendValueToWriter<T>() is available). // Returns false if the Any is empty or if there is no serialization method // defined for the contained data. - void AppendToDBusMessageWriter(dbus::MessageWriter* writer) const; + void AppendToDBusMessageWriter(::dbus::MessageWriter* writer) const; private: // Returns a pointer to a static buffer containing type tag (sort of a type diff --git a/brillo/any_internal_impl.h b/brillo/any_internal_impl.h index 9309f5d..f4114e6 100644 --- a/brillo/any_internal_impl.h +++ b/brillo/any_internal_impl.h @@ -154,7 +154,7 @@ struct Data { // Gets the contained integral value as an integer. virtual intmax_t GetAsInteger() const = 0; // Writes the contained value to the D-Bus message buffer. - virtual void AppendToDBusMessage(dbus::MessageWriter* writer) const = 0; + virtual void AppendToDBusMessage(::dbus::MessageWriter* writer) const = 0; // Compares if the two data containers have objects of the same value. virtual bool CompareEqual(const Data* other_data) const = 0; }; @@ -180,19 +180,19 @@ struct TypedData : public Data { return int_val; } - template<typename U> + template <typename U> static typename std::enable_if<dbus_utils::IsTypeSupported<U>::value>::type - AppendValueHelper(dbus::MessageWriter* writer, const U& value) { + AppendValueHelper(::dbus::MessageWriter* writer, const U& value) { brillo::dbus_utils::AppendValueToWriterAsVariant(writer, value); } - template<typename U> + template <typename U> static typename std::enable_if<!dbus_utils::IsTypeSupported<U>::value>::type - AppendValueHelper(dbus::MessageWriter* /* writer */, const U& /* value */) { + AppendValueHelper(::dbus::MessageWriter* /* writer */, const U& /* value */) { LOG(FATAL) << "Type '" << GetUndecoratedTypeName<U>() << "' is not supported by D-Bus"; } - void AppendToDBusMessage(dbus::MessageWriter* writer) const override { + void AppendToDBusMessage(::dbus::MessageWriter* writer) const override { return AppendValueHelper(writer, value_); } diff --git a/brillo/any_internal_impl_unittest.cc b/brillo/any_internal_impl_test.cc index 6f7f512..6f7f512 100644 --- a/brillo/any_internal_impl_unittest.cc +++ b/brillo/any_internal_impl_test.cc diff --git a/brillo/any_unittest.cc b/brillo/any_test.cc index db89884..936235e 100644 --- a/brillo/any_unittest.cc +++ b/brillo/any_test.cc @@ -5,6 +5,7 @@ #include <algorithm> #include <functional> #include <string> +#include <utility> #include <vector> #include <brillo/any.h> diff --git a/brillo/array_utils.h b/brillo/array_utils.h new file mode 100644 index 0000000..d180d35 --- /dev/null +++ b/brillo/array_utils.h @@ -0,0 +1,26 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_ARRAY_UTILS_H_ +#define LIBBRILLO_BRILLO_ARRAY_UTILS_H_ + +#include <array> +#include <utility> + +namespace brillo { + +// Create a std::array from a set of values without manually specifying the +// size of the array. Note that unlike the make_array likely to make its way +// into C++20, this function always requires the user to specify ElementType. +// This is done so that users are not surprised by the element type of resulting +// arrays when std::common_type is used. +template <typename ElementType, typename... T> +constexpr auto make_array(T&&... values) { + return std::array<ElementType, sizeof...(T)>{ + static_cast<ElementType>(std::forward<T>(values))...}; +} + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_ARRAY_UTILS_H_ diff --git a/brillo/asan.h b/brillo/asan.h index 9a73202..d29932a 100644 --- a/brillo/asan.h +++ b/brillo/asan.h @@ -17,5 +17,4 @@ #define BRILLO_DISABLE_ASAN #endif -#endif - +#endif // LIBBRILLO_BRILLO_ASAN_H_ diff --git a/brillo/asynchronous_signal_handler.cc b/brillo/asynchronous_signal_handler.cc index b8ec529..38b1787 100644 --- a/brillo/asynchronous_signal_handler.cc +++ b/brillo/asynchronous_signal_handler.cc @@ -11,49 +11,41 @@ #include <base/bind.h> #include <base/files/file_util.h> #include <base/logging.h> -#include <base/message_loop/message_loop.h> -#include <base/posix/eintr_wrapper.h> - -namespace { -const int kInvalidDescriptor = -1; -} // namespace namespace brillo { -AsynchronousSignalHandler::AsynchronousSignalHandler() - : descriptor_(kInvalidDescriptor) { +AsynchronousSignalHandler::AsynchronousSignalHandler() { CHECK_EQ(sigemptyset(&signal_mask_), 0) << "Failed to initialize signal mask"; CHECK_EQ(sigemptyset(&saved_signal_mask_), 0) << "Failed to initialize signal mask"; } AsynchronousSignalHandler::~AsynchronousSignalHandler() { - if (descriptor_ != kInvalidDescriptor) { - MessageLoop::current()->CancelTask(fd_watcher_task_); + fd_watcher_ = nullptr; - if (IGNORE_EINTR(close(descriptor_)) != 0) - PLOG(WARNING) << "Failed to close file descriptor"; + if (!descriptor_.is_valid()) + return; - descriptor_ = kInvalidDescriptor; - CHECK_EQ(0, sigprocmask(SIG_SETMASK, &saved_signal_mask_, nullptr)); - } + // Close FD before restoring sigprocmask. + descriptor_.reset(); + CHECK_EQ(0, sigprocmask(SIG_SETMASK, &saved_signal_mask_, nullptr)); } void AsynchronousSignalHandler::Init() { - CHECK_EQ(kInvalidDescriptor, descriptor_); + // Making sure it is not yet initialized. + CHECK(!descriptor_.is_valid()); + + // Set sigprocmask before creating signalfd. CHECK_EQ(0, sigprocmask(SIG_BLOCK, &signal_mask_, &saved_signal_mask_)); - descriptor_ = - signalfd(descriptor_, &signal_mask_, SFD_CLOEXEC | SFD_NONBLOCK); - CHECK_NE(kInvalidDescriptor, descriptor_); - fd_watcher_task_ = MessageLoop::current()->WatchFileDescriptor( - FROM_HERE, - descriptor_, - MessageLoop::WatchMode::kWatchRead, - true, - base::Bind(&AsynchronousSignalHandler::OnFileCanReadWithoutBlocking, - base::Unretained(this))); - CHECK(fd_watcher_task_ != MessageLoop::kTaskIdNull) - << "Watching shutdown pipe failed."; + + // Creating signalfd, and start watching it. + descriptor_.reset(signalfd(-1, &signal_mask_, SFD_CLOEXEC | SFD_NONBLOCK)); + CHECK(descriptor_.is_valid()); + fd_watcher_ = base::FileDescriptorWatcher::WatchReadable( + descriptor_.get(), + base::BindRepeating(&AsynchronousSignalHandler::OnReadable, + base::Unretained(this))); + CHECK(fd_watcher_) << "Watching signalfd failed."; } void AsynchronousSignalHandler::RegisterHandler(int signal, @@ -65,15 +57,16 @@ void AsynchronousSignalHandler::RegisterHandler(int signal, void AsynchronousSignalHandler::UnregisterHandler(int signal) { Callbacks::iterator callback_it = registered_callbacks_.find(signal); - if (callback_it != registered_callbacks_.end()) { - registered_callbacks_.erase(callback_it); - ResetSignal(signal); - } + if (callback_it == registered_callbacks_.end()) + return; + registered_callbacks_.erase(callback_it); + CHECK_EQ(0, sigdelset(&signal_mask_, signal)); + UpdateSignals(); } -void AsynchronousSignalHandler::OnFileCanReadWithoutBlocking() { +void AsynchronousSignalHandler::OnReadable() { struct signalfd_siginfo info; - while (base::ReadFromFD(descriptor_, + while (base::ReadFromFD(descriptor_.get(), reinterpret_cast<char*>(&info), sizeof(info))) { int signal = info.ssi_signo; Callbacks::iterator callback_it = registered_callbacks_.find(signal); @@ -85,24 +78,29 @@ void AsynchronousSignalHandler::OnFileCanReadWithoutBlocking() { } const SignalHandler& callback = callback_it->second; bool must_unregister = callback.Run(info); - if (must_unregister) { + if (must_unregister) UnregisterHandler(signal); - } } } -void AsynchronousSignalHandler::ResetSignal(int signal) { - CHECK_EQ(0, sigdelset(&signal_mask_, signal)); - UpdateSignals(); -} - void AsynchronousSignalHandler::UpdateSignals() { - if (descriptor_ != kInvalidDescriptor) { - CHECK_EQ(0, sigprocmask(SIG_SETMASK, &saved_signal_mask_, nullptr)); - CHECK_EQ(0, sigprocmask(SIG_BLOCK, &signal_mask_, nullptr)); - CHECK_EQ(descriptor_, - signalfd(descriptor_, &signal_mask_, SFD_CLOEXEC | SFD_NONBLOCK)); + if (!descriptor_.is_valid()) + return; + sigset_t mask; +#ifdef __ANDROID__ + CHECK_EQ(0, sigemptyset(&mask)); + for (size_t i = 0; i < NSIG; ++i) { + if (sigismember(&signal_mask_, i) == 1 || sigismember(&saved_signal_mask_, i) == 1) { + CHECK_EQ(0, sigaddset(&mask, i)); + } } +#else + CHECK_EQ(0, sigorset(&mask, &signal_mask_, &saved_signal_mask_)); +#endif + CHECK_EQ(0, sigprocmask(SIG_SETMASK, &mask, nullptr)); + CHECK_EQ( + descriptor_.get(), + signalfd(descriptor_.get(), &signal_mask_, SFD_CLOEXEC | SFD_NONBLOCK)); } } // namespace brillo diff --git a/brillo/asynchronous_signal_handler.h b/brillo/asynchronous_signal_handler.h index ceae1ff..4b0edce 100644 --- a/brillo/asynchronous_signal_handler.h +++ b/brillo/asynchronous_signal_handler.h @@ -5,18 +5,16 @@ #ifndef LIBBRILLO_BRILLO_ASYNCHRONOUS_SIGNAL_HANDLER_H_ #define LIBBRILLO_BRILLO_ASYNCHRONOUS_SIGNAL_HANDLER_H_ -#include <signal.h> #include <sys/signalfd.h> #include <map> +#include <memory> #include <base/callback.h> -#include <base/compiler_specific.h> -#include <base/macros.h> -#include <base/message_loop/message_loop.h> +#include <base/files/file_descriptor_watcher_posix.h> +#include <base/files/scoped_file.h> #include <brillo/asynchronous_signal_handler_interface.h> #include <brillo/brillo_export.h> -#include <brillo/message_loops/message_loop.h> namespace brillo { // Sets up signal handlers for registered signals, and converts signal receipt @@ -25,10 +23,14 @@ namespace brillo { class BRILLO_EXPORT AsynchronousSignalHandler final : public AsynchronousSignalHandlerInterface { public: + using AsynchronousSignalHandlerInterface::SignalHandler; + AsynchronousSignalHandler(); ~AsynchronousSignalHandler() override; - using AsynchronousSignalHandlerInterface::SignalHandler; + AsynchronousSignalHandler(const AsynchronousSignalHandler&) = delete; + AsynchronousSignalHandler& + operator=(const AsynchronousSignalHandler&) = delete; // Initialize the handler. void Init(); @@ -40,17 +42,20 @@ class BRILLO_EXPORT AsynchronousSignalHandler final : private: // Called from the main loop when we can read from |descriptor_|, indicated // that a signal was processed. - void OnFileCanReadWithoutBlocking(); + void OnReadable(); - // Controller used to manage watching of signalling pipe. - MessageLoop::TaskId fd_watcher_task_{MessageLoop::kTaskIdNull}; + // Updates the set of signals that this handler listens to. + BRILLO_PRIVATE void UpdateSignals(); - // The registered callbacks. - typedef std::map<int, SignalHandler> Callbacks; + // Map from signal to its registered callback. + using Callbacks = std::map<int, SignalHandler>; Callbacks registered_callbacks_; // File descriptor for accepting signals indicated by |signal_mask_|. - int descriptor_; + base::ScopedFD descriptor_; + + // Controller used to manage watching of signalling pipe. + std::unique_ptr<base::FileDescriptorWatcher::Controller> fd_watcher_; // A set of signals to be handled after the dispatcher is running. sigset_t signal_mask_; @@ -58,15 +63,6 @@ class BRILLO_EXPORT AsynchronousSignalHandler final : // A copy of the signal mask before the dispatcher starts, which will be // used to restore to the original state when the dispatcher stops. sigset_t saved_signal_mask_; - - // Resets the given signal to its default behavior. Doesn't touch - // |registered_callbacks_|. - BRILLO_PRIVATE void ResetSignal(int signal); - - // Updates the set of signals that this handler listens to. - BRILLO_PRIVATE void UpdateSignals(); - - DISALLOW_COPY_AND_ASSIGN(AsynchronousSignalHandler); }; } // namespace brillo diff --git a/brillo/asynchronous_signal_handler_interface.h b/brillo/asynchronous_signal_handler_interface.h index ef0012d..7bae444 100644 --- a/brillo/asynchronous_signal_handler_interface.h +++ b/brillo/asynchronous_signal_handler_interface.h @@ -20,7 +20,8 @@ class BRILLO_EXPORT AsynchronousSignalHandlerInterface { virtual ~AsynchronousSignalHandlerInterface() = default; // The callback called when a signal is received. - using SignalHandler = base::Callback<bool(const struct signalfd_siginfo&)>; + using SignalHandler = + base::RepeatingCallback<bool(const struct signalfd_siginfo&)>; // Register a new handler for the given |signal|, replacing any previously // registered handler. |callback| will be called on the thread the diff --git a/brillo/asynchronous_signal_handler_unittest.cc b/brillo/asynchronous_signal_handler_test.cc index ec3b061..2211b9c 100644 --- a/brillo/asynchronous_signal_handler_unittest.cc +++ b/brillo/asynchronous_signal_handler_test.cc @@ -113,6 +113,8 @@ TEST_F(AsynchronousSignalHandlerTest, CheckMultipleSignal) { } } +// TODO(crbug/1011829): This test is flaky. +#if 0 TEST_F(AsynchronousSignalHandlerTest, CheckChld) { handler_.RegisterHandler( SIGCHLD, @@ -134,5 +136,6 @@ TEST_F(AsynchronousSignalHandlerTest, CheckChld) { EXPECT_EQ(static_cast<int>(CLD_EXITED), infos_[0].ssi_code); EXPECT_EQ(EXIT_SUCCESS, infos_[0].ssi_status); } +#endif } // namespace brillo diff --git a/brillo/backoff_entry_test.cc b/brillo/backoff_entry_test.cc new file mode 100644 index 0000000..6a95bc0 --- /dev/null +++ b/brillo/backoff_entry_test.cc @@ -0,0 +1,311 @@ +// Copyright 2015 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/backoff_entry.h> +#include <gtest/gtest.h> + +using base::TimeDelta; +using base::TimeTicks; + +namespace brillo { + +BackoffEntry::Policy base_policy = { 0, 1000, 2.0, 0.0, 20000, 2000, false }; + +class TestBackoffEntry : public BackoffEntry { + public: + explicit TestBackoffEntry(const Policy* const policy) + : BackoffEntry(policy), + now_(TimeTicks()) { + // Work around initialization in constructor not picking up + // fake time. + SetCustomReleaseTime(TimeTicks()); + } + + ~TestBackoffEntry() override {} + + TimeTicks ImplGetTimeNow() const override { return now_; } + + void set_now(const TimeTicks& now) { + now_ = now; + } + + private: + TimeTicks now_; + + DISALLOW_COPY_AND_ASSIGN(TestBackoffEntry); +}; + +TEST(BackoffEntryTest, BaseTest) { + TestBackoffEntry entry(&base_policy); + EXPECT_FALSE(entry.ShouldRejectRequest()); + EXPECT_EQ(TimeDelta(), entry.GetTimeUntilRelease()); + + entry.InformOfRequest(false); + EXPECT_TRUE(entry.ShouldRejectRequest()); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); +} + +TEST(BackoffEntryTest, CanDiscardNeverExpires) { + BackoffEntry::Policy never_expires_policy = base_policy; + never_expires_policy.entry_lifetime_ms = -1; + TestBackoffEntry never_expires(&never_expires_policy); + EXPECT_FALSE(never_expires.CanDiscard()); + never_expires.set_now(TimeTicks() + TimeDelta::FromDays(100)); + EXPECT_FALSE(never_expires.CanDiscard()); +} + +TEST(BackoffEntryTest, CanDiscard) { + TestBackoffEntry entry(&base_policy); + // Because lifetime is non-zero, we shouldn't be able to discard yet. + EXPECT_FALSE(entry.CanDiscard()); + + // Test the "being used" case. + entry.InformOfRequest(false); + EXPECT_FALSE(entry.CanDiscard()); + + // Test the case where there are errors but we can time out. + entry.set_now( + entry.GetReleaseTime() + TimeDelta::FromMilliseconds(1)); + EXPECT_FALSE(entry.CanDiscard()); + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds( + base_policy.maximum_backoff_ms + 1)); + EXPECT_TRUE(entry.CanDiscard()); + + // Test the final case (no errors, dependent only on specified lifetime). + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds( + base_policy.entry_lifetime_ms - 1)); + entry.InformOfRequest(true); + EXPECT_FALSE(entry.CanDiscard()); + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds( + base_policy.entry_lifetime_ms)); + EXPECT_TRUE(entry.CanDiscard()); +} + +TEST(BackoffEntryTest, CanDiscardAlwaysDelay) { + BackoffEntry::Policy always_delay_policy = base_policy; + always_delay_policy.always_use_initial_delay = true; + always_delay_policy.entry_lifetime_ms = 0; + + TestBackoffEntry entry(&always_delay_policy); + + // Because lifetime is non-zero, we shouldn't be able to discard yet. + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds(2000)); + EXPECT_TRUE(entry.CanDiscard()); + + // Even with no failures, we wait until the delay before we allow discard. + entry.InformOfRequest(true); + EXPECT_FALSE(entry.CanDiscard()); + + // Wait until the delay expires, and we can discard the entry again. + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds(1000)); + EXPECT_TRUE(entry.CanDiscard()); +} + +TEST(BackoffEntryTest, CanDiscardNotStored) { + BackoffEntry::Policy no_store_policy = base_policy; + no_store_policy.entry_lifetime_ms = 0; + TestBackoffEntry not_stored(&no_store_policy); + EXPECT_TRUE(not_stored.CanDiscard()); +} + +TEST(BackoffEntryTest, ShouldIgnoreFirstTwo) { + BackoffEntry::Policy lenient_policy = base_policy; + lenient_policy.num_errors_to_ignore = 2; + + BackoffEntry entry(&lenient_policy); + + entry.InformOfRequest(false); + EXPECT_FALSE(entry.ShouldRejectRequest()); + + entry.InformOfRequest(false); + EXPECT_FALSE(entry.ShouldRejectRequest()); + + entry.InformOfRequest(false); + EXPECT_TRUE(entry.ShouldRejectRequest()); +} + +TEST(BackoffEntryTest, ReleaseTimeCalculation) { + TestBackoffEntry entry(&base_policy); + + // With zero errors, should return "now". + TimeTicks result = entry.GetReleaseTime(); + EXPECT_EQ(entry.ImplGetTimeNow(), result); + + // 1 error. + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ(entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(1000), result); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // 2 errors. + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ(entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(2000), result); + EXPECT_EQ(TimeDelta::FromMilliseconds(2000), entry.GetTimeUntilRelease()); + + // 3 errors. + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ(entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(4000), result); + EXPECT_EQ(TimeDelta::FromMilliseconds(4000), entry.GetTimeUntilRelease()); + + // 6 errors (to check it doesn't pass maximum). + entry.InformOfRequest(false); + entry.InformOfRequest(false); + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ( + entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(20000), result); +} + +TEST(BackoffEntryTest, ReleaseTimeCalculationAlwaysDelay) { + BackoffEntry::Policy always_delay_policy = base_policy; + always_delay_policy.always_use_initial_delay = true; + always_delay_policy.num_errors_to_ignore = 2; + + TestBackoffEntry entry(&always_delay_policy); + + // With previous requests, should return "now". + TimeTicks result = entry.GetReleaseTime(); + EXPECT_EQ(TimeDelta(), entry.GetTimeUntilRelease()); + + // 1 error. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // 2 errors. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // 3 errors, exponential backoff starts. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(2000), entry.GetTimeUntilRelease()); + + // 4 errors. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(4000), entry.GetTimeUntilRelease()); + + // 8 errors (to check it doesn't pass maximum). + entry.InformOfRequest(false); + entry.InformOfRequest(false); + entry.InformOfRequest(false); + entry.InformOfRequest(false); + result = entry.GetReleaseTime(); + EXPECT_EQ(TimeDelta::FromMilliseconds(20000), entry.GetTimeUntilRelease()); +} + +TEST(BackoffEntryTest, ReleaseTimeCalculationWithJitter) { + for (int i = 0; i < 10; ++i) { + BackoffEntry::Policy jittery_policy = base_policy; + jittery_policy.jitter_factor = 0.2; + + TestBackoffEntry entry(&jittery_policy); + + entry.InformOfRequest(false); + entry.InformOfRequest(false); + entry.InformOfRequest(false); + TimeTicks result = entry.GetReleaseTime(); + EXPECT_LE( + entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(3200), result); + EXPECT_GE( + entry.ImplGetTimeNow() + TimeDelta::FromMilliseconds(4000), result); + } +} + +TEST(BackoffEntryTest, FailureThenSuccess) { + TestBackoffEntry entry(&base_policy); + + // Failure count 1, establishes horizon. + entry.InformOfRequest(false); + TimeTicks release_time = entry.GetReleaseTime(); + EXPECT_EQ(TimeTicks() + TimeDelta::FromMilliseconds(1000), release_time); + + // Success, failure count 0, should not advance past + // the horizon that was already set. + entry.set_now(release_time - TimeDelta::FromMilliseconds(200)); + entry.InformOfRequest(true); + EXPECT_EQ(release_time, entry.GetReleaseTime()); + + // Failure, failure count 1. + entry.InformOfRequest(false); + EXPECT_EQ(release_time + TimeDelta::FromMilliseconds(800), + entry.GetReleaseTime()); +} + +TEST(BackoffEntryTest, FailureThenSuccessAlwaysDelay) { + BackoffEntry::Policy always_delay_policy = base_policy; + always_delay_policy.always_use_initial_delay = true; + always_delay_policy.num_errors_to_ignore = 1; + + TestBackoffEntry entry(&always_delay_policy); + + // Failure count 1. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // Failure count 2. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(2000), entry.GetTimeUntilRelease()); + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds(2000)); + + // Success. We should go back to the original delay. + entry.InformOfRequest(true); + EXPECT_EQ(TimeDelta::FromMilliseconds(1000), entry.GetTimeUntilRelease()); + + // Failure count reaches 2 again. We should increase the delay once more. + entry.InformOfRequest(false); + EXPECT_EQ(TimeDelta::FromMilliseconds(2000), entry.GetTimeUntilRelease()); + entry.set_now(entry.GetReleaseTime() + TimeDelta::FromMilliseconds(2000)); +} + +TEST(BackoffEntryTest, RetainCustomHorizon) { + TestBackoffEntry custom(&base_policy); + TimeTicks custom_horizon = TimeTicks() + TimeDelta::FromDays(3); + custom.SetCustomReleaseTime(custom_horizon); + custom.InformOfRequest(false); + custom.InformOfRequest(true); + custom.set_now(TimeTicks() + TimeDelta::FromDays(2)); + custom.InformOfRequest(false); + custom.InformOfRequest(true); + EXPECT_EQ(custom_horizon, custom.GetReleaseTime()); + + // Now check that once we are at or past the custom horizon, + // we get normal behavior. + custom.set_now(TimeTicks() + TimeDelta::FromDays(3)); + custom.InformOfRequest(false); + EXPECT_EQ( + TimeTicks() + TimeDelta::FromDays(3) + TimeDelta::FromMilliseconds(1000), + custom.GetReleaseTime()); +} + +TEST(BackoffEntryTest, RetainCustomHorizonWhenInitialErrorsIgnored) { + // Regression test for a bug discovered during code review. + BackoffEntry::Policy lenient_policy = base_policy; + lenient_policy.num_errors_to_ignore = 1; + TestBackoffEntry custom(&lenient_policy); + TimeTicks custom_horizon = TimeTicks() + TimeDelta::FromDays(3); + custom.SetCustomReleaseTime(custom_horizon); + custom.InformOfRequest(false); // This must not reset the horizon. + EXPECT_EQ(custom_horizon, custom.GetReleaseTime()); +} + +TEST(BackoffEntryTest, OverflowProtection) { + BackoffEntry::Policy large_multiply_policy = base_policy; + large_multiply_policy.multiply_factor = 256; + TestBackoffEntry custom(&large_multiply_policy); + + // Trigger enough failures such that more than 11 bits of exponent are used + // to represent the exponential backoff intermediate values. Given a multiply + // factor of 256 (2^8), 129 iterations is enough: 2^(8*(129-1)) = 2^1024. + for (int i = 0; i < 129; ++i) { + custom.set_now(custom.ImplGetTimeNow() + custom.GetTimeUntilRelease()); + custom.InformOfRequest(false); + ASSERT_TRUE(custom.ShouldRejectRequest()); + } + + // Max delay should still be respected. + EXPECT_EQ(20000, custom.GetTimeUntilRelease().InMilliseconds()); +} + +} // namespace brillo diff --git a/brillo/bind_lambda.h b/brillo/bind_lambda.h deleted file mode 100644 index 50ac095..0000000 --- a/brillo/bind_lambda.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2014 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef LIBBRILLO_BRILLO_BIND_LAMBDA_H_ -#define LIBBRILLO_BRILLO_BIND_LAMBDA_H_ - -#include <base/bind.h> - -//////////////////////////////////////////////////////////////////////////////// -// This file is an extension to base/bind_internal.h and adds a RunnableAdapter -// class specialization that wraps a functor (including lambda objects), so -// they can be used in base::Callback/base::Bind constructs. -// By including this file you will gain the ability to write expressions like: -// base::Callback<int(int)> callback = base::Bind([](int value) { -// return value * value; -// }); -//////////////////////////////////////////////////////////////////////////////// -namespace base { -namespace internal { - -// LambdaAdapter is a helper class that specializes on different function call -// signatures and provides the RunType and Run() method required by -// RunnableAdapter<> class. -template <typename Lambda, typename Sig> -class LambdaAdapter; - -// R(...) -template <typename Lambda, typename R, typename... Args> -class LambdaAdapter<Lambda, R(Lambda::*)(Args... args)> { - public: - typedef R(RunType)(Args...); - explicit LambdaAdapter(Lambda lambda) : lambda_(lambda) {} - R Run(Args... args) { return lambda_(std::forward<Args>(args)...); } - - private: - Lambda lambda_; -}; - -// R(...) const -template <typename Lambda, typename R, typename... Args> -class LambdaAdapter<Lambda, R(Lambda::*)(Args... args) const> { - public: - typedef R(RunType)(Args...); - explicit LambdaAdapter(Lambda lambda) : lambda_(lambda) {} - R Run(Args... args) { return lambda_(std::forward<Args>(args)...); } - - private: - Lambda lambda_; -}; - -template <typename Lambda> -class RunnableAdapter - : public LambdaAdapter<Lambda, decltype(&Lambda::operator())> { - public: - explicit RunnableAdapter(Lambda lambda) - : LambdaAdapter<Lambda, decltype(&Lambda::operator())>(lambda) {} -}; - -} // namespace internal -} // namespace base - -#endif // LIBBRILLO_BRILLO_BIND_LAMBDA_H_ diff --git a/brillo/binder_watcher.cc b/brillo/binder_watcher.cc index 9752204..51b0f59 100644 --- a/brillo/binder_watcher.cc +++ b/brillo/binder_watcher.cc @@ -33,24 +33,11 @@ void OnBinderReadReady() { namespace brillo { -BinderWatcher::BinderWatcher(MessageLoop* message_loop) - : message_loop_(message_loop) {} +BinderWatcher::BinderWatcher() = default; -BinderWatcher::BinderWatcher() : message_loop_(nullptr) {} - -BinderWatcher::~BinderWatcher() { - if (task_id_ != MessageLoop::kTaskIdNull) - message_loop_->CancelTask(task_id_); -} +BinderWatcher::~BinderWatcher() = default; bool BinderWatcher::Init() { - if (!message_loop_) - message_loop_ = MessageLoop::current(); - if (!message_loop_) { - LOG(ERROR) << "Must initialize a brillo::MessageLoop to use BinderWatcher"; - return false; - } - int binder_fd = -1; ProcessState::self()->setThreadPoolMaxThreadCount(0); IPCThreadState::self()->disableBackgroundScheduling(true); @@ -66,13 +53,10 @@ bool BinderWatcher::Init() { } VLOG(1) << "Got binder FD " << binder_fd; - task_id_ = message_loop_->WatchFileDescriptor( - FROM_HERE, + watcher_ = base::FileDescriptorWatcher::WatchReadable( binder_fd, - MessageLoop::kWatchRead, - true /* persistent */, - base::Bind(&OnBinderReadReady)); - if (task_id_ == MessageLoop::kTaskIdNull) { + base::BindRepeating(&OnBinderReadReady)); + if (!watcher_) { LOG(ERROR) << "Failed to watch binder FD"; return false; } diff --git a/brillo/binder_watcher.h b/brillo/binder_watcher.h index ece999d..d7af50e 100644 --- a/brillo/binder_watcher.h +++ b/brillo/binder_watcher.h @@ -17,8 +17,10 @@ #ifndef LIBBRILLO_BRILLO_BINDER_WATCHER_H_ #define LIBBRILLO_BRILLO_BINDER_WATCHER_H_ +#include <memory> + +#include <base/files/file_descriptor_watcher_posix.h> #include <base/macros.h> -#include <brillo/message_loops/message_loop.h> namespace brillo { @@ -26,9 +28,6 @@ namespace brillo { // make the message loop watch for binder events and pass them to libbinder. class BinderWatcher final { public: - // Construct the BinderWatcher using the passed |message_loop| if not null or - // the current MessageLoop otherwise. - explicit BinderWatcher(MessageLoop* message_loop); BinderWatcher(); ~BinderWatcher(); @@ -36,8 +35,7 @@ class BinderWatcher final { bool Init(); private: - MessageLoop::TaskId task_id_{MessageLoop::kTaskIdNull}; - MessageLoop* message_loop_; + std::unique_ptr<base::FileDescriptorWatcher::Controller> watcher_; DISALLOW_COPY_AND_ASSIGN(BinderWatcher); }; diff --git a/brillo/blkdev_utils/device_mapper.cc b/brillo/blkdev_utils/device_mapper.cc new file mode 100644 index 0000000..726cd94 --- /dev/null +++ b/brillo/blkdev_utils/device_mapper.cc @@ -0,0 +1,240 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/device_mapper.h> + +#include <libdevmapper.h> +#include <algorithm> +#include <utility> + +#include <base/files/file_util.h> +#include <base/strings/string_number_conversions.h> +#include <base/strings/string_tokenizer.h> +#include <base/strings/stringprintf.h> +#include <brillo/blkdev_utils/device_mapper_task.h> +#include <brillo/secure_blob.h> + +namespace brillo { + +// Use a tokenizer to parse string data stored in SecureBlob. +// The tokenizer does not store internal state so it should be +// okay to use with SecureBlobs. +// DO NOT USE .toker() as that leaks contents of the SecureBlob. +using SecureBlobTokenizer = + base::StringTokenizerT<std::string, SecureBlob::const_iterator>; + +DevmapperTable::DevmapperTable(uint64_t start, + uint64_t size, + const std::string& type, + const SecureBlob& parameters) + : start_(start), size_(size), type_(type), parameters_(parameters) {} + +SecureBlob DevmapperTable::ToSecureBlob() { + SecureBlob table_blob(base::StringPrintf("%" PRIu64 " %" PRIu64 " %s ", + start_, size_, type_.c_str())); + + return SecureBlob::Combine(table_blob, parameters_); +} + +DevmapperTable DevmapperTable::CreateTableFromSecureBlob( + const SecureBlob& table) { + uint64_t start, size; + std::string type; + DevmapperTable invalid_table(0, 0, "", SecureBlob()); + + SecureBlobTokenizer tokenizer(table.begin(), table.end(), " "); + + // First parameter is start. + if (!tokenizer.GetNext() || + !base::StringToUint64( + std::string(tokenizer.token_begin(), tokenizer.token_end()), &start)) + return invalid_table; + + // Second parameter is size of the dm device. + if (!tokenizer.GetNext() || + !base::StringToUint64( + std::string(tokenizer.token_begin(), tokenizer.token_end()), &size)) + return invalid_table; + + // Third parameter is type of dm device. + if (!tokenizer.GetNext()) + return invalid_table; + + type = std::string(tokenizer.token_begin(), tokenizer.token_end()); + + // The remaining string is the parameters. + if (!tokenizer.GetNext()) + return invalid_table; + + // The remaining part is the parameters passed to the device. + SecureBlob target = SecureBlob(tokenizer.token_begin(), table.end()); + + return DevmapperTable(start, size, type, target); +} + +SecureBlob DevmapperTable::CryptGetKey() { + SecureBlobTokenizer tokenizer(parameters_.begin(), parameters_.end(), " "); + + // First field is the cipher. + if (!tokenizer.GetNext()) + return SecureBlob(); + + // The key is stored in the second field. + if (!tokenizer.GetNext()) + return SecureBlob(); + + SecureBlob hex_key(tokenizer.token_begin(), tokenizer.token_end()); + + SecureBlob key = SecureHexToSecureBlob(hex_key); + + if (key.empty()) { + LOG(ERROR) << "CryptExtractKey: HexStringToBytes failed"; + return SecureBlob(); + } + + return key; +} + +// In order to not leak the encryption key to non-SecureBlob managed memory, +// create the parameter blobs in three parts and combine. +SecureBlob DevmapperTable::CryptCreateParameters( + const std::string& cipher, + const SecureBlob& encryption_key, + const int iv_offset, + const base::FilePath& device, + int device_offset, + bool allow_discard) { + // First field is the cipher. + SecureBlob parameter_parts[3]; + + parameter_parts[0] = SecureBlob(cipher + " "); + parameter_parts[1] = SecureBlobToSecureHex(encryption_key); + parameter_parts[2] = SecureBlob(base::StringPrintf( + " %d %s %d%s", iv_offset, device.value().c_str(), device_offset, + (allow_discard ? " 1 allow_discards" : ""))); + + SecureBlob parameters; + for (auto param_part : parameter_parts) + parameters = SecureBlob::Combine(parameters, param_part); + + return parameters; +} + +std::unique_ptr<DevmapperTask> CreateDevmapperTask(int type) { + return std::make_unique<DevmapperTaskImpl>(type); +} + +DeviceMapper::DeviceMapper() { + dm_task_factory_ = base::Bind(&CreateDevmapperTask); +} + +DeviceMapper::DeviceMapper(const DevmapperTaskFactory& factory) + : dm_task_factory_(factory) {} + +bool DeviceMapper::Setup(const std::string& name, const DevmapperTable& table) { + auto task = dm_task_factory_.Run(DM_DEVICE_CREATE); + + if (!task->SetName(name)) { + LOG(ERROR) << "Setup: SetName failed."; + return false; + } + + if (!task->AddTarget(table.GetStart(), table.GetSize(), table.GetType(), + table.GetParameters())) { + LOG(ERROR) << "Setup: AddTarget failed"; + return false; + } + + if (!task->Run(true /* udev sync */)) { + LOG(ERROR) << "Setup: Run failed."; + return false; + } + + return true; +} + +bool DeviceMapper::Remove(const std::string& name) { + auto task = dm_task_factory_.Run(DM_DEVICE_REMOVE); + + if (!task->SetName(name)) { + LOG(ERROR) << "Remove: SetName failed."; + return false; + } + + if (!task->Run(true /* udev_sync */)) { + LOG(ERROR) << "Remove: Teardown failed."; + return false; + } + + return true; +} + +DevmapperTable DeviceMapper::GetTable(const std::string& name) { + auto task = dm_task_factory_.Run(DM_DEVICE_TABLE); + uint64_t start, size; + std::string type; + SecureBlob parameters; + + if (!task->SetName(name)) { + LOG(ERROR) << "GetTable: SetName failed."; + return DevmapperTable(0, 0, "", SecureBlob()); + } + + if (!task->Run()) { + LOG(ERROR) << "GetTable: Run failed."; + return DevmapperTable(0, 0, "", SecureBlob()); + } + + task->GetNextTarget(&start, &size, &type, ¶meters); + + return DevmapperTable(start, size, type, parameters); +} + +bool DeviceMapper::WipeTable(const std::string& name) { + auto size_task = dm_task_factory_.Run(DM_DEVICE_TABLE); + + if (!size_task->SetName(name)) { + LOG(ERROR) << "WipeTable: SetName failed."; + return false; + } + + if (!size_task->Run()) { + LOG(ERROR) << "WipeTable: RunTask failed."; + return false; + } + + // Arguments for fetching dm target. + bool ret = false; + uint64_t start = 0, size = 0, total_size = 0; + std::string type; + SecureBlob parameters; + + // Get maximum size of the device to be wiped. + do { + ret = size_task->GetNextTarget(&start, &size, &type, ¶meters); + total_size = std::max(start + size, total_size); + } while (ret); + + // Setup wipe task. + auto wipe_task = dm_task_factory_.Run(DM_DEVICE_RELOAD); + + if (!wipe_task->SetName(name)) { + LOG(ERROR) << "WipeTable: SetName failed."; + return false; + } + + if (!wipe_task->AddTarget(0, total_size, "error", SecureBlob())) { + LOG(ERROR) << "WipeTable: AddTarget failed."; + return false; + } + + if (!wipe_task->Run()) { + LOG(ERROR) << "WipeTable: RunTask failed."; + return false; + } + + return true; +} + +} // namespace brillo diff --git a/brillo/blkdev_utils/device_mapper.h b/brillo/blkdev_utils/device_mapper.h new file mode 100644 index 0000000..478b30a --- /dev/null +++ b/brillo/blkdev_utils/device_mapper.h @@ -0,0 +1,116 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_H_ + +#include <functional> +#include <memory> +#include <string> + +#include <base/bind.h> +#include <base/callback.h> +#include <base/files/file_path.h> +#include <brillo/blkdev_utils/device_mapper_task.h> + +namespace brillo { + +// DevmapperTable manages device parameters. Contains helper +// functions to parse results from dmsetup. Since the table parameters +// may contain sensitive data eg. dm-crypt keys, we use SecureBlobs for +// the table parameters and as the table output format. + +class BRILLO_EXPORT DevmapperTable { + public: + // Create table from table parameters. + // Useful for setting up devices. + DevmapperTable(uint64_t start, + uint64_t size, + const std::string& type, + const SecureBlob& parameters); + + ~DevmapperTable() = default; + + // Returns the table as a SecureBlob. + SecureBlob ToSecureBlob(); + + // Getters for table components. + uint64_t GetStart() const { return start_; } + uint64_t GetSize() const { return size_; } + std::string GetType() const { return type_; } + SecureBlob GetParameters() const { return parameters_; } + + // Create table from table blob. + // Useful for parsing output from dmsetup. + // Using a static function to surface errors in parsing the blob. + static DevmapperTable CreateTableFromSecureBlob(const SecureBlob& table); + + // dm-crypt specific functions: + // ---------------------------- + // Extract key from (crypt) table. + SecureBlob CryptGetKey(); + + // Create crypt parameters . + // Useful for parsing output from dmsetup. + // Using a static function to surface errors in parsing the blob. + static SecureBlob CryptCreateParameters(const std::string& cipher, + const SecureBlob& encryption_key, + const int iv_offset, + const base::FilePath& device, + int device_offset, + bool allow_discard); + + private: + const uint64_t start_; + const uint64_t size_; + const std::string type_; + const SecureBlob parameters_; +}; + +// DevmapperTask is an abstract class so we wrap it in a unique_ptr. +using DevmapperTaskFactory = + base::Callback<std::unique_ptr<DevmapperTask>(int)>; + +// DeviceMapper handles the creation and removal of dm devices. +class BRILLO_EXPORT DeviceMapper { + public: + // Default constructor: sets up real devmapper devices. + DeviceMapper(); + + // Set a non-default dm task factory. + explicit DeviceMapper(const DevmapperTaskFactory& factory); + + // Default destructor. + ~DeviceMapper() = default; + + // Sets up device with table on /dev/mapper/<name>. + // Parameters + // name - Name of the devmapper device. + // table - Table for the devmapper device. + bool Setup(const std::string& name, const DevmapperTable& table); + + // Removes device. + // Parameters + // name - Name of the devmapper device. + bool Remove(const std::string& device); + + // Returns table for device. + // Parameters + // name - Name of the devmapper device. + DevmapperTable GetTable(const std::string& name); + + // Clears table for device. + // Parameters + // name - Name of the devmapper device. + bool WipeTable(const std::string& name); + + private: + // Devmapper task factory. + DevmapperTaskFactory dm_task_factory_; + DISALLOW_COPY_AND_ASSIGN(DeviceMapper); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_H_ diff --git a/brillo/blkdev_utils/device_mapper_fake.cc b/brillo/blkdev_utils/device_mapper_fake.cc new file mode 100644 index 0000000..8126960 --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_fake.cc @@ -0,0 +1,112 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/device_mapper_fake.h> + +#include <memory> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +namespace brillo { +namespace fake { + +namespace { + +// Stub DmTask runs into a map for easy reference. +bool StubDmRunTask(DmTask* task, bool udev_sync) { + std::string dev_name = task->name; + std::string params; + int type = task->type; + static auto& dm_target_map_ = + *new std::unordered_map<std::string, std::vector<DmTarget>>(); + + switch (type) { + case DM_DEVICE_CREATE: + CHECK_EQ(udev_sync, true); + if (dm_target_map_.find(dev_name) != dm_target_map_.end()) + return false; + dm_target_map_.insert(std::make_pair(dev_name, task->targets)); + break; + case DM_DEVICE_REMOVE: + CHECK_EQ(udev_sync, true); + if (dm_target_map_.find(dev_name) == dm_target_map_.end()) + return false; + dm_target_map_.erase(dev_name); + break; + case DM_DEVICE_TABLE: + CHECK_EQ(udev_sync, false); + if (dm_target_map_.find(dev_name) == dm_target_map_.end()) + return false; + task->targets = dm_target_map_[dev_name]; + break; + case DM_DEVICE_RELOAD: + CHECK_EQ(udev_sync, false); + if (dm_target_map_.find(dev_name) == dm_target_map_.end()) + return false; + dm_target_map_.erase(dev_name); + dm_target_map_.insert(std::make_pair(dev_name, task->targets)); + break; + default: + return false; + } + return true; +} + +std::unique_ptr<DmTask> DmTaskCreate(int type) { + auto t = std::make_unique<DmTask>(); + t->type = type; + return t; +} + +} // namespace + +FakeDevmapperTask::FakeDevmapperTask(int type) : task_(DmTaskCreate(type)) {} + +bool FakeDevmapperTask::SetName(const std::string& name) { + task_->name = std::string(name); + return true; +} + +bool FakeDevmapperTask::AddTarget(uint64_t start, + uint64_t sectors, + const std::string& type, + const SecureBlob& parameters) { + DmTarget dmt; + dmt.start = start; + dmt.size = sectors; + dmt.type = type; + dmt.parameters = parameters; + task_->targets.push_back(dmt); + return true; +} + +bool FakeDevmapperTask::GetNextTarget(uint64_t* start, + uint64_t* sectors, + std::string* type, + SecureBlob* parameters) { + if (task_->targets.empty()) + return false; + + DmTarget dmt = task_->targets[0]; + *start = dmt.start; + *sectors = dmt.size; + *type = dmt.type; + *parameters = dmt.parameters; + task_->targets.erase(task_->targets.begin()); + + return !task_->targets.empty(); +} + +bool FakeDevmapperTask::Run(bool udev_sync) { + return StubDmRunTask(task_.get(), udev_sync); +} + +std::unique_ptr<DevmapperTask> CreateDevmapperTask(int type) { + return std::make_unique<FakeDevmapperTask>(type); +} + +} // namespace fake +} // namespace brillo diff --git a/brillo/blkdev_utils/device_mapper_fake.h b/brillo/blkdev_utils/device_mapper_fake.h new file mode 100644 index 0000000..bc4f28c --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_fake.h @@ -0,0 +1,65 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_FAKE_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_FAKE_H_ + +#include <memory> +#include <string> +#include <vector> + +#include <base/files/file_path.h> +#include <brillo/blkdev_utils/device_mapper.h> +#include <brillo/blkdev_utils/device_mapper_fake.h> +#include <brillo/blkdev_utils/device_mapper_task.h> +#include <brillo/secure_blob.h> + +namespace brillo { +namespace fake { + +// Fake implementation of dm_task primitives. +// ------------------------------------------ +// dm_task is an opaque type in libdevmapper so we +// define a minimal struct for DmTask and DmTarget +// to avoid linking in libdevmapper. +struct DmTarget { + uint64_t start; + uint64_t size; + std::string type; + SecureBlob parameters; +}; + +struct DmTask { + int type; + std::string name; + std::vector<DmTarget> targets; +}; + +// Fake task factory: creates fake tasks that +// stub task info into a map. +std::unique_ptr<DevmapperTask> CreateDevmapperTask(int type); + +class FakeDevmapperTask : public brillo::DevmapperTask { + public: + explicit FakeDevmapperTask(int type); + ~FakeDevmapperTask() override = default; + bool SetName(const std::string& name) override; + bool AddTarget(uint64_t start, + uint64_t sectors, + const std::string& target, + const SecureBlob& parameters) override; + bool GetNextTarget(uint64_t* start, + uint64_t* sectors, + std::string* target, + SecureBlob* parameters) override; + bool Run(bool udev_sync = true) override; + + private: + std::unique_ptr<DmTask> task_; +}; + +} // namespace fake +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_FAKE_H_ diff --git a/brillo/blkdev_utils/device_mapper_task.cc b/brillo/blkdev_utils/device_mapper_task.cc new file mode 100644 index 0000000..f2cbadd --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_task.cc @@ -0,0 +1,95 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/device_mapper_task.h> + +#include <libdevmapper.h> +#include <string> +#include <utility> + +#include <brillo/blkdev_utils/device_mapper.h> + +namespace brillo { + +DevmapperTaskImpl::DevmapperTaskImpl(int type) + : task_(DmTaskPtr(dm_task_create(type), &dm_task_destroy)) {} + +bool DevmapperTaskImpl::SetName(const std::string& name) { + if (!task_ || !dm_task_set_name(task_.get(), name.c_str())) { + LOG(ERROR) << "SetName failed"; + return false; + } + return true; +} + +bool DevmapperTaskImpl::AddTarget(uint64_t start, + uint64_t length, + const std::string& type, + const SecureBlob& parameters) { + // Strings stored in SecureBlob don't end with '\0'. Unfortunately, + // this causes accesses beyond the allocated storage space if any + // of the functions expecting a c-string get passed a SecureBlob.data(). + // Temporarily, assign to a string. + // TODO(sarthakkukreti): Evaluate creation of a SecureCString to keep + // string data safe. + std::string parameters_str = parameters.to_string(); + if (!task_ || + !dm_task_add_target(task_.get(), start, length, type.c_str(), + parameters_str.c_str())) { + LOG(ERROR) << "AddTarget failed"; + return false; + } + // Clear the string. + parameters_str.clear(); + return true; +} + +bool DevmapperTaskImpl::GetNextTarget(uint64_t* start, + uint64_t* length, + std::string* type, + SecureBlob* parameters) { + if (!task_) { + LOG(ERROR) << "GetNextTarget: invalid task."; + return false; + } + + char *type_cstr, *parameters_cstr; + next_target_ = dm_get_next_target(task_.get(), next_target_, start, length, + &type_cstr, ¶meters_cstr); + + if (type_cstr) + *type = std::string(type_cstr); + if (parameters_cstr) { + SecureBlob parameters_blob(parameters_cstr); + memset(parameters_cstr, 0, parameters_blob.size()); + *parameters = std::move(parameters_blob); + } + + return (next_target_ != nullptr); +} + +bool DevmapperTaskImpl::Run(bool udev_sync) { + uint32_t cookie = 0; + + if (!task_) { + LOG(ERROR) << "Invalid task."; + return false; + } + + if (udev_sync && !dm_task_set_cookie(task_.get(), &cookie, 0)) { + LOG(ERROR) << "dm_task_set_cookie failed"; + return false; + } + + if (!dm_task_run(task_.get())) { + LOG(ERROR) << "dm_task_run failed"; + return false; + } + + // Make sure the node exists before continuing. + // TODO(sarthakkukreti): move to dm_udev_wait_immediate() on uprevving lvm2. + return udev_sync ? (dm_udev_wait(cookie) != 0) : true; +} + +} // namespace brillo diff --git a/brillo/blkdev_utils/device_mapper_task.h b/brillo/blkdev_utils/device_mapper_task.h new file mode 100644 index 0000000..f8e45d4 --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_task.h @@ -0,0 +1,101 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_TASK_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_TASK_H_ + +#include <libdevmapper.h> +#include <memory> +#include <string> + +#include <brillo/secure_blob.h> + +namespace brillo { + +using DmTaskPtr = std::unique_ptr<dm_task, void (*)(dm_task*)>; + +// Abstract class to manage DM devices. +// This class implements the bare minimum set of functions +// required to create/remove DM devices. DevmapperTask is the equivalent +// of a command to the device mapper to set/get targets associated with a +// logical DM device, but omits, for now, finer-grained commands. +// A target represents a segment of a DM device. +// +// The abstract class is strictly based on the dm_task_* functions +// from libdevmapper, but the interface provides sufficient flexibility +// for other implementations (eg. invoking dmsetup) or testing facades. +// +// The task type enum is defined in libdevmapper.h: for simplicity, the same +// enum types are reused in fake implementations of DevmapperTask. +// The following task types have been tested with DeviceMapper functions: +// - DM_DEVICE_CREATE: used in DeviceMapper::Setup. +// - DM_DEVICE_REMOVE: used in DeviceMapper::Remove. +// - DM_DEVICE_TABLE: used in DeviceMapper::GetTable and +// DeviceMapper::WipeTable. +// - DM_DEVICE_RELOAD: used in DeviceMapper::WipeTable. +class DevmapperTask { + public: + virtual ~DevmapperTask() = default; + // Sets device name for the command. + virtual bool SetName(const std::string& name) = 0; + + // Adds a target to the command. Should be followed by a Run(); + // Parameters: + // start: start of target in device. + // sectors: number of sectors in the target. + // type: type of the target. + // parameters: target parameters. + virtual bool AddTarget(uint64_t start, + uint64_t sectors, + const std::string& type, + const SecureBlob& parameters) = 0; + // Gets the next target from the command. + // Returns true while another target exists. + // If no target exist for the device, GetNextTarget sets all + // parameters to 0 and returns false. + // + // Parameters: + // start: start of target in device. + // sectors: number of sectors in the target. + // type: type of the target. + // parameters: target parameters. + virtual bool GetNextTarget(uint64_t* start, + uint64_t* sectors, + std::string* type, + SecureBlob* parameters) = 0; + // Run the task. + // Returns true if the task succeeded. + // + // Parameters: + // udev_sync: Enable/Disable udev_synchronization. Defaults to false. + // Enable only for tasks that create/remove/rename files to + // prevent both udevd and libdevmapper from attempting to + // add or remove files. + virtual bool Run(bool udev_sync = false) = 0; +}; + +// Libdevmapper implementation for DevmapperTask. +class DevmapperTaskImpl : public DevmapperTask { + public: + explicit DevmapperTaskImpl(int type); + ~DevmapperTaskImpl() override = default; + bool SetName(const std::string& name) override; + bool AddTarget(uint64_t start, + uint64_t sectors, + const std::string& target, + const SecureBlob& parameters) override; + bool GetNextTarget(uint64_t* start, + uint64_t* sectors, + std::string* target, + SecureBlob* parameters) override; + bool Run(bool udev_sync = true) override; + + private: + DmTaskPtr task_; + void* next_target_ = nullptr; +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_DEVICE_MAPPER_TASK_H_ diff --git a/brillo/blkdev_utils/device_mapper_test.cc b/brillo/blkdev_utils/device_mapper_test.cc new file mode 100644 index 0000000..ab19092 --- /dev/null +++ b/brillo/blkdev_utils/device_mapper_test.cc @@ -0,0 +1,143 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <libdevmapper.h> + +#include <base/files/file_util.h> +#include <base/memory/ptr_util.h> +#include <base/strings/string_number_conversions.h> +#include <base/strings/string_split.h> +#include <brillo/blkdev_utils/device_mapper_fake.h> +#include <gtest/gtest.h> + +namespace brillo { + +TEST(DevmapperTableTest, CreateTableFromSecureBlobTest) { + SecureBlob crypt_table_str("0 100 crypt"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + EXPECT_EQ(DevmapperTable(0, 0, "", SecureBlob()).ToSecureBlob(), + dm_table.ToSecureBlob()); +} + +TEST(DevmapperTableTest, CryptCreateParametersTest) { + base::FilePath device("/some/random/filepath"); + + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + + SecureBlob crypt_parameters = DevmapperTable::CryptCreateParameters( + "aes-cbc-essiv:sha256", secret, 0, device, 0, true); + + DevmapperTable crypt_table(0, 100, "crypt", crypt_parameters); + + SecureBlob crypt_table_str( + "0 100 crypt aes-cbc-essiv:sha256 " + "0123456789ABCDEF 0 /some/random/filepath 0 1 " + "allow_discards"); + + EXPECT_EQ(crypt_table.ToSecureBlob().to_string(), + crypt_table_str.to_string()); +} + +TEST(DevmapperTableTest, CryptCreateTableFromSecureBlobTest) { + base::FilePath device("/some/random/filepath"); + + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + + SecureBlob crypt_parameters = DevmapperTable::CryptCreateParameters( + "aes-cbc-essiv:sha256", secret, 0, device, 0, true); + + DevmapperTable crypt_table(0, 100, "crypt", crypt_parameters); + + SecureBlob crypt_table_str( + "0 100 crypt aes-cbc-essiv:sha256 " + "0123456789ABCDEF 0 /some/random/filepath 0 1 " + "allow_discards"); + + DevmapperTable parsed_blob_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + + EXPECT_EQ(crypt_table.ToSecureBlob(), parsed_blob_table.ToSecureBlob()); +} + +TEST(DevmapperTableTest, CryptGetKeyTest) { + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + SecureBlob crypt_table_str( + "0 100 crypt aes-cbc-essiv:sha256 " + "0123456789ABCDEF 0 /some/random/filepath 0 1 " + "allow_discards"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + + EXPECT_EQ(secret, dm_table.CryptGetKey()); +} + +TEST(DevmapperTableTest, MalformedCryptTableTest) { + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + // Pass malformed crypt table string. + SecureBlob crypt_table_str( + "0 100 crypt ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ" + "ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ" + "ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ" + "ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + + EXPECT_EQ(SecureBlob(), dm_table.CryptGetKey()); +} + +TEST(DevmapperTableTest, GetterTest) { + SecureBlob verity_table( + "0 40 verity payload=/dev/loop6 hashtree=/dev/loop6 " + "hashstart=40 alg=sha256 root_hexdigest=" + "01234567 " + "salt=89ABCDEF " + "error_behavior=eio"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(verity_table); + + EXPECT_EQ(dm_table.GetStart(), 0); + EXPECT_EQ(dm_table.GetSize(), 40); + EXPECT_EQ(dm_table.GetType(), "verity"); + EXPECT_EQ(dm_table.GetParameters(), + SecureBlob("payload=/dev/loop6 hashtree=/dev/loop6 " + "hashstart=40 alg=sha256 root_hexdigest=01234567 " + "salt=89ABCDEF error_behavior=eio")); +} + +TEST(DevmapperTest, FakeTaskConformance) { + SecureBlob secret; + SecureBlob::HexStringToSecureBlob("0123456789ABCDEF", &secret); + SecureBlob crypt_table_str( + "0 100 crypt aes-cbc-essiv:sha256 " + "0123456789ABCDEF 0 /some/random/filepath 0 1 " + "allow_discards"); + + DevmapperTable dm_table = + DevmapperTable::CreateTableFromSecureBlob(crypt_table_str); + + EXPECT_EQ(secret, dm_table.CryptGetKey()); + DeviceMapper dm(base::Bind(&fake::CreateDevmapperTask)); + + // Add device. + EXPECT_TRUE(dm.Setup("abcd", dm_table)); + EXPECT_FALSE(dm.Setup("abcd", dm_table)); + DevmapperTable table = dm.GetTable("abcd"); + // Expect tables to be the same. + EXPECT_EQ(table.ToSecureBlob(), dm_table.ToSecureBlob()); + // Expect key to match. + EXPECT_EQ(table.CryptGetKey(), secret); + EXPECT_TRUE(dm.Remove("abcd")); + EXPECT_FALSE(dm.Remove("abcd")); +} + +} // namespace brillo diff --git a/brillo/blkdev_utils/loop_device.cc b/brillo/blkdev_utils/loop_device.cc new file mode 100644 index 0000000..2b2219d --- /dev/null +++ b/brillo/blkdev_utils/loop_device.cc @@ -0,0 +1,271 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/loop_device.h> + +#include <fcntl.h> +#include <linux/major.h> +#include <sys/ioctl.h> +#include <sys/types.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include <base/files/file_enumerator.h> +#include <base/files/file_util.h> +#include <base/files/scoped_file.h> +#include <base/posix/eintr_wrapper.h> +#include <base/strings/string_number_conversions.h> +#include <base/strings/string_split.h> +#include <base/strings/string_util.h> +#include <base/strings/stringprintf.h> + +namespace brillo { + +namespace { + +constexpr char kLoopControl[] = "/dev/loop-control"; +constexpr char kSysBlockPath[] = "/sys/block"; +// File containing device id in /sys/block/loopX/. +constexpr char kDeviceIdPath[] = "dev"; +constexpr char kLoopBackingFile[] = "loop/backing_file"; +constexpr int kLoopDeviceIoctlFlags = O_RDWR | O_NOFOLLOW | O_CLOEXEC; +constexpr int kLoopControlIoctlFlags = O_RDONLY | O_NOFOLLOW | O_CLOEXEC; + +// ioctl runner for LoopDevice and LoopDeviceManager +int LoopDeviceIoctl(const base::FilePath& device, + int type, + uint64_t arg, + int open_flag) { + base::ScopedFD device_fd( + HANDLE_EINTR(open(device.value().c_str(), open_flag))); + + if (!device_fd.is_valid()) { + PLOG(ERROR) << "Unable to open loop device"; + return -EINVAL; + } + + int rc = ioctl(device_fd.get(), type, arg); + + if (rc < 0) + PLOG(ERROR) << "ioctl failed."; + + return rc; +} + +// Parse the device number for a valid /sys/block/loopX path +// or symlink to such a path. +// Returns -1 if invalid. +int GetDeviceNumber(const base::FilePath& sys_block_loopdev_path) { + std::string device_string; + int device_number = -1; + + base::FilePath device_file = sys_block_loopdev_path.Append(kDeviceIdPath); + + if (!base::ReadFileToString(device_file, &device_string)) + return -1; + + std::vector<std::string> device_ids = base::SplitString( + device_string, ":", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY); + + if (device_ids.size() != 2 || + device_ids[0] != base::NumberToString(LOOP_MAJOR)) + return -1; + + base::StringToInt(device_ids[1], &device_number); + return device_number; +} + +// For a validated loop device path, return the backing file path. +// Note that a pre-populated loop device path would return an empty +// backing file. +base::FilePath GetBackingFile(const base::FilePath& loopdev_path) { + // Backing file contains path to associated source for loop devices. + base::FilePath backing_file = loopdev_path.Append(kLoopBackingFile); + std::string backing_file_content; + // If the backing file doesn't exist, it's not an attached loop device. + if (!base::ReadFileToString(backing_file, &backing_file_content)) + return base::FilePath(); + base::FilePath backing_file_path( + base::TrimWhitespaceASCII(backing_file_content, base::TRIM_ALL)); + + return backing_file_path; +} + +base::FilePath CreateDevicePath(int device_number) { + return base::FilePath(base::StringPrintf("/dev/loop%d", device_number)); +} + +} // namespace + +LoopDevice::LoopDevice(int device_number, + const base::FilePath& backing_file, + const LoopIoctl& ioctl_runner) + : device_number_(device_number), + backing_file_(backing_file), + loop_ioctl_(ioctl_runner) {} + +bool LoopDevice::SetStatus(struct loop_info64 info) { + if (loop_ioctl_.Run(GetDevicePath(), LOOP_SET_STATUS64, + reinterpret_cast<uint64_t>(&info), + kLoopDeviceIoctlFlags) < 0) { + LOG(ERROR) << "ioctl(LOOP_SET_STATUS64) failed"; + return false; + } + return true; +} + +bool LoopDevice::GetStatus(struct loop_info64* info) { + if (loop_ioctl_.Run(GetDevicePath(), LOOP_GET_STATUS64, + reinterpret_cast<uint64_t>(info), + kLoopDeviceIoctlFlags) < 0) { + LOG(ERROR) << "ioctl(LOOP_GET_STATUS64) failed"; + return false; + } + return true; +} + +bool LoopDevice::SetName(const std::string& name) { + struct loop_info64 info; + + memset(&info, 0, sizeof(info)); + strncpy(reinterpret_cast<char*>(info.lo_file_name), name.c_str(), + LO_NAME_SIZE); + return SetStatus(info); +} + +bool LoopDevice::Detach() { + if (loop_ioctl_.Run(GetDevicePath(), LOOP_CLR_FD, 0, kLoopDeviceIoctlFlags) != + 0) { + LOG(ERROR) << "ioctl(LOOP_CLR_FD) failed"; + return false; + } + + return true; +} + +base::FilePath LoopDevice::GetDevicePath() { + return CreateDevicePath(device_number_); +} + +bool LoopDevice::IsValid() { + return device_number_ >= 0; +} + +LoopDeviceManager::LoopDeviceManager() + : loop_ioctl_(base::Bind(&LoopDeviceIoctl)) {} + +LoopDeviceManager::LoopDeviceManager(LoopIoctl ioctl_runner) + : loop_ioctl_(ioctl_runner) {} + +std::unique_ptr<LoopDevice> LoopDeviceManager::AttachDeviceToFile( + const base::FilePath& backing_file) { + int device_number = -1; + while (true) { + device_number = + loop_ioctl_.Run(base::FilePath(kLoopControl), LOOP_CTL_GET_FREE, 0, + kLoopControlIoctlFlags); + + if (device_number < 0) { + LOG(ERROR) << "ioctl(LOOP_CTL_GET_FREE) failed"; + return CreateLoopDevice(-1, base::FilePath()); + } + + base::ScopedFD backing_file_fd( + HANDLE_EINTR(open(backing_file.value().c_str(), O_RDWR))); + + if (!backing_file_fd.is_valid()) { + LOG(ERROR) << "Failed to open backing file."; + return CreateLoopDevice(-1, base::FilePath()); + } + + base::FilePath device_path = CreateDevicePath(device_number); + + if (loop_ioctl_.Run(device_path, LOOP_SET_FD, backing_file_fd.get(), + kLoopDeviceIoctlFlags) == 0) + break; + + if (errno != EBUSY) { + LOG(ERROR) << "ioctl(LOOP_SET_FD) failed"; + return CreateLoopDevice(-1, base::FilePath()); + } + } + // All steps of setting up the loop device succeeded. + return CreateLoopDevice(device_number, backing_file); +} + +std::vector<std::unique_ptr<LoopDevice>> +LoopDeviceManager::GetAttachedDevices() { + return SearchLoopDevicePaths(); +} + +std::unique_ptr<LoopDevice> LoopDeviceManager::GetAttachedDeviceByNumber( + int device_number) { + auto devices = SearchLoopDevicePaths(device_number); + + if (devices.empty()) + return CreateLoopDevice(-1, base::FilePath()); + + return std::move(devices[0]); +} + +std::unique_ptr<LoopDevice> LoopDeviceManager::GetAttachedDeviceByName( + const std::string& name) { + std::vector<std::unique_ptr<LoopDevice>> devices = GetAttachedDevices(); + + for (auto& attached_device : devices) { + struct loop_info64 device_info; + + if (!attached_device->GetStatus(&device_info)) { + LOG(ERROR) << "GetStatus failed"; + continue; + } + + if (strcmp(reinterpret_cast<char*>(device_info.lo_file_name), + name.c_str()) == 0) + return std::move(attached_device); + } + + return CreateLoopDevice(-1, base::FilePath()); +} + +// virtual +std::vector<std::unique_ptr<LoopDevice>> +LoopDeviceManager::SearchLoopDevicePaths(int device_number) { + std::vector<std::unique_ptr<LoopDevice>> devices; + base::FilePath rootdir(kSysBlockPath); + + if (device_number != -1) { + auto loopdev_path = + rootdir.Append(base::StringPrintf("loop%d", device_number)); + if (base::PathExists(loopdev_path)) + devices.push_back( + CreateLoopDevice(device_number, GetBackingFile(loopdev_path))); + } else { + // Read /sys/block to discover all loop devices. + base::FileEnumerator loopdev_enum( + rootdir, false /*recursive*/, + base::FileEnumerator::FILES | base::FileEnumerator::SHOW_SYM_LINKS, + "loop*"); + + for (auto loopdev = loopdev_enum.Next(); !loopdev.empty(); + loopdev = loopdev_enum.Next()) { + int dev_number = GetDeviceNumber(loopdev); + if (dev_number != -1) + devices.push_back( + CreateLoopDevice(dev_number, GetBackingFile(loopdev))); + } + } + return devices; +} + +std::unique_ptr<LoopDevice> LoopDeviceManager::CreateLoopDevice( + int device_number, const base::FilePath& backing_file) { + return std::make_unique<LoopDevice>(device_number, backing_file, loop_ioctl_); +} + +} // namespace brillo diff --git a/brillo/blkdev_utils/loop_device.h b/brillo/blkdev_utils/loop_device.h new file mode 100644 index 0000000..aba19cc --- /dev/null +++ b/brillo/blkdev_utils/loop_device.h @@ -0,0 +1,117 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_H_ + +#include <linux/loop.h> +#include <memory> +#include <string> +#include <vector> + +#include <base/bind.h> +#include <base/callback.h> +#include <base/files/file_path.h> +#include <brillo/secure_blob.h> + +namespace brillo { + +// Forward declaration used by LoopDevice. +class LoopDeviceManager; + +using LoopIoctl = + base::Callback<int(const base::FilePath&, int, uint64_t, int)>; + +// LoopDevice provides an interface to attached loop devices. +// In order to simplify handling of loop devices, there +// is no inherent modifiable state associated within objects: +// the device number and backing file are consts. +// The intent here is for no class to create a LoopDevice +// directly; instead use LoopDeviceManager to get devices. +class BRILLO_EXPORT LoopDevice { + public: + // Create a loop device with a ioctl runner. + // Parameters + // device_number - loop device number. + // backing_file - backing file for the device. + // ioctl_runner - function to run loop ioctls. + LoopDevice(int device_number, + const base::FilePath& backing_file, + const LoopIoctl& ioctl_runner); + ~LoopDevice() = default; + + // Set device status. + // Parameters + // info - struct containing status. + bool SetStatus(struct loop_info64 info); + // Get device status. + // Parameters + // info - struct to populate. + bool GetStatus(struct loop_info64* info); + // Set device name. + // Parameters + // name - device name + bool SetName(const std::string& name); + // Detach device. + bool Detach(); + // Check if device is valid; + bool IsValid(); + + // Getters for device parameters. + base::FilePath GetBackingFilePath() { return backing_file_; } + base::FilePath GetDevicePath(); + + private: + const int device_number_; + const base::FilePath backing_file_; + // Ioctl runner. + LoopIoctl loop_ioctl_; +}; + +// Loop Device Manager handles requests for creating or fetching +// existing loop devices. If creation/fetch fails, the loop device +// manager returns nullptr. +class BRILLO_EXPORT LoopDeviceManager { + public: + LoopDeviceManager(); + // Create a loop device manager with a non-default ioctl runner. + // Parameters + // ioctl_runner - base::Callback to run ioctls. + explicit LoopDeviceManager(LoopIoctl ioctl_runner); + virtual ~LoopDeviceManager() = default; + + // Allocates a loop device and attaches it to a backing file. + // Parameters + // backing_file - file to attach device to. + virtual std::unique_ptr<LoopDevice> AttachDeviceToFile( + const base::FilePath& backing_file); + + // Fetches all attached loop devices. + std::vector<std::unique_ptr<LoopDevice>> GetAttachedDevices(); + + // Fetches a loop device by device number. + std::unique_ptr<LoopDevice> GetAttachedDeviceByNumber(int device_number); + + // Fetches a device number by name. + std::unique_ptr<LoopDevice> GetAttachedDeviceByName(const std::string& name); + + private: + // Search for loop devices by device number; if no device number is given, + // default to searaching and returning all loop devices. + virtual std::vector<std::unique_ptr<LoopDevice>> SearchLoopDevicePaths( + int device_number = -1); + // Create loop device with current ioctl runner. + // Parameters + // device_number - device number. + // backing_file - path to backing file. + std::unique_ptr<LoopDevice> CreateLoopDevice( + int device_number, const base::FilePath& backing_file); + + LoopIoctl loop_ioctl_; + DISALLOW_COPY_AND_ASSIGN(LoopDeviceManager); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_H_ diff --git a/brillo/blkdev_utils/loop_device_fake.cc b/brillo/blkdev_utils/loop_device_fake.cc new file mode 100644 index 0000000..a181aad --- /dev/null +++ b/brillo/blkdev_utils/loop_device_fake.cc @@ -0,0 +1,148 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/loop_device_fake.h> + +#include <linux/loop.h> +#include <memory> +#include <string> +#include <vector> + +#include <base/strings/string_number_conversions.h> +#include <base/strings/string_split.h> +#include <base/strings/string_util.h> +#include <base/strings/stringprintf.h> +#include <brillo/blkdev_utils/loop_device.h> + +// Not a loop ioctl: we only use this to get the backing file from +// the stubbed function. All loop device ioctls start with 0x4c. +#define LOOP_GET_DEV 0x4cff + +namespace brillo { +namespace fake { + +namespace { + +int ParseLoopDeviceNumber(const base::FilePath& device_path) { + int device_number; + std::string path_string = device_path.value(); + return base::StartsWith(path_string, "/dev/loop", + base::CompareCase::SENSITIVE) && + base::StringToInt(path_string.substr(9), &device_number) + ? device_number + : -1; +} + +base::FilePath GetLoopDevicePath(int device_number) { + return base::FilePath(base::StringPrintf("/dev/loop%d", device_number)); +} + +int StubIoctlRunner(const base::FilePath& path, + int type, + uint64_t arg, + int flag) { + int device_number = ParseLoopDeviceNumber(path); + struct loop_info64* info; + struct LoopDev* device; + static std::vector<struct LoopDev>& loop_device_vector = + *new std::vector<struct LoopDev>(); + + switch (type) { + case LOOP_GET_STATUS64: + if (loop_device_vector.size() <= device_number || + loop_device_vector[device_number].valid == false) + return -1; + info = reinterpret_cast<struct loop_info64*>(arg); + memcpy(info, &loop_device_vector[device_number].info, + sizeof(struct loop_info64)); + return 0; + case LOOP_SET_STATUS64: + if (loop_device_vector.size() <= device_number || + loop_device_vector[device_number].valid == false) + return -1; + info = reinterpret_cast<struct loop_info64*>(arg); + memcpy(&loop_device_vector[device_number].info, info, + sizeof(struct loop_info64)); + return 0; + case LOOP_CLR_FD: + if (loop_device_vector.size() <= device_number || + loop_device_vector[device_number].valid == false) + return -1; + loop_device_vector[device_number].valid = false; + return 0; + case LOOP_CTL_GET_FREE: + device_number = loop_device_vector.size(); + loop_device_vector.push_back({true, base::FilePath(), {0}}); + return device_number; + // Instead of passing the fd here, we pass the FilePath of the backing + // file. + case LOOP_SET_FD: + if (loop_device_vector.size() <= device_number) + return -1; + loop_device_vector[device_number].backing_file = + *reinterpret_cast<const base::FilePath*>(arg); + return 0; + // Not a loop ioctl; Only used for conveniently checking the + // validity of the loop devices. + case LOOP_GET_DEV: + if (device_number >= loop_device_vector.size()) + return -1; + device = reinterpret_cast<struct LoopDev*>(arg); + device->valid = loop_device_vector[device_number].valid; + device->backing_file = loop_device_vector[device_number].backing_file; + memset(&(device->info), 0, sizeof(struct loop_info64)); + return 0; + default: + return -1; + } +} + +} // namespace + +FakeLoopDeviceManager::FakeLoopDeviceManager() + : LoopDeviceManager(base::Bind(&StubIoctlRunner)) {} + +std::unique_ptr<LoopDevice> FakeLoopDeviceManager::AttachDeviceToFile( + const base::FilePath& backing_file) { + int device_number = StubIoctlRunner(base::FilePath("/dev/loop-control"), + LOOP_CTL_GET_FREE, 0, 0); + + if (StubIoctlRunner(GetLoopDevicePath(device_number), LOOP_SET_FD, + reinterpret_cast<uint64_t>(&backing_file), 0) < 0) + return std::make_unique<LoopDevice>(-1, base::FilePath(), + base::Bind(&StubIoctlRunner)); + + return std::make_unique<LoopDevice>(device_number, backing_file, + base::Bind(&StubIoctlRunner)); +} + +std::vector<std::unique_ptr<LoopDevice>> +FakeLoopDeviceManager::SearchLoopDevicePaths(int device_number) { + std::vector<std::unique_ptr<LoopDevice>> devices; + struct LoopDev device; + + if (device_number != -1) { + if (StubIoctlRunner(GetLoopDevicePath(device_number), LOOP_GET_DEV, + reinterpret_cast<uint64_t>(&device), 0) < 0) + return devices; + + if (device.valid) + devices.push_back(std::make_unique<LoopDevice>( + device_number, device.backing_file, base::Bind(&StubIoctlRunner))); + return devices; + } + + int i = 0; + while (StubIoctlRunner(GetLoopDevicePath(i), LOOP_GET_DEV, + reinterpret_cast<uint64_t>(&device), 0) == 0) { + if (device.valid) + devices.push_back(std::make_unique<LoopDevice>( + i, device.backing_file, base::Bind(&StubIoctlRunner))); + i++; + } + return devices; +} + +} // namespace fake +} // namespace brillo diff --git a/brillo/blkdev_utils/loop_device_fake.h b/brillo/blkdev_utils/loop_device_fake.h new file mode 100644 index 0000000..751aa96 --- /dev/null +++ b/brillo/blkdev_utils/loop_device_fake.h @@ -0,0 +1,37 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_FAKE_H_ +#define LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_FAKE_H_ + +#include <memory> +#include <vector> + +#include <brillo/blkdev_utils/loop_device.h> + +namespace brillo { +namespace fake { + +struct LoopDev { + bool valid; + base::FilePath backing_file; + struct loop_info64 info; +}; + +class BRILLO_EXPORT FakeLoopDeviceManager : public brillo::LoopDeviceManager { + public: + FakeLoopDeviceManager(); + ~FakeLoopDeviceManager() override = default; + std::unique_ptr<LoopDevice> AttachDeviceToFile( + const base::FilePath& backing_file) override; + + private: + std::vector<std::unique_ptr<LoopDevice>> SearchLoopDevicePaths( + int device_number = -1) override; +}; + +} // namespace fake +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_BLKDEV_UTILS_LOOP_DEVICE_FAKE_H_ diff --git a/brillo/blkdev_utils/loop_device_test.cc b/brillo/blkdev_utils/loop_device_test.cc new file mode 100644 index 0000000..920ad68 --- /dev/null +++ b/brillo/blkdev_utils/loop_device_test.cc @@ -0,0 +1,57 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/blkdev_utils/loop_device_fake.h> + +#include <base/files/file_util.h> +#include <gtest/gtest.h> + +namespace brillo { + +TEST(LoopDeviceTest, GeneralTest) { + base::FilePath loop_backing_file; + base::CreateTemporaryFile(&loop_backing_file); + fake::FakeLoopDeviceManager loop_manager; + + // Create a new device + std::unique_ptr<LoopDevice> device = + loop_manager.AttachDeviceToFile(loop_backing_file); + std::unique_ptr<LoopDevice> device1 = + loop_manager.AttachDeviceToFile(loop_backing_file); + std::unique_ptr<LoopDevice> device2 = + loop_manager.AttachDeviceToFile(loop_backing_file); + + EXPECT_TRUE(device->IsValid()); + EXPECT_TRUE(device1->IsValid()); + EXPECT_TRUE(device2->IsValid()); + + std::vector<std::unique_ptr<LoopDevice>> attached_devices = + loop_manager.GetAttachedDevices(); + + // Expect 3 devices + EXPECT_EQ(attached_devices.size(), 3); + + device2->SetName("Loopy"); + + std::unique_ptr<LoopDevice> device1_copy = + loop_manager.GetAttachedDeviceByNumber(1); + EXPECT_TRUE(device1_copy->IsValid()); + EXPECT_EQ(device1->GetDevicePath(), device1_copy->GetDevicePath()); + EXPECT_EQ(device1->GetBackingFilePath(), device1_copy->GetBackingFilePath()); + + std::unique_ptr<LoopDevice> device2_copy = + loop_manager.GetAttachedDeviceByName("Loopy"); + EXPECT_TRUE(device2_copy->IsValid()); + EXPECT_EQ(device2->GetDevicePath(), device2_copy->GetDevicePath()); + EXPECT_EQ(device2->GetBackingFilePath(), device2_copy->GetBackingFilePath()); + + // Check double detach + EXPECT_TRUE(device->Detach()); + EXPECT_TRUE(device1->Detach()); + EXPECT_FALSE(device1_copy->Detach()); + EXPECT_TRUE(device2->Detach()); + EXPECT_FALSE(device2_copy->Detach()); +} + +} // namespace brillo diff --git a/brillo/cryptohome.cc b/brillo/cryptohome.cc index 88e4739..a82356e 100644 --- a/brillo/cryptohome.cc +++ b/brillo/cryptohome.cc @@ -41,7 +41,7 @@ static char g_system_salt_path[PATH_MAX] = "/home/.shadow/salt"; static std::string* salt = nullptr; -static bool EnsureSystemSaltIsLoaded() { +bool EnsureSystemSaltIsLoaded() { if (salt && !salt->empty()) return true; FilePath salt_path(g_system_salt_path); diff --git a/brillo/cryptohome.h b/brillo/cryptohome.h index 798d3a0..a9d5927 100644 --- a/brillo/cryptohome.h +++ b/brillo/cryptohome.h @@ -74,6 +74,9 @@ BRILLO_EXPORT void SetSystemSalt(std::string* salt); // Returns the system salt. BRILLO_EXPORT std::string* GetSystemSalt(); +// Ensures the system salt is loaded in the memory. +BRILLO_EXPORT bool EnsureSystemSaltIsLoaded(); + } // namespace home } // namespace cryptohome } // namespace brillo diff --git a/brillo/daemons/daemon.cc b/brillo/daemons/daemon.cc index 1b3d6d2..82de826 100644 --- a/brillo/daemons/daemon.cc +++ b/brillo/daemons/daemon.cc @@ -4,6 +4,7 @@ #include <brillo/daemons/daemon.h> +#include <signal.h> #include <sysexits.h> #include <time.h> @@ -14,7 +15,7 @@ namespace brillo { -Daemon::Daemon() : exit_code_{EX_OK} { +Daemon::Daemon() : exit_code_{EX_OK}, exiting_(false) { message_loop_.SetAsCurrent(); } @@ -27,7 +28,7 @@ int Daemon::Run() { return exit_code; message_loop_.PostTask( - base::Bind(&Daemon::OnEventLoopStartedTask, base::Unretained(this))); + base::BindOnce(&Daemon::OnEventLoopStartedTask, base::Unretained(this))); message_loop_.Run(); OnShutdown(&exit_code_); @@ -85,15 +86,27 @@ bool Daemon::OnRestart() { } bool Daemon::Shutdown(const signalfd_siginfo& /* info */) { - Quit(); - return true; // Unregister the signal handler. + // Only respond to the first call. + if (!exiting_) { + exiting_ = true; + Quit(); + } + // Always return false, to avoid unregistering the signal handler. We might + // receive multiple successive signals, and we don't want to take the default + // response (termination) while we're still tearing down. + return false; } bool Daemon::Restart(const signalfd_siginfo& /* info */) { - if (OnRestart()) - return false; // Keep listening to the signal. - Quit(); - return true; // Unregister the signal handler. + if (!exiting_ && !OnRestart()) { + // Only Quit() once. + exiting_ = true; + Quit(); + } + // Always return false, to avoid unregistering the signal handler. We might + // receive multiple successive signals, and we don't want to take the default + // response (termination) while we're still tearing down. + return false; } void Daemon::OnEventLoopStartedTask() { diff --git a/brillo/daemons/daemon.h b/brillo/daemons/daemon.h index a16e04a..499b609 100644 --- a/brillo/daemons/daemon.h +++ b/brillo/daemons/daemon.h @@ -114,6 +114,8 @@ class BRILLO_EXPORT Daemon : public AsynchronousSignalHandlerInterface { AsynchronousSignalHandler async_signal_handler_; // Process exit code specified in QuitWithExitCode() method call. int exit_code_; + // Daemon is in the process of exiting. + bool exiting_; DISALLOW_COPY_AND_ASSIGN(Daemon); }; diff --git a/brillo/daemons/dbus_daemon.h b/brillo/daemons/dbus_daemon.h index 25ce306..2017e7f 100644 --- a/brillo/daemons/dbus_daemon.h +++ b/brillo/daemons/dbus_daemon.h @@ -37,7 +37,7 @@ class BRILLO_EXPORT DBusDaemon : public Daemon { // A reference to the |dbus_connection_| bus object often used by derived // classes. - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; private: DBusConnection dbus_connection_; @@ -59,7 +59,7 @@ class BRILLO_EXPORT DBusServiceDaemon : public DBusDaemon { // not created and is not available as part of the D-Bus service. explicit DBusServiceDaemon(const std::string& service_name); DBusServiceDaemon(const std::string& service_name, - const dbus::ObjectPath& object_manager_path); + const ::dbus::ObjectPath& object_manager_path); DBusServiceDaemon(const std::string& service_name, base::StringPiece object_manager_path); @@ -76,7 +76,7 @@ class BRILLO_EXPORT DBusServiceDaemon : public DBusDaemon { dbus_utils::AsyncEventSequencer* sequencer); std::string service_name_; - dbus::ObjectPath object_manager_path_; + ::dbus::ObjectPath object_manager_path_; std::unique_ptr<dbus_utils::ExportedObjectManager> object_manager_; private: diff --git a/brillo/data_encoding_fuzzer.cc b/brillo/data_encoding_fuzzer.cc new file mode 100644 index 0000000..8d5d41e --- /dev/null +++ b/brillo/data_encoding_fuzzer.cc @@ -0,0 +1,72 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <cstddef> +#include <cstdint> +#include <cstdio> + +#include <brillo/data_encoding.h> + +#include <base/logging.h> +#include <fuzzer/FuzzedDataProvider.h> + +namespace { +constexpr int kMaxStringLength = 256; +constexpr int kMaxParamsSize = 8; + +void FuzzUrlEncodeDecode(FuzzedDataProvider* provider) { + brillo::data_encoding::UrlEncode( + provider->ConsumeRandomLengthString(kMaxStringLength).c_str(), + provider->ConsumeBool()); + + brillo::data_encoding::UrlDecode( + provider->ConsumeRandomLengthString(kMaxStringLength).c_str()); +} + +void FuzzWebParamsEncodeDecode(FuzzedDataProvider* provider) { + brillo::data_encoding::WebParamList param_list; + const auto num_params = provider->ConsumeIntegralInRange(0, kMaxParamsSize); + for (auto i = 0; i < num_params; i++) { + param_list.push_back(std::pair<std::string, std::string>( + provider->ConsumeRandomLengthString(kMaxStringLength), + provider->ConsumeRandomLengthString(kMaxStringLength))); + } + brillo::data_encoding::WebParamsEncode(param_list, provider->ConsumeBool()); + + brillo::data_encoding::WebParamsDecode( + provider->ConsumeRandomLengthString(kMaxStringLength)); +} + +void FuzzBase64EncodeDecode(FuzzedDataProvider* provider) { + brillo::data_encoding::Base64Encode( + provider->ConsumeRandomLengthString(kMaxStringLength)); + brillo::Blob output; + brillo::data_encoding::Base64Decode( + provider->ConsumeRandomLengthString(kMaxStringLength), &output); +} + +bool IgnoreLogging(int, const char*, int, size_t, const std::string&) { + return true; +} + +} // namespace + +class Environment { + public: + Environment() { + // Disable logging. Normally this would be done with logging::SetMinLogLevel + // but that doesn't work for brillo::Error because it's not using the + // LOG(ERROR) macro which is where the actual log level check occurs. + logging::SetLogMessageHandler(&IgnoreLogging); + } +}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + static Environment env; + FuzzedDataProvider data_provider(data, size); + FuzzUrlEncodeDecode(&data_provider); + FuzzWebParamsEncodeDecode(&data_provider); + FuzzBase64EncodeDecode(&data_provider); + return 0; +} diff --git a/brillo/data_encoding_unittest.cc b/brillo/data_encoding_test.cc index cb73da6..cb73da6 100644 --- a/brillo/data_encoding_unittest.cc +++ b/brillo/data_encoding_test.cc diff --git a/brillo/dbus/async_event_sequencer.cc b/brillo/dbus/async_event_sequencer.cc index 8861e21..5cdf36d 100644 --- a/brillo/dbus/async_event_sequencer.cc +++ b/brillo/dbus/async_event_sequencer.cc @@ -4,6 +4,9 @@ #include <brillo/dbus/async_event_sequencer.h> +#include <base/bind.h> +#include <base/callback.h> + namespace brillo { namespace dbus_utils { diff --git a/brillo/dbus/async_event_sequencer.h b/brillo/dbus/async_event_sequencer.h index c817b55..cc532e6 100644 --- a/brillo/dbus/async_event_sequencer.h +++ b/brillo/dbus/async_event_sequencer.h @@ -9,7 +9,7 @@ #include <string> #include <vector> -#include <base/bind.h> +#include <base/callback_forward.h> #include <base/macros.h> #include <base/memory/ref_counted.h> #include <brillo/brillo_export.h> diff --git a/brillo/dbus/async_event_sequencer_unittest.cc b/brillo/dbus/async_event_sequencer_test.cc index 5f4c0e2..1026afe 100644 --- a/brillo/dbus/async_event_sequencer_unittest.cc +++ b/brillo/dbus/async_event_sequencer_test.cc @@ -4,6 +4,7 @@ #include <brillo/dbus/async_event_sequencer.h> +#include <base/bind.h> #include <base/bind_helpers.h> #include <gmock/gmock.h> #include <gtest/gtest.h> @@ -22,7 +23,7 @@ const char kTestMethod2[] = "TestMethod2"; class AsyncEventSequencerTest : public ::testing::Test { public: - MOCK_METHOD1(HandleCompletion, void(bool all_succeeded)); + MOCK_METHOD(void, HandleCompletion, (bool)); void SetUp() { aec_ = new AsyncEventSequencer(); diff --git a/brillo/dbus/data_serialization.cc b/brillo/dbus/data_serialization.cc index 4cae471..5f47d67 100644 --- a/brillo/dbus/data_serialization.cc +++ b/brillo/dbus/data_serialization.cc @@ -232,6 +232,9 @@ bool PopArrayValueFromReader(dbus::MessageReader* reader, else if (signature == "a(uu)") return PopTypedArrayFromReader< std::tuple<uint32_t, uint32_t>>(reader, value); + else if (signature == "a(ubay)") + return PopTypedArrayFromReader< + std::tuple<uint32_t, bool, std::vector<uint8_t>>>(reader, value); // When a use case for particular array signature is found, feel free // to add handing for it here. @@ -256,6 +259,9 @@ bool PopStructValueFromReader(dbus::MessageReader* reader, else if (signature == "(uu)") return PopTypedValueFromReader<std::tuple<uint32_t, uint32_t>>(reader, value); + else if (signature == "(ua{sv})") + return PopTypedValueFromReader< + std::tuple<uint32_t, brillo::VariantDictionary>>(reader, value); // When a use case for particular struct signature is found, feel free // to add handing for it here. @@ -314,7 +320,6 @@ bool PopValueFromReader(dbus::MessageReader* reader, brillo::Any* value) { LOG(FATAL) << "Unknown D-Bus data type: " << variant_reader.GetDataType(); return false; } - return true; } } // namespace dbus_utils diff --git a/brillo/dbus/data_serialization.h b/brillo/dbus/data_serialization.h index 1600919..a4f49c1 100644 --- a/brillo/dbus/data_serialization.h +++ b/brillo/dbus/data_serialization.h @@ -49,7 +49,7 @@ // - static void Write(dbus::MessageWriter* writer, const CustomType& value); // - static bool Read(dbus::MessageReader* reader, CustomType* value); // See an example in DBusUtils.CustomStruct unit test in -// brillo/dbus/data_serialization_unittest.cc. +// brillo/dbus/data_serialization_test.cc. #include <map> #include <memory> @@ -125,16 +125,16 @@ struct IsTypeSupported<> : public std::false_type {}; // Write the |value| of type T to D-Bus message. // Explicitly delete the overloads for scalar types that are not supported by // D-Bus. -void AppendValueToWriter(dbus::MessageWriter* writer, char value) = delete; -void AppendValueToWriter(dbus::MessageWriter* writer, float value) = delete; +void AppendValueToWriter(::dbus::MessageWriter* writer, char value) = delete; +void AppendValueToWriter(::dbus::MessageWriter* writer, float value) = delete; //---------------------------------------------------------------------------- // PopValueFromReader<T>(dbus::MessageWriter* writer, T* value) // Reads the |value| of type T from D-Bus message. // Explicitly delete the overloads for scalar types that are not supported by // D-Bus. -void PopValueFromReader(dbus::MessageReader* reader, char* value) = delete; -void PopValueFromReader(dbus::MessageReader* reader, float* value) = delete; +void PopValueFromReader(::dbus::MessageReader* reader, char* value) = delete; +void PopValueFromReader(::dbus::MessageReader* reader, float* value) = delete; //---------------------------------------------------------------------------- // Get D-Bus data signature from C++ data types. @@ -153,9 +153,9 @@ namespace details { // into the Variant and updates the |*reader_ref| with the transient // |variant_reader| MessageReader instance passed in. // Returns false if it fails to descend into the Variant. -inline bool DescendIntoVariantIfPresent(dbus::MessageReader** reader_ref, - dbus::MessageReader* variant_reader) { - if ((*reader_ref)->GetDataType() != dbus::Message::VARIANT) +inline bool DescendIntoVariantIfPresent(::dbus::MessageReader** reader_ref, + ::dbus::MessageReader* variant_reader) { + if ((*reader_ref)->GetDataType() != ::dbus::Message::VARIANT) return true; if (!(*reader_ref)->PopVariant(variant_reader)) return false; @@ -187,198 +187,198 @@ inline std::string GetDBusDictEntryType() { // DBusType<T> for various C++ types that can be serialized over D-Bus. // bool ----------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - bool value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - bool* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + bool value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + bool* value); template<> struct DBusType<bool> { inline static std::string GetSignature() { return DBUS_TYPE_BOOLEAN_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, bool value) { + inline static void Write(::dbus::MessageWriter* writer, bool value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, bool* value) { + inline static bool Read(::dbus::MessageReader* reader, bool* value) { return PopValueFromReader(reader, value); } }; // uint8_t -------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - uint8_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - uint8_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + uint8_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + uint8_t* value); template<> struct DBusType<uint8_t> { inline static std::string GetSignature() { return DBUS_TYPE_BYTE_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, uint8_t value) { + inline static void Write(::dbus::MessageWriter* writer, uint8_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, uint8_t* value) { + inline static bool Read(::dbus::MessageReader* reader, uint8_t* value) { return PopValueFromReader(reader, value); } }; // int16_t -------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - int16_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - int16_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + int16_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + int16_t* value); template<> struct DBusType<int16_t> { inline static std::string GetSignature() { return DBUS_TYPE_INT16_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, int16_t value) { + inline static void Write(::dbus::MessageWriter* writer, int16_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, int16_t* value) { + inline static bool Read(::dbus::MessageReader* reader, int16_t* value) { return PopValueFromReader(reader, value); } }; // uint16_t ------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - uint16_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - uint16_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + uint16_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + uint16_t* value); template<> struct DBusType<uint16_t> { inline static std::string GetSignature() { return DBUS_TYPE_UINT16_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, uint16_t value) { + inline static void Write(::dbus::MessageWriter* writer, uint16_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, uint16_t* value) { + inline static bool Read(::dbus::MessageReader* reader, uint16_t* value) { return PopValueFromReader(reader, value); } }; // int32_t -------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - int32_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - int32_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + int32_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + int32_t* value); template<> struct DBusType<int32_t> { inline static std::string GetSignature() { return DBUS_TYPE_INT32_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, int32_t value) { + inline static void Write(::dbus::MessageWriter* writer, int32_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, int32_t* value) { + inline static bool Read(::dbus::MessageReader* reader, int32_t* value) { return PopValueFromReader(reader, value); } }; // uint32_t ------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - uint32_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - uint32_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + uint32_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + uint32_t* value); template<> struct DBusType<uint32_t> { inline static std::string GetSignature() { return DBUS_TYPE_UINT32_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, uint32_t value) { + inline static void Write(::dbus::MessageWriter* writer, uint32_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, uint32_t* value) { + inline static bool Read(::dbus::MessageReader* reader, uint32_t* value) { return PopValueFromReader(reader, value); } }; // int64_t -------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - int64_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - int64_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + int64_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + int64_t* value); template<> struct DBusType<int64_t> { inline static std::string GetSignature() { return DBUS_TYPE_INT64_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, int64_t value) { + inline static void Write(::dbus::MessageWriter* writer, int64_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, int64_t* value) { + inline static bool Read(::dbus::MessageReader* reader, int64_t* value) { return PopValueFromReader(reader, value); } }; // uint64_t ------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - uint64_t value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - uint64_t* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + uint64_t value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + uint64_t* value); template<> struct DBusType<uint64_t> { inline static std::string GetSignature() { return DBUS_TYPE_UINT64_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, uint64_t value) { + inline static void Write(::dbus::MessageWriter* writer, uint64_t value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, uint64_t* value) { + inline static bool Read(::dbus::MessageReader* reader, uint64_t* value) { return PopValueFromReader(reader, value); } }; // double --------------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - double value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - double* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + double value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + double* value); template<> struct DBusType<double> { inline static std::string GetSignature() { return DBUS_TYPE_DOUBLE_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, double value) { + inline static void Write(::dbus::MessageWriter* writer, double value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, double* value) { + inline static bool Read(::dbus::MessageReader* reader, double* value) { return PopValueFromReader(reader, value); } }; // std::string ---------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const std::string& value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - std::string* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const std::string& value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + std::string* value); template<> struct DBusType<std::string> { inline static std::string GetSignature() { return DBUS_TYPE_STRING_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::string& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, std::string* value) { + inline static bool Read(::dbus::MessageReader* reader, std::string* value) { return PopValueFromReader(reader, value); } }; // const char* -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const char* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const char* value); template<> struct DBusType<const char*> { inline static std::string GetSignature() { return DBUS_TYPE_STRING_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, const char* value) { + inline static void Write(::dbus::MessageWriter* writer, const char* value) { AppendValueToWriter(writer, value); } }; @@ -389,44 +389,44 @@ struct DBusType<const char[]> { inline static std::string GetSignature() { return DBUS_TYPE_STRING_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, const char* value) { + inline static void Write(::dbus::MessageWriter* writer, const char* value) { AppendValueToWriter(writer, value); } }; // dbus::ObjectPath ----------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const dbus::ObjectPath& value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - dbus::ObjectPath* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const ::dbus::ObjectPath& value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + ::dbus::ObjectPath* value); -template<> -struct DBusType<dbus::ObjectPath> { +template <> +struct DBusType<::dbus::ObjectPath> { inline static std::string GetSignature() { return DBUS_TYPE_OBJECT_PATH_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, - const dbus::ObjectPath& value) { + inline static void Write(::dbus::MessageWriter* writer, + const ::dbus::ObjectPath& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, - dbus::ObjectPath* value) { + inline static bool Read(::dbus::MessageReader* reader, + ::dbus::ObjectPath* value) { return PopValueFromReader(reader, value); } }; // brillo::dbus_utils::FileDescriptor/base::ScopedFD -------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const FileDescriptor& value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - base::ScopedFD* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const FileDescriptor& value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + base::ScopedFD* value); template<> struct DBusType<FileDescriptor> { inline static std::string GetSignature() { return DBUS_TYPE_UNIX_FD_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const FileDescriptor& value) { AppendValueToWriter(writer, value); } @@ -437,38 +437,37 @@ struct DBusType<base::ScopedFD> { inline static std::string GetSignature() { return DBUS_TYPE_UNIX_FD_AS_STRING; } - inline static bool Read(dbus::MessageReader* reader, + inline static bool Read(::dbus::MessageReader* reader, base::ScopedFD* value) { return PopValueFromReader(reader, value); } }; // brillo::Any -------------------------------------------------------------- -BRILLO_EXPORT void AppendValueToWriter(dbus::MessageWriter* writer, - const brillo::Any& value); -BRILLO_EXPORT bool PopValueFromReader(dbus::MessageReader* reader, - brillo::Any* value); +BRILLO_EXPORT void AppendValueToWriter(::dbus::MessageWriter* writer, + const brillo::Any& value); +BRILLO_EXPORT bool PopValueFromReader(::dbus::MessageReader* reader, + brillo::Any* value); template<> struct DBusType<brillo::Any> { inline static std::string GetSignature() { return DBUS_TYPE_VARIANT_AS_STRING; } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const brillo::Any& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, brillo::Any* value) { + inline static bool Read(::dbus::MessageReader* reader, brillo::Any* value) { return PopValueFromReader(reader, value); } }; // std::vector = D-Bus ARRAY. ------------------------------------------------- -template<typename T, typename ALLOC> +template <typename T, typename ALLOC> typename std::enable_if<IsTypeSupported<T>::value>::type AppendValueToWriter( - dbus::MessageWriter* writer, - const std::vector<T, ALLOC>& value) { - dbus::MessageWriter array_writer(nullptr); + ::dbus::MessageWriter* writer, const std::vector<T, ALLOC>& value) { + ::dbus::MessageWriter array_writer(nullptr); writer->OpenArray(GetDBusSignature<T>(), &array_writer); for (const auto& element : value) { // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay @@ -479,11 +478,12 @@ typename std::enable_if<IsTypeSupported<T>::value>::type AppendValueToWriter( writer->CloseContainer(&array_writer); } -template<typename T, typename ALLOC> +template <typename T, typename ALLOC> typename std::enable_if<IsTypeSupported<T>::value, bool>::type -PopValueFromReader(dbus::MessageReader* reader, std::vector<T, ALLOC>* value) { - dbus::MessageReader variant_reader(nullptr); - dbus::MessageReader array_reader(nullptr); +PopValueFromReader(::dbus::MessageReader* reader, + std::vector<T, ALLOC>* value) { + ::dbus::MessageReader variant_reader(nullptr); + ::dbus::MessageReader array_reader(nullptr); if (!details::DescendIntoVariantIfPresent(&reader, &variant_reader) || !reader->PopArray(&array_reader)) return false; @@ -510,11 +510,11 @@ struct DBusArrayType { inline static std::string GetSignature() { return GetArrayDBusSignature(GetDBusSignature<T>()); } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::vector<T, ALLOC>& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, + inline static bool Read(::dbus::MessageReader* reader, std::vector<T, ALLOC>* value) { return PopValueFromReader(reader, value); } @@ -562,11 +562,10 @@ inline std::string GetStructDBusSignature() { DBUS_STRUCT_END_CHAR_AS_STRING; } -template<typename U, typename V> +template <typename U, typename V> typename std::enable_if<IsTypeSupported<U, V>::value>::type AppendValueToWriter( - dbus::MessageWriter* writer, - const std::pair<U, V>& value) { - dbus::MessageWriter struct_writer(nullptr); + ::dbus::MessageWriter* writer, const std::pair<U, V>& value) { + ::dbus::MessageWriter struct_writer(nullptr); writer->OpenStruct(&struct_writer); // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay // binding to AppendValueToWriter() to the point of instantiation of this @@ -576,11 +575,11 @@ typename std::enable_if<IsTypeSupported<U, V>::value>::type AppendValueToWriter( writer->CloseContainer(&struct_writer); } -template<typename U, typename V> +template <typename U, typename V> typename std::enable_if<IsTypeSupported<U, V>::value, bool>::type -PopValueFromReader(dbus::MessageReader* reader, std::pair<U, V>* value) { - dbus::MessageReader variant_reader(nullptr); - dbus::MessageReader struct_reader(nullptr); +PopValueFromReader(::dbus::MessageReader* reader, std::pair<U, V>* value) { + ::dbus::MessageReader variant_reader(nullptr); + ::dbus::MessageReader struct_reader(nullptr); if (!details::DescendIntoVariantIfPresent(&reader, &variant_reader) || !reader->PopStruct(&struct_reader)) return false; @@ -602,11 +601,12 @@ struct DBusPairType { inline static std::string GetSignature() { return GetStructDBusSignature<U, V>(); } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::pair<U, V>& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, std::pair<U, V>* value) { + inline static bool Read(::dbus::MessageReader* reader, + std::pair<U, V>* value) { return PopValueFromReader(reader, value); } }; @@ -636,7 +636,7 @@ struct TupleIterator { using ValueType = typename std::tuple_element<I, Tuple>::type; // Write the tuple element at index I to D-Bus message. - static void Write(dbus::MessageWriter* writer, const Tuple& value) { + static void Write(::dbus::MessageWriter* writer, const Tuple& value) { // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay // binding to AppendValueToWriter() to the point of instantiation of this // template. @@ -645,7 +645,7 @@ struct TupleIterator { } // Read the tuple element at index I from D-Bus message. - static bool Read(dbus::MessageReader* reader, Tuple* value) { + static bool Read(::dbus::MessageReader* reader, Tuple* value) { // Use DBusType<T>::Read() instead of PopValueFromReader() to delay // binding to PopValueFromReader() to the point of instantiation of this // template. @@ -658,29 +658,29 @@ struct TupleIterator { template<size_t N, typename... T> struct TupleIterator<N, N, T...> { using Tuple = std::tuple<T...>; - static void Write(dbus::MessageWriter* /* writer */, + static void Write(::dbus::MessageWriter* /* writer */, const Tuple& /* value */) {} - static bool Read(dbus::MessageReader* /* reader */, - Tuple* /* value */) { return true; } + static bool Read(::dbus::MessageReader* /* reader */, Tuple* /* value */) { + return true; + } }; } // namespace details -template<typename... T> +template <typename... T> typename std::enable_if<IsTypeSupported<T...>::value>::type AppendValueToWriter( - dbus::MessageWriter* writer, - const std::tuple<T...>& value) { - dbus::MessageWriter struct_writer(nullptr); + ::dbus::MessageWriter* writer, const std::tuple<T...>& value) { + ::dbus::MessageWriter struct_writer(nullptr); writer->OpenStruct(&struct_writer); details::TupleIterator<0, sizeof...(T), T...>::Write(&struct_writer, value); writer->CloseContainer(&struct_writer); } -template<typename... T> +template <typename... T> typename std::enable_if<IsTypeSupported<T...>::value, bool>::type -PopValueFromReader(dbus::MessageReader* reader, std::tuple<T...>* value) { - dbus::MessageReader variant_reader(nullptr); - dbus::MessageReader struct_reader(nullptr); +PopValueFromReader(::dbus::MessageReader* reader, std::tuple<T...>* value) { + ::dbus::MessageReader variant_reader(nullptr); + ::dbus::MessageReader struct_reader(nullptr); if (!details::DescendIntoVariantIfPresent(&reader, &variant_reader) || !reader->PopStruct(&struct_reader)) return false; @@ -699,11 +699,11 @@ struct DBusTupleType { inline static std::string GetSignature() { return GetStructDBusSignature<T...>(); } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::tuple<T...>& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, + inline static bool Read(::dbus::MessageReader* reader, std::tuple<T...>* value) { return PopValueFromReader(reader, value); } @@ -720,14 +720,14 @@ struct DBusType<std::tuple<T...>> : public details::DBusTupleType<IsTypeSupported<T...>::value, T...> {}; // std::map = D-Bus ARRAY of DICT_ENTRY. -------------------------------------- -template<typename KEY, typename VALUE, typename PRED, typename ALLOC> +template <typename KEY, typename VALUE, typename PRED, typename ALLOC> typename std::enable_if<IsTypeSupported<KEY, VALUE>::value>::type -AppendValueToWriter(dbus::MessageWriter* writer, +AppendValueToWriter(::dbus::MessageWriter* writer, const std::map<KEY, VALUE, PRED, ALLOC>& value) { - dbus::MessageWriter dict_writer(nullptr); + ::dbus::MessageWriter dict_writer(nullptr); writer->OpenArray(details::GetDBusDictEntryType<KEY, VALUE>(), &dict_writer); for (const auto& pair : value) { - dbus::MessageWriter entry_writer(nullptr); + ::dbus::MessageWriter entry_writer(nullptr); dict_writer.OpenDictEntry(&entry_writer); // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay // binding to AppendValueToWriter() to the point of instantiation of this @@ -739,18 +739,18 @@ AppendValueToWriter(dbus::MessageWriter* writer, writer->CloseContainer(&dict_writer); } -template<typename KEY, typename VALUE, typename PRED, typename ALLOC> +template <typename KEY, typename VALUE, typename PRED, typename ALLOC> typename std::enable_if<IsTypeSupported<KEY, VALUE>::value, bool>::type -PopValueFromReader(dbus::MessageReader* reader, +PopValueFromReader(::dbus::MessageReader* reader, std::map<KEY, VALUE, PRED, ALLOC>* value) { - dbus::MessageReader variant_reader(nullptr); - dbus::MessageReader array_reader(nullptr); + ::dbus::MessageReader variant_reader(nullptr); + ::dbus::MessageReader array_reader(nullptr); if (!details::DescendIntoVariantIfPresent(&reader, &variant_reader) || !reader->PopArray(&array_reader)) return false; value->clear(); while (array_reader.HasMoreData()) { - dbus::MessageReader dict_entry_reader(nullptr); + ::dbus::MessageReader dict_entry_reader(nullptr); if (!array_reader.PopDictEntry(&dict_entry_reader)) return false; KEY key; @@ -782,11 +782,11 @@ struct DBusMapType { inline static std::string GetSignature() { return GetArrayDBusSignature(GetDBusDictEntryType<KEY, VALUE>()); } - inline static void Write(dbus::MessageWriter* writer, + inline static void Write(::dbus::MessageWriter* writer, const std::map<KEY, VALUE, PRED, ALLOC>& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, + inline static bool Read(::dbus::MessageReader* reader, std::map<KEY, VALUE, PRED, ALLOC>* value) { return PopValueFromReader(reader, value); } @@ -807,12 +807,12 @@ struct DBusType<std::map<KEY, VALUE, PRED, ALLOC>> ALLOC> {}; // google::protobuf::MessageLite = D-Bus ARRAY of BYTE ------------------------ -inline void AppendValueToWriter(dbus::MessageWriter* writer, +inline void AppendValueToWriter(::dbus::MessageWriter* writer, const google::protobuf::MessageLite& value) { writer->AppendProtoAsArrayOfBytes(value); } -inline bool PopValueFromReader(dbus::MessageReader* reader, +inline bool PopValueFromReader(::dbus::MessageReader* reader, google::protobuf::MessageLite* value) { return reader->PopArrayOfBytesAsProto(value); } @@ -835,23 +835,23 @@ struct DBusType<T, typename std::enable_if<is_protobuf<T>::value>::type> { inline static std::string GetSignature() { return GetDBusSignature<std::vector<uint8_t>>(); } - inline static void Write(dbus::MessageWriter* writer, const T& value) { + inline static void Write(::dbus::MessageWriter* writer, const T& value) { AppendValueToWriter(writer, value); } - inline static bool Read(dbus::MessageReader* reader, T* value) { + inline static bool Read(::dbus::MessageReader* reader, T* value) { return PopValueFromReader(reader, value); } }; //---------------------------------------------------------------------------- -// AppendValueToWriterAsVariant<T>(dbus::MessageWriter* writer, const T& value) -// Write the |value| of type T to D-Bus message as a VARIANT. -// This overload is provided only if T is supported by D-Bus. -template<typename T> +// AppendValueToWriterAsVariant<T>(::dbus::MessageWriter* writer, const T& +// value) Write the |value| of type T to D-Bus message as a VARIANT. This +// overload is provided only if T is supported by D-Bus. +template <typename T> typename std::enable_if<IsTypeSupported<T>::value>::type -AppendValueToWriterAsVariant(dbus::MessageWriter* writer, const T& value) { +AppendValueToWriterAsVariant(::dbus::MessageWriter* writer, const T& value) { std::string data_type = GetDBusSignature<T>(); - dbus::MessageWriter variant_writer(nullptr); + ::dbus::MessageWriter variant_writer(nullptr); writer->OpenVariant(data_type, &variant_writer); // Use DBusType<T>::Write() instead of AppendValueToWriter() to delay // binding to AppendValueToWriter() to the point of instantiation of this @@ -862,13 +862,13 @@ AppendValueToWriterAsVariant(dbus::MessageWriter* writer, const T& value) { // Special case: do not allow to write a Variant containing a Variant. // Just redirect to normal AppendValueToWriter(). -inline void AppendValueToWriterAsVariant(dbus::MessageWriter* writer, +inline void AppendValueToWriterAsVariant(::dbus::MessageWriter* writer, const brillo::Any& value) { return AppendValueToWriter(writer, value); } //---------------------------------------------------------------------------- -// PopVariantValueFromReader<T>(dbus::MessageWriter* writer, T* value) +// PopVariantValueFromReader<T>(::dbus::MessageWriter* writer, T* value) // Reads a Variant containing the |value| of type T from D-Bus message. // Note that the generic PopValueFromReader<T>(...) can do this too. // This method is provided for two reasons: @@ -876,10 +876,10 @@ inline void AppendValueToWriterAsVariant(dbus::MessageWriter* writer, // 2. To be used when it is important to assert that the data was sent // specifically as a Variant. // This overload is provided only if T is supported by D-Bus. -template<typename T> +template <typename T> typename std::enable_if<IsTypeSupported<T>::value, bool>::type -PopVariantValueFromReader(dbus::MessageReader* reader, T* value) { - dbus::MessageReader variant_reader(nullptr); +PopVariantValueFromReader(::dbus::MessageReader* reader, T* value) { + ::dbus::MessageReader variant_reader(nullptr); if (!reader->PopVariant(&variant_reader)) return false; // Use DBusType<T>::Read() instead of PopValueFromReader() to delay @@ -889,7 +889,8 @@ PopVariantValueFromReader(dbus::MessageReader* reader, T* value) { } // Special handling of request to read a Variant of Variant. -inline bool PopVariantValueFromReader(dbus::MessageReader* reader, Any* value) { +inline bool PopVariantValueFromReader(::dbus::MessageReader* reader, + Any* value) { return PopValueFromReader(reader, value); } diff --git a/brillo/dbus/data_serialization_fuzzer.cc b/brillo/dbus/data_serialization_fuzzer.cc new file mode 100644 index 0000000..dd576a1 --- /dev/null +++ b/brillo/dbus/data_serialization_fuzzer.cc @@ -0,0 +1,334 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <map> +#include <string> +#include <utility> +#include <vector> + +#include <base/logging.h> +#include <base/strings/string_util.h> +#include <brillo/dbus/data_serialization.h> +#include <dbus/string_util.h> +#include <fuzzer/FuzzedDataProvider.h> + +namespace { +constexpr int kRandomMaxContainerSize = 8; +constexpr int kRandomMaxDataLength = 128; + +typedef enum DataType { + kUint8 = 0, + kUint16, + kUint32, + kUint64, + kInt16, + kInt32, + kInt64, + kBool, + kDouble, + kString, + kObjectPath, + // A couple vector types. + kVectorInt16, + kVectorString, + // A couple pair types. + kPairBoolInt64, + kPairUint32String, + // A couple tuple types. + kTupleUint16StringBool, + kTupleDoubleInt32ObjectPath, + // A couple map types. + kMapInt32String, + kMapDoubleBool, + kMaxValue = kMapDoubleBool, +} DataType; + +template <typename T> +void AppendValue(dbus::MessageWriter* writer, bool variant, const T& value) { + if (variant) + brillo::dbus_utils::AppendValueToWriterAsVariant(writer, value); + else + brillo::dbus_utils::AppendValueToWriter(writer, value); +} + +template <typename T> +void GenerateIntAndAppendValue(FuzzedDataProvider* data_provider, + dbus::MessageWriter* writer, + bool variant) { + AppendValue(writer, variant, data_provider->ConsumeIntegral<T>()); +} + +template <typename T> +void PopValue(dbus::MessageReader* reader, bool variant, T* value) { + if (variant) + brillo::dbus_utils::PopVariantValueFromReader(reader, value); + else + brillo::dbus_utils::PopValueFromReader(reader, value); +} + +std::string GenerateValidUTF8(FuzzedDataProvider* data_provider) { + // >= 0x80 + // Generates a random string and returns it if it is valid UTF8, if it is not + // then it will strip it down to all the 7-bit ASCII chars and just return + // that string. + std::string str = + data_provider->ConsumeRandomLengthString(kRandomMaxDataLength); + if (base::IsStringUTF8(str)) + return str; + for (auto it = str.begin(); it != str.end(); it++) { + if (static_cast<uint8_t>(*it) >= 0x80) { + // Might be invalid, remove it. + it = str.erase(it); + it--; + } + } + return str; +} + +} // namespace + +class Environment { + public: + Environment() { + // Disable logging. + logging::SetMinLogLevel(logging::LOG_FATAL); + } +}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + static Environment env; + FuzzedDataProvider data_provider(data, size); + // Consume a random fraction of our data writing random things to a D-Bus + // message, and then consume the remaining data reading randomly from that + // same D-Bus message. Given the templated nature of these functions and that + // they support essentially an infinite amount of types, we are constraining + // this to a fixed set of types defined above. + std::unique_ptr<dbus::Response> message = dbus::Response::CreateEmpty(); + dbus::MessageWriter writer(message.get()); + + int bytes_left_for_read = + static_cast<int>(data_provider.ConsumeProbability<float>() * size); + while (data_provider.remaining_bytes() > bytes_left_for_read) { + DataType curr_type = data_provider.ConsumeEnum<DataType>(); + bool variant = data_provider.ConsumeBool(); + switch (curr_type) { + case kUint8: + GenerateIntAndAppendValue<uint8_t>(&data_provider, &writer, variant); + break; + case kUint16: + GenerateIntAndAppendValue<uint16_t>(&data_provider, &writer, variant); + break; + case kUint32: + GenerateIntAndAppendValue<uint32_t>(&data_provider, &writer, variant); + break; + case kUint64: + GenerateIntAndAppendValue<uint64_t>(&data_provider, &writer, variant); + break; + case kInt16: + GenerateIntAndAppendValue<int16_t>(&data_provider, &writer, variant); + break; + case kInt32: + GenerateIntAndAppendValue<int32_t>(&data_provider, &writer, variant); + break; + case kInt64: + GenerateIntAndAppendValue<int64_t>(&data_provider, &writer, variant); + break; + case kBool: + AppendValue(&writer, variant, data_provider.ConsumeBool()); + break; + case kDouble: + AppendValue(&writer, variant, + data_provider.ConsumeProbability<double>()); + break; + case kString: + AppendValue(&writer, variant, GenerateValidUTF8(&data_provider)); + break; + case kObjectPath: { + std::string object_path = + data_provider.ConsumeRandomLengthString(kRandomMaxDataLength); + // If this isn't valid we'll hit a CHECK failure. + if (dbus::IsValidObjectPath(object_path)) + AppendValue(&writer, variant, dbus::ObjectPath(object_path)); + break; + } + case kVectorInt16: { + int vec_size = data_provider.ConsumeIntegralInRange<int>( + 0, kRandomMaxContainerSize); + std::vector<int16_t> vec(vec_size); + for (int i = 0; i < vec_size; i++) + vec[i] = data_provider.ConsumeIntegral<int16_t>(); + AppendValue(&writer, variant, vec); + break; + } + case kVectorString: { + int vec_size = data_provider.ConsumeIntegralInRange<int>( + 0, kRandomMaxContainerSize); + std::vector<std::string> vec(vec_size); + for (int i = 0; i < vec_size; i++) + vec[i] = GenerateValidUTF8(&data_provider); + AppendValue(&writer, variant, vec); + break; + } + case kPairBoolInt64: + AppendValue( + &writer, variant, + std::pair<bool, int64_t>{data_provider.ConsumeBool(), + data_provider.ConsumeIntegral<int64_t>()}); + break; + case kPairUint32String: + AppendValue(&writer, variant, + std::pair<uint32_t, std::string>{ + data_provider.ConsumeIntegral<uint32_t>(), + GenerateValidUTF8(&data_provider)}); + break; + case kTupleUint16StringBool: + AppendValue(&writer, variant, + std::tuple<uint32_t, std::string, bool>{ + data_provider.ConsumeIntegral<uint32_t>(), + GenerateValidUTF8(&data_provider), + data_provider.ConsumeBool()}); + break; + case kTupleDoubleInt32ObjectPath: { + std::string object_path = + data_provider.ConsumeRandomLengthString(kRandomMaxDataLength); + // If this isn't valid we'll hit a CHECK failure. + if (dbus::IsValidObjectPath(object_path)) { + AppendValue(&writer, variant, + std::tuple<double, int32_t, dbus::ObjectPath>{ + data_provider.ConsumeProbability<double>(), + data_provider.ConsumeIntegral<int32_t>(), + dbus::ObjectPath(object_path)}); + } + break; + } + case kMapInt32String: { + int map_size = data_provider.ConsumeIntegralInRange<int>( + 0, kRandomMaxContainerSize); + std::map<int32_t, std::string> map; + for (int i = 0; i < map_size; i++) + map[data_provider.ConsumeIntegral<int32_t>()] = + GenerateValidUTF8(&data_provider); + AppendValue(&writer, variant, map); + break; + } + case kMapDoubleBool: { + int map_size = data_provider.ConsumeIntegralInRange<int>( + 0, kRandomMaxContainerSize); + std::map<double, bool> map; + for (int i = 0; i < map_size; i++) + map[data_provider.ConsumeProbability<double>()] = + data_provider.ConsumeBool(); + AppendValue(&writer, variant, map); + break; + } + } + } + + dbus::MessageReader reader(message.get()); + while (data_provider.remaining_bytes()) { + DataType curr_type = data_provider.ConsumeEnum<DataType>(); + bool variant = data_provider.ConsumeBool(); + switch (curr_type) { + case kUint8: { + uint8_t value; + PopValue(&reader, variant, &value); + break; + } + case kUint16: { + uint16_t value; + PopValue(&reader, variant, &value); + break; + } + case kUint32: { + uint32_t value; + PopValue(&reader, variant, &value); + break; + } + case kUint64: { + uint64_t value; + PopValue(&reader, variant, &value); + break; + } + case kInt16: { + int16_t value; + PopValue(&reader, variant, &value); + break; + } + case kInt32: { + int32_t value; + PopValue(&reader, variant, &value); + break; + } + case kInt64: { + int64_t value; + PopValue(&reader, variant, &value); + break; + } + case kBool: { + bool value; + PopValue(&reader, variant, &value); + break; + } + case kDouble: { + double value; + PopValue(&reader, variant, &value); + break; + } + case kString: { + std::string value; + PopValue(&reader, variant, &value); + break; + } + case kObjectPath: { + dbus::ObjectPath value; + PopValue(&reader, variant, &value); + break; + } + case kVectorInt16: { + std::vector<int16_t> value; + PopValue(&reader, variant, &value); + break; + } + case kVectorString: { + std::vector<std::string> value; + PopValue(&reader, variant, &value); + break; + } + case kPairBoolInt64: { + std::pair<bool, int64_t> value; + PopValue(&reader, variant, &value); + break; + } + case kPairUint32String: { + std::pair<uint32_t, std::string> value; + PopValue(&reader, variant, &value); + break; + } + case kTupleUint16StringBool: { + std::tuple<uint16_t, std::string, bool> value; + break; + } + case kTupleDoubleInt32ObjectPath: { + std::tuple<double, int32_t, dbus::ObjectPath> value; + PopValue(&reader, variant, &value); + break; + } + case kMapInt32String: { + std::map<int32_t, std::string> value; + PopValue(&reader, variant, &value); + break; + } + case kMapDoubleBool: { + std::map<double, bool> value; + PopValue(&reader, variant, &value); + break; + } + } + } + + return 0; +} diff --git a/brillo/dbus/data_serialization_unittest.cc b/brillo/dbus/data_serialization_test.cc index c7d5e0f..7e68af5 100644 --- a/brillo/dbus/data_serialization_unittest.cc +++ b/brillo/dbus/data_serialization_test.cc @@ -5,6 +5,7 @@ #include <brillo/dbus/data_serialization.h> #include <limits> +#include <tuple> #include <base/files/scoped_file.h> #include <brillo/variant_dictionary.h> @@ -473,19 +474,28 @@ TEST(DBusUtils, ArraysAsVariant) { std::vector<double> dbl_array_empty{}; std::map<std::string, std::string> dict_ss{{"k1", "v1"}, {"k2", "v2"}}; VariantDictionary dict_sv{{"k1", 1}, {"k2", "v2"}}; + using ComplexStructArray = + std::vector<std::tuple<uint32_t, bool, std::vector<uint8_t>>>; + ComplexStructArray complex_struct_array{ + {123, true, {0xaa, 0xbb, 0xcc}}, + {456, false, {0xdd}}, + {789, false, {}}, + }; AppendValueToWriterAsVariant(&writer, int_array); AppendValueToWriterAsVariant(&writer, str_array); AppendValueToWriterAsVariant(&writer, dbl_array_empty); AppendValueToWriterAsVariant(&writer, dict_ss); AppendValueToWriterAsVariant(&writer, dict_sv); + AppendValueToWriterAsVariant(&writer, complex_struct_array); - EXPECT_EQ("vvvvv", message->GetSignature()); + EXPECT_EQ("vvvvvv", message->GetSignature()); Any int_array_out; Any str_array_out; Any dbl_array_out; Any dict_ss_out; Any dict_sv_out; + Any complex_struct_array_out; MessageReader reader(message.get()); EXPECT_TRUE(PopValueFromReader(&reader, &int_array_out)); @@ -493,6 +503,7 @@ TEST(DBusUtils, ArraysAsVariant) { EXPECT_TRUE(PopValueFromReader(&reader, &dbl_array_out)); EXPECT_TRUE(PopValueFromReader(&reader, &dict_ss_out)); EXPECT_TRUE(PopValueFromReader(&reader, &dict_sv_out)); + EXPECT_TRUE(PopValueFromReader(&reader, &complex_struct_array_out)); EXPECT_FALSE(reader.HasMoreData()); EXPECT_EQ(int_array, int_array_out.Get<std::vector<int>>()); @@ -503,6 +514,35 @@ TEST(DBusUtils, ArraysAsVariant) { dict_sv_out.Get<VariantDictionary>().at("k1").Get<int>()); EXPECT_EQ(dict_sv["k2"].Get<const char*>(), dict_sv_out.Get<VariantDictionary>().at("k2").Get<std::string>()); + EXPECT_EQ(complex_struct_array, + complex_struct_array_out.Get<ComplexStructArray>()); +} + +TEST(DBusUtils, StructsAsVariant) { + std::unique_ptr<Response> message = Response::CreateEmpty(); + MessageWriter writer(message.get()); + VariantDictionary dict_sv{{"k1", 1}, {"k2", "v2"}}; + std::tuple<uint32_t, VariantDictionary> u32_dict_sv_struct = + std::make_tuple(1, dict_sv); + AppendValueToWriterAsVariant(&writer, u32_dict_sv_struct); + + EXPECT_EQ("v", message->GetSignature()); + + Any u32_dict_sv_struct_out_any; + + MessageReader reader(message.get()); + EXPECT_TRUE(PopValueFromReader(&reader, &u32_dict_sv_struct_out_any)); + EXPECT_FALSE(reader.HasMoreData()); + + auto u32_dict_sv_struct_out = + u32_dict_sv_struct_out_any.Get<std::tuple<uint32_t, VariantDictionary>>(); + EXPECT_EQ(std::get<0>(u32_dict_sv_struct), + std::get<0>(u32_dict_sv_struct_out)); + VariantDictionary dict_sv_out = std::get<1>(u32_dict_sv_struct_out); + EXPECT_EQ(dict_sv.size(), dict_sv_out.size()); + EXPECT_EQ(dict_sv["k1"].Get<int>(), dict_sv_out["k1"].Get<int>()); + EXPECT_EQ(dict_sv["k2"].Get<const char*>(), + dict_sv_out["k2"].Get<std::string>()); } TEST(DBusUtils, VariantDictionary) { diff --git a/brillo/dbus/dbus_connection.cc b/brillo/dbus/dbus_connection.cc index b60cf44..2773316 100644 --- a/brillo/dbus/dbus_connection.cc +++ b/brillo/dbus/dbus_connection.cc @@ -4,15 +4,6 @@ #include <brillo/dbus/dbus_connection.h> -#include <sysexits.h> - -#include <base/bind.h> -#include <brillo/dbus/async_event_sequencer.h> -#include <brillo/dbus/exported_object_manager.h> - -using brillo::dbus_utils::AsyncEventSequencer; -using brillo::dbus_utils::ExportedObjectManager; - namespace brillo { DBusConnection::DBusConnection() { diff --git a/brillo/dbus/dbus_connection.h b/brillo/dbus/dbus_connection.h index aecf434..5f08ef7 100644 --- a/brillo/dbus/dbus_connection.h +++ b/brillo/dbus/dbus_connection.h @@ -21,15 +21,15 @@ class BRILLO_EXPORT DBusConnection final { // Instantiates dbus::Bus and establishes a D-Bus connection. Returns a // reference to the connected bus, or an empty pointer in case of error. - scoped_refptr<dbus::Bus> Connect(); + scoped_refptr<::dbus::Bus> Connect(); // Instantiates dbus::Bus and tries to establish a D-Bus connection for up to // |timeout|. If the connection can't be established after the timeout, fails // returning an empty pointer. - scoped_refptr<dbus::Bus> ConnectWithTimeout(base::TimeDelta timeout); + scoped_refptr<::dbus::Bus> ConnectWithTimeout(base::TimeDelta timeout); private: - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; private: DISALLOW_COPY_AND_ASSIGN(DBusConnection); @@ -37,4 +37,4 @@ class BRILLO_EXPORT DBusConnection final { } // namespace brillo -#endif // LIBBRILLO_BRILLO_DAEMONS_DBUS_DAEMON_H_ +#endif // LIBBRILLO_BRILLO_DBUS_DBUS_CONNECTION_H_ diff --git a/brillo/dbus/dbus_method_invoker.h b/brillo/dbus/dbus_method_invoker.h index f8b6990..08f5781 100644 --- a/brillo/dbus/dbus_method_invoker.h +++ b/brillo/dbus/dbus_method_invoker.h @@ -65,6 +65,7 @@ #include <memory> #include <string> #include <tuple> +#include <utility> #include <base/bind.h> #include <base/files/scoped_file.h> @@ -91,19 +92,19 @@ namespace dbus_utils { // [dbus/dbus.h]). // Returns a dbus::Response object on success. On failure, returns nullptr and // fills in additional error details into the |error| object. -template<typename... Args> -inline std::unique_ptr<dbus::Response> CallMethodAndBlockWithTimeout( +template <typename... Args> +inline std::unique_ptr<::dbus::Response> CallMethodAndBlockWithTimeout( int timeout_ms, - dbus::ObjectProxy* object, + ::dbus::ObjectProxy* object, const std::string& interface_name, const std::string& method_name, ErrorPtr* error, const Args&... args) { - dbus::MethodCall method_call(interface_name, method_name); + ::dbus::MethodCall method_call(interface_name, method_name); // Add method arguments to the message buffer. - dbus::MessageWriter writer(&method_call); + ::dbus::MessageWriter writer(&method_call); DBusParamWriter::Append(&writer, args...); - dbus::ScopedDBusError dbus_error; + ::dbus::ScopedDBusError dbus_error; auto response = object->CallMethodAndBlockWithErrorDetails( &method_call, timeout_ms, &dbus_error); if (!response) { @@ -127,19 +128,16 @@ inline std::unique_ptr<dbus::Response> CallMethodAndBlockWithTimeout( } // Same as CallMethodAndBlockWithTimeout() but uses a default timeout value. -template<typename... Args> -inline std::unique_ptr<dbus::Response> CallMethodAndBlock( - dbus::ObjectProxy* object, +template <typename... Args> +inline std::unique_ptr<::dbus::Response> CallMethodAndBlock( + ::dbus::ObjectProxy* object, const std::string& interface_name, const std::string& method_name, ErrorPtr* error, const Args&... args) { - return CallMethodAndBlockWithTimeout(dbus::ObjectProxy::TIMEOUT_USE_DEFAULT, - object, - interface_name, - method_name, - error, - args...); + return CallMethodAndBlockWithTimeout(::dbus::ObjectProxy::TIMEOUT_USE_DEFAULT, + object, interface_name, method_name, + error, args...); } namespace internal { @@ -169,9 +167,9 @@ inline FileDescriptor HackMove(const FileDescriptor& val) { // Extracts the parameters of |ResultTypes...| types from the message reader // and puts the values in the resulting |tuple|. Returns false on error and // provides additional error details in |error| object. -template<typename... ResultTypes> +template <typename... ResultTypes> inline bool ExtractMessageParametersAsTuple( - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, std::tuple<ResultTypes...>* val_tuple) { auto callback = [val_tuple](const ResultTypes&... params) { @@ -182,9 +180,9 @@ inline bool ExtractMessageParametersAsTuple( } // Overload of ExtractMessageParametersAsTuple to handle reference types in // tuples created with std::tie(). -template<typename... ResultTypes> +template <typename... ResultTypes> inline bool ExtractMessageParametersAsTuple( - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, std::tuple<ResultTypes&...>* ref_tuple) { auto callback = [ref_tuple](const ResultTypes&... params) { @@ -207,8 +205,8 @@ inline bool ExtractMessageParametersAsTuple( // if (ExtractMessageParameters(reader, &error, &data1, &data2)) { ... } // // The above example extracts an Int32 and a String from D-Bus message buffer. -template<typename... ResultTypes> -inline bool ExtractMessageParameters(dbus::MessageReader* reader, +template <typename... ResultTypes> +inline bool ExtractMessageParameters(::dbus::MessageReader* reader, ErrorPtr* error, ResultTypes*... results) { auto ref_tuple = std::tie(*results...); @@ -225,14 +223,14 @@ inline bool ExtractMessageParameters(dbus::MessageReader* reader, // any return values. Just do not specify any output |results|. In this case, // ExtractMethodCallResults() will verify that the method didn't return any // data in the |message|. -template<typename... ResultTypes> -inline bool ExtractMethodCallResults(dbus::Message* message, +template <typename... ResultTypes> +inline bool ExtractMethodCallResults(::dbus::Message* message, ErrorPtr* error, ResultTypes*... results) { CHECK(message) << "Unable to extract parameters from a NULL message."; - dbus::MessageReader reader(message); - if (message->GetMessageType() == dbus::Message::MESSAGE_ERROR) { + ::dbus::MessageReader reader(message); + if (message->GetMessageType() == ::dbus::Message::MESSAGE_ERROR) { std::string error_message; if (ExtractMessageParameters(&reader, error, &error_message)) AddDBusError(error, message->GetErrorName(), error_message); @@ -249,24 +247,24 @@ using AsyncErrorCallback = base::Callback<void(Error* error)>; // A helper function that translates dbus::ErrorResponse response // from D-Bus into brillo::Error* and invokes the |callback|. void BRILLO_EXPORT TranslateErrorResponse(const AsyncErrorCallback& callback, - dbus::ErrorResponse* resp); + ::dbus::ErrorResponse* resp); // A helper function that translates dbus::Response from D-Bus into // a list of C++ values passed as parameters to |success_callback|. If the // response message doesn't have the correct number of parameters, or they // are of wrong types, an error is sent to |error_callback|. -template<typename... OutArgs> +template <typename... OutArgs> void TranslateSuccessResponse( const base::Callback<void(OutArgs...)>& success_callback, const AsyncErrorCallback& error_callback, - dbus::Response* resp) { + ::dbus::Response* resp) { auto callback = [&success_callback](const OutArgs&... params) { if (!success_callback.is_null()) { success_callback.Run(params...); } }; ErrorPtr error; - dbus::MessageReader reader(resp); + ::dbus::MessageReader reader(resp); if (!DBusParamReader<false, OutArgs...>::Invoke(callback, &reader, &error) && !error_callback.is_null()) { error_callback.Run(error.get()); @@ -283,43 +281,40 @@ void TranslateSuccessResponse( // a problem invoking a method (e.g. object or method doesn't exist). // If the response is not received within |timeout_ms|, an error callback is // called with DBUS_ERROR_NO_REPLY error code. -template<typename... InArgs, typename... OutArgs> +template <typename... InArgs, typename... OutArgs> inline void CallMethodWithTimeout( int timeout_ms, - dbus::ObjectProxy* object, + ::dbus::ObjectProxy* object, const std::string& interface_name, const std::string& method_name, const base::Callback<void(OutArgs...)>& success_callback, const AsyncErrorCallback& error_callback, const InArgs&... params) { - dbus::MethodCall method_call(interface_name, method_name); - dbus::MessageWriter writer(&method_call); + ::dbus::MethodCall method_call(interface_name, method_name); + ::dbus::MessageWriter writer(&method_call); DBusParamWriter::Append(&writer, params...); - dbus::ObjectProxy::ErrorCallback dbus_error_callback = + ::dbus::ObjectProxy::ErrorCallback dbus_error_callback = base::Bind(&TranslateErrorResponse, error_callback); - dbus::ObjectProxy::ResponseCallback dbus_success_callback = base::Bind( + ::dbus::ObjectProxy::ResponseCallback dbus_success_callback = base::Bind( &TranslateSuccessResponse<OutArgs...>, success_callback, error_callback); - object->CallMethodWithErrorCallback( - &method_call, timeout_ms, dbus_success_callback, dbus_error_callback); + object->CallMethodWithErrorCallback(&method_call, timeout_ms, + std::move(dbus_success_callback), + std::move(dbus_error_callback)); } // Same as CallMethodWithTimeout() but uses a default timeout value. -template<typename... InArgs, typename... OutArgs> -inline void CallMethod(dbus::ObjectProxy* object, +template <typename... InArgs, typename... OutArgs> +inline void CallMethod(::dbus::ObjectProxy* object, const std::string& interface_name, const std::string& method_name, const base::Callback<void(OutArgs...)>& success_callback, const AsyncErrorCallback& error_callback, const InArgs&... params) { - return CallMethodWithTimeout(dbus::ObjectProxy::TIMEOUT_USE_DEFAULT, - object, - interface_name, - method_name, - success_callback, - error_callback, - params...); + return CallMethodWithTimeout(::dbus::ObjectProxy::TIMEOUT_USE_DEFAULT, object, + interface_name, method_name, success_callback, + error_callback, params...); } } // namespace dbus_utils diff --git a/brillo/dbus/dbus_method_invoker_unittest.cc b/brillo/dbus/dbus_method_invoker_test.cc index 34f4c6f..c0c681b 100644 --- a/brillo/dbus/dbus_method_invoker_unittest.cc +++ b/brillo/dbus/dbus_method_invoker_test.cc @@ -6,8 +6,7 @@ #include <string> -#include <base/files/scoped_file.h> -#include <brillo/bind_lambda.h> +#include <base/bind.h> #include <dbus/mock_bus.h> #include <dbus/mock_object_proxy.h> #include <dbus/scoped_dbus_error.h> @@ -85,15 +84,15 @@ class DBusMethodInvokerTest : public testing::Test { .WillRepeatedly(Return(mock_object_proxy_.get())); int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT; EXPECT_CALL(*mock_object_proxy_, - MockCallMethodAndBlockWithErrorDetails(_, def_timeout_ms, _)) + CallMethodAndBlockWithErrorDetails(_, def_timeout_ms, _)) .WillRepeatedly(Invoke(this, &DBusMethodInvokerTest::CreateResponse)); } void TearDown() override { bus_ = nullptr; } - Response* CreateResponse(dbus::MethodCall* method_call, - int /* timeout_ms */, - dbus::ScopedDBusError* dbus_error) { + std::unique_ptr<Response> CreateResponse(dbus::MethodCall* method_call, + int /* timeout_ms */, + dbus::ScopedDBusError* dbus_error) { if (method_call->GetInterface() == kTestInterface) { if (method_call->GetMember() == kTestMethod1) { MessageReader reader(method_call); @@ -104,12 +103,12 @@ class DBusMethodInvokerTest : public testing::Test { auto response = Response::CreateEmpty(); MessageWriter writer(response.get()); writer.AppendString(std::to_string(v1 + v2)); - return response.release(); + return response; } } else if (method_call->GetMember() == kTestMethod2) { method_call->SetSerial(123); dbus_set_error(dbus_error->get(), "org.MyError", "My error message"); - return nullptr; + return std::unique_ptr<dbus::Response>(); } else if (method_call->GetMember() == kTestMethod3) { MessageReader reader(method_call); dbus_utils_test::TestMessage msg; @@ -117,7 +116,7 @@ class DBusMethodInvokerTest : public testing::Test { auto response = Response::CreateEmpty(); MessageWriter writer(response.get()); AppendValueToWriter(&writer, msg); - return response.release(); + return response; } } else if (method_call->GetMember() == kTestMethod4) { method_call->SetSerial(123); @@ -127,13 +126,13 @@ class DBusMethodInvokerTest : public testing::Test { auto response = Response::CreateEmpty(); MessageWriter writer(response.get()); writer.AppendFileDescriptor(fd.get()); - return response.release(); + return response; } } } LOG(ERROR) << "Unexpected method call: " << method_call->ToString(); - return nullptr; + return std::unique_ptr<dbus::Response>(); } std::string CallTestMethod(int v1, int v2) { @@ -244,7 +243,7 @@ class AsyncDBusMethodInvokerTest : public testing::Test { .WillRepeatedly(Return(mock_object_proxy_.get())); int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT; EXPECT_CALL(*mock_object_proxy_, - CallMethodWithErrorCallback(_, def_timeout_ms, _, _)) + DoCallMethodWithErrorCallback(_, def_timeout_ms, _, _)) .WillRepeatedly(Invoke(this, &AsyncDBusMethodInvokerTest::HandleCall)); } @@ -252,8 +251,8 @@ class AsyncDBusMethodInvokerTest : public testing::Test { void HandleCall(dbus::MethodCall* method_call, int /* timeout_ms */, - dbus::ObjectProxy::ResponseCallback success_callback, - dbus::ObjectProxy::ErrorCallback error_callback) { + dbus::ObjectProxy::ResponseCallback* success_callback, + dbus::ObjectProxy::ErrorCallback* error_callback) { if (method_call->GetInterface() == kTestInterface) { if (method_call->GetMember() == kTestMethod1) { MessageReader reader(method_call); @@ -264,14 +263,14 @@ class AsyncDBusMethodInvokerTest : public testing::Test { auto response = Response::CreateEmpty(); MessageWriter writer(response.get()); writer.AppendString(std::to_string(v1 + v2)); - success_callback.Run(response.get()); + std::move(*success_callback).Run(response.get()); } return; } else if (method_call->GetMember() == kTestMethod2) { method_call->SetSerial(123); auto error_response = dbus::ErrorResponse::FromMethodCall( method_call, "org.MyError", "My error message"); - error_callback.Run(error_response.get()); + std::move(*error_callback).Run(error_response.get()); return; } } diff --git a/brillo/dbus/dbus_method_response.h b/brillo/dbus/dbus_method_response.h index 289f11e..15df602 100644 --- a/brillo/dbus/dbus_method_response.h +++ b/brillo/dbus/dbus_method_response.h @@ -5,8 +5,12 @@ #ifndef LIBBRILLO_BRILLO_DBUS_DBUS_METHOD_RESPONSE_H_ #define LIBBRILLO_BRILLO_DBUS_DBUS_METHOD_RESPONSE_H_ +#include <memory> #include <string> +#include <utility> +#include <base/bind.h> +#include <base/location.h> #include <base/macros.h> #include <brillo/brillo_export.h> #include <brillo/dbus/dbus_param_writer.h> @@ -20,14 +24,25 @@ class Error; namespace dbus_utils { -using ResponseSender = dbus::ExportedObject::ResponseSender; +using ResponseSender = ::dbus::ExportedObject::ResponseSender; // DBusMethodResponseBase is a helper class used with asynchronous D-Bus method // handlers to encapsulate the information needed to send the method call // response when it is available. class BRILLO_EXPORT DBusMethodResponseBase { public: - DBusMethodResponseBase(dbus::MethodCall* method_call, ResponseSender sender); + DBusMethodResponseBase(::dbus::MethodCall* method_call, + ResponseSender sender); + DBusMethodResponseBase(DBusMethodResponseBase&& other) + : sender_(std::exchange( + other.sender_, + base::Bind([](std::unique_ptr<dbus::Response> response) { + LOG(DFATAL) + << "Empty DBusMethodResponseBase attempts to send a response"; + }))), + method_call_(std::exchange(other.method_call_, nullptr)) {} + DBusMethodResponseBase& operator=(DBusMethodResponseBase&& other) = delete; + virtual ~DBusMethodResponseBase(); // Sends an error response. Marshals the |error| object over D-Bus. @@ -36,20 +51,20 @@ class BRILLO_EXPORT DBusMethodResponseBase { // For error is from other domains, the full error information (domain, error // code, error message) is encoded into the D-Bus error message and returned // to the caller as "org.freedesktop.DBus.Failed". - void ReplyWithError(const brillo::Error* error); + virtual void ReplyWithError(const brillo::Error* error); // Constructs brillo::Error object from the parameters specified and send // the error information over D-Bus using the method above. - void ReplyWithError(const base::Location& location, - const std::string& error_domain, - const std::string& error_code, - const std::string& error_message); + virtual void ReplyWithError(const base::Location& location, + const std::string& error_domain, + const std::string& error_code, + const std::string& error_message); // Sends a raw D-Bus response message. - void SendRawResponse(std::unique_ptr<dbus::Response> response); + void SendRawResponse(std::unique_ptr<::dbus::Response> response); // Creates a custom response object for the current method call. - std::unique_ptr<dbus::Response> CreateCustomResponse() const; + std::unique_ptr<::dbus::Response> CreateCustomResponse() const; // Checks if the response has been sent already. bool IsResponseSent() const; @@ -67,9 +82,7 @@ class BRILLO_EXPORT DBusMethodResponseBase { // in the bound parameter list in the Callback). We set it to nullptr after // the method call response has been sent to ensure we can't possibly try // to send a response again somehow. - dbus::MethodCall* method_call_; - - DISALLOW_COPY_AND_ASSIGN(DBusMethodResponseBase); + ::dbus::MethodCall* method_call_; }; // DBusMethodResponse is an explicitly-typed version of DBusMethodResponse. @@ -83,10 +96,10 @@ class DBusMethodResponse : public DBusMethodResponseBase { // Sends the a successful response. |return_values| can contain a list // of return values to be sent to the caller. - inline void Return(const Types&... return_values) { + virtual void Return(const Types&... return_values) { CheckCanSendResponse(); auto response = CreateCustomResponse(); - dbus::MessageWriter writer(response.get()); + ::dbus::MessageWriter writer(response.get()); DBusParamWriter::Append(&writer, return_values...); SendRawResponse(std::move(response)); } diff --git a/brillo/dbus/dbus_object.cc b/brillo/dbus/dbus_object.cc index 512cd6f..502af3e 100644 --- a/brillo/dbus/dbus_object.cc +++ b/brillo/dbus/dbus_object.cc @@ -4,9 +4,12 @@ #include <brillo/dbus/dbus_object.h> +#include <memory> +#include <utility> #include <vector> #include <base/bind.h> +#include <base/bind_helpers.h> #include <base/logging.h> #include <brillo/dbus/async_event_sequencer.h> #include <brillo/dbus/exported_object_manager.h> @@ -37,8 +40,9 @@ void SetupDefaultPropertyHandlers(DBusInterface* prop_interface, DBusInterface::DBusInterface(DBusObject* dbus_object, const std::string& interface_name) - : dbus_object_(dbus_object), interface_name_(interface_name) { -} + : dbus_object_(dbus_object), + interface_name_(interface_name), + release_interface_cb_(base::DoNothing()) {} void DBusInterface::AddProperty(const std::string& property_name, ExportedPropertyBase* prop_base) { @@ -115,6 +119,50 @@ void DBusInterface::ExportAndBlock( } } +void DBusInterface::UnexportAsync( + ExportedObjectManager* object_manager, + dbus::ExportedObject* exported_object, + const dbus::ObjectPath& object_path, + const AsyncEventSequencer::CompletionAction& completion_callback) { + VLOG(1) << "Unexporting D-Bus interface " << interface_name_ << " for " + << object_path.value(); + + // Release the interface. + release_interface_cb_.RunAndReset(); + + // Unexport all method handlers. + scoped_refptr<AsyncEventSequencer> sequencer(new AsyncEventSequencer()); + for (const auto& pair : handlers_) { + std::string method_name = pair.first; + VLOG(1) << "Unexporting method: " << interface_name_ << "." << method_name; + std::string export_error = "Failed unexporting " + method_name + " method"; + auto export_handler = sequencer->GetExportHandler( + interface_name_, method_name, export_error, true); + exported_object->UnexportMethod(interface_name_, method_name, + export_handler); + } + + sequencer->OnAllTasksCompletedCall({completion_callback}); +} + +void DBusInterface::UnexportAndBlock(ExportedObjectManager* object_manager, + dbus::ExportedObject* exported_object, + const dbus::ObjectPath& object_path) { + VLOG(1) << "Unexporting D-Bus interface " << interface_name_ << " for " + << object_path.value(); + + // Release the interface. + release_interface_cb_.RunAndReset(); + + // Unexport all method handlers. + for (const auto& pair : handlers_) { + std::string method_name = pair.first; + VLOG(1) << "Unexporting method: " << interface_name_ << "." << method_name; + if (!exported_object->UnexportMethodAndBlock(interface_name_, method_name)) + LOG(FATAL) << "Failed unexporting " << method_name << " method"; + } +} + void DBusInterface::ClaimInterface( base::WeakPtr<ExportedObjectManager> object_manager, const dbus::ObjectPath& object_path, @@ -125,6 +173,7 @@ void DBusInterface::ClaimInterface( return; } object_manager->ClaimInterface(object_path, interface_name_, writer); + release_interface_cb_.RunAndReset(); release_interface_cb_.ReplaceClosure( base::Bind(&ExportedObjectManager::ReleaseInterface, object_manager, object_path, interface_name_)); @@ -234,6 +283,25 @@ void DBusObject::ExportInterfaceAsync( object_path_, completion_callback); } +void DBusObject::ExportInterfaceAndBlock(const std::string& interface_name) { + AddOrGetInterface(interface_name) + ->ExportAndBlock(object_manager_.get(), bus_.get(), exported_object_, + object_path_); +} + +void DBusObject::UnexportInterfaceAsync( + const std::string& interface_name, + const AsyncEventSequencer::CompletionAction& completion_callback) { + AddOrGetInterface(interface_name) + ->UnexportAsync(object_manager_.get(), exported_object_, object_path_, + completion_callback); +} + +void DBusObject::UnexportInterfaceAndBlock(const std::string& interface_name) { + AddOrGetInterface(interface_name) + ->UnexportAndBlock(object_manager_.get(), exported_object_, object_path_); +} + void DBusObject::RegisterAsync( const AsyncEventSequencer::CompletionAction& completion_callback) { VLOG(1) << "Registering D-Bus object '" << object_path_.value() << "'."; diff --git a/brillo/dbus/dbus_object.h b/brillo/dbus/dbus_object.h index 61c954f..6ab0b23 100644 --- a/brillo/dbus/dbus_object.h +++ b/brillo/dbus/dbus_object.h @@ -45,7 +45,8 @@ class MyDbusObject { void Method3(std::unique_ptr<DBusMethodResponse<int_32>> response, const std::string& message) { if (message.empty()) { - response->ReplyWithError(brillo::errors::dbus::kDomain, + response->ReplyWithError(FROM_HERE, + brillo::errors::dbus::kDomain, DBUS_ERROR_INVALID_ARGS, "Message string cannot be empty"); return; @@ -62,7 +63,9 @@ class MyDbusObject { #define LIBBRILLO_BRILLO_DBUS_DBUS_OBJECT_H_ #include <map> +#include <memory> #include <string> +#include <utility> #include <base/bind.h> #include <base/callback_helpers.h> @@ -197,10 +200,10 @@ class BRILLO_EXPORT DBusInterface final { // Register sync DBus method handler for |method_name| as base::Callback. // Passing the method sender as a first parameter to the callback. - template<typename... Args> + template <typename... Args> inline void AddSimpleMethodHandlerWithErrorAndMessage( const std::string& method_name, - const base::Callback<bool(ErrorPtr*, dbus::Message*, Args...)>& + const base::Callback<bool(ErrorPtr*, ::dbus::Message*, Args...)>& handler) { Handler<SimpleDBusInterfaceMethodHandlerWithErrorAndMessage<Args...>>::Add( this, method_name, handler); @@ -209,10 +212,10 @@ class BRILLO_EXPORT DBusInterface final { // Register sync D-Bus method handler for |method_name| as a static // function. Passing the method D-Bus message as the second parameter to the // callback. - template<typename... Args> + template <typename... Args> inline void AddSimpleMethodHandlerWithErrorAndMessage( const std::string& method_name, - bool(*handler)(ErrorPtr*, dbus::Message*, Args...)) { + bool (*handler)(ErrorPtr*, ::dbus::Message*, Args...)) { Handler<SimpleDBusInterfaceMethodHandlerWithErrorAndMessage<Args...>>::Add( this, method_name, base::Bind(handler)); } @@ -220,21 +223,21 @@ class BRILLO_EXPORT DBusInterface final { // Register sync D-Bus method handler for |method_name| as a class member // function. Passing the method D-Bus message as the second parameter to the // callback. - template<typename Instance, typename Class, typename... Args> + template <typename Instance, typename Class, typename... Args> inline void AddSimpleMethodHandlerWithErrorAndMessage( const std::string& method_name, Instance instance, - bool(Class::*handler)(ErrorPtr*, dbus::Message*, Args...)) { + bool (Class::*handler)(ErrorPtr*, ::dbus::Message*, Args...)) { Handler<SimpleDBusInterfaceMethodHandlerWithErrorAndMessage<Args...>>::Add( this, method_name, base::Bind(handler, instance)); } // Same as above but for const-method of a class. - template<typename Instance, typename Class, typename... Args> + template <typename Instance, typename Class, typename... Args> inline void AddSimpleMethodHandlerWithErrorAndMessage( const std::string& method_name, Instance instance, - bool(Class::*handler)(ErrorPtr*, dbus::Message*, Args...) const) { + bool (Class::*handler)(ErrorPtr*, ::dbus::Message*, Args...) const) { Handler<SimpleDBusInterfaceMethodHandlerWithErrorAndMessage<Args...>>::Add( this, method_name, base::Bind(handler, instance)); } @@ -294,11 +297,11 @@ class BRILLO_EXPORT DBusInterface final { } // Register an async DBus method handler for |method_name| as base::Callback. - template<typename Response, typename... Args> + template <typename Response, typename... Args> inline void AddMethodHandlerWithMessage( const std::string& method_name, - const base::Callback<void(std::unique_ptr<Response>, dbus::Message*, - Args...)>& handler) { + const base::Callback<void( + std::unique_ptr<Response>, ::dbus::Message*, Args...)>& handler) { static_assert(std::is_base_of<DBusMethodResponseBase, Response>::value, "Response must be DBusMethodResponse<T...>"); Handler<DBusInterfaceMethodHandlerWithMessage<Response, Args...>>::Add( @@ -307,10 +310,10 @@ class BRILLO_EXPORT DBusInterface final { // Register an async D-Bus method handler for |method_name| as a static // function. - template<typename Response, typename... Args> + template <typename Response, typename... Args> inline void AddMethodHandlerWithMessage( const std::string& method_name, - void (*handler)(std::unique_ptr<Response>, dbus::Message*, Args...)) { + void (*handler)(std::unique_ptr<Response>, ::dbus::Message*, Args...)) { static_assert(std::is_base_of<DBusMethodResponseBase, Response>::value, "Response must be DBusMethodResponse<T...>"); Handler<DBusInterfaceMethodHandlerWithMessage<Response, Args...>>::Add( @@ -319,15 +322,16 @@ class BRILLO_EXPORT DBusInterface final { // Register an async D-Bus method handler for |method_name| as a class member // function. - template<typename Response, - typename Instance, - typename Class, - typename... Args> + template <typename Response, + typename Instance, + typename Class, + typename... Args> inline void AddMethodHandlerWithMessage( const std::string& method_name, Instance instance, - void(Class::*handler)(std::unique_ptr<Response>, - dbus::Message*, Args...)) { + void (Class::*handler)(std::unique_ptr<Response>, + ::dbus::Message*, + Args...)) { static_assert(std::is_base_of<DBusMethodResponseBase, Response>::value, "Response must be DBusMethodResponse<T...>"); Handler<DBusInterfaceMethodHandlerWithMessage<Response, Args...>>::Add( @@ -335,15 +339,16 @@ class BRILLO_EXPORT DBusInterface final { } // Same as above but for const-method of a class. - template<typename Response, - typename Instance, - typename Class, - typename... Args> + template <typename Response, + typename Instance, + typename Class, + typename... Args> inline void AddMethodHandlerWithMessage( const std::string& method_name, Instance instance, - void(Class::*handler)(std::unique_ptr<Response>, dbus::Message*, - Args...) const) { + void (Class::*handler)(std::unique_ptr<Response>, + ::dbus::Message*, + Args...) const) { static_assert(std::is_base_of<DBusMethodResponseBase, Response>::value, "Response must be DBusMethodResponse<T...>"); Handler<DBusInterfaceMethodHandlerWithMessage<Response, Args...>>::Add( @@ -353,17 +358,18 @@ class BRILLO_EXPORT DBusInterface final { // Register a raw D-Bus method handler for |method_name| as base::Callback. inline void AddRawMethodHandler( const std::string& method_name, - const base::Callback<void(dbus::MethodCall*, ResponseSender)>& handler) { + const base::Callback<void(::dbus::MethodCall*, ResponseSender)>& + handler) { Handler<RawDBusInterfaceMethodHandler>::Add(this, method_name, handler); } // Register a raw D-Bus method handler for |method_name| as a class member // function. - template<typename Instance, typename Class> - inline void AddRawMethodHandler( - const std::string& method_name, - Instance instance, - void(Class::*handler)(dbus::MethodCall*, ResponseSender)) { + template <typename Instance, typename Class> + inline void AddRawMethodHandler(const std::string& method_name, + Instance instance, + void (Class::*handler)(::dbus::MethodCall*, + ResponseSender)) { Handler<RawDBusInterfaceMethodHandler>::Add( this, method_name, base::Bind(handler, instance)); } @@ -444,7 +450,7 @@ class BRILLO_EXPORT DBusInterface final { // A generic D-Bus method handler for the interface. It extracts the method // name from |method_call|, looks up a registered handler from |handlers_| // map and dispatched the call to that handler. - void HandleMethodCall(dbus::MethodCall* method_call, ResponseSender sender); + void HandleMethodCall(::dbus::MethodCall* method_call, ResponseSender sender); // Helper to add a handler for method |method_name| to the |handlers_| map. // Not marked BRILLO_PRIVATE because it needs to be called by the inline // template functions AddMethodHandler(...) @@ -467,9 +473,9 @@ class BRILLO_EXPORT DBusInterface final { // registration operation is completed. BRILLO_PRIVATE void ExportAsync( ExportedObjectManager* object_manager, - dbus::Bus* bus, - dbus::ExportedObject* exported_object, - const dbus::ObjectPath& object_path, + ::dbus::Bus* bus, + ::dbus::ExportedObject* exported_object, + const ::dbus::ObjectPath& object_path, const AsyncEventSequencer::CompletionAction& completion_callback); // Exports all the methods and properties of this interface and claims the // D-Bus interface synchronously. @@ -478,15 +484,24 @@ class BRILLO_EXPORT DBusInterface final { // exported_object - instance of D-Bus object the interface is being added to. // object_path - D-Bus object path for the object instance. // interface_name - name of interface being registered. - BRILLO_PRIVATE void ExportAndBlock( + BRILLO_PRIVATE void ExportAndBlock(ExportedObjectManager* object_manager, + ::dbus::Bus* bus, + ::dbus::ExportedObject* exported_object, + const ::dbus::ObjectPath& object_path); + // Releases the D-Bus interface and unexports all the methods asynchronously. + BRILLO_PRIVATE void UnexportAsync( ExportedObjectManager* object_manager, - dbus::Bus* bus, - dbus::ExportedObject* exported_object, - const dbus::ObjectPath& object_path); + ::dbus::ExportedObject* exported_object, + const ::dbus::ObjectPath& object_path, + const AsyncEventSequencer::CompletionAction& completion_callback); + // Releases the D-Bus interface and unexports all the methods synchronously. + BRILLO_PRIVATE void UnexportAndBlock(ExportedObjectManager* object_manager, + ::dbus::ExportedObject* exported_object, + const ::dbus::ObjectPath& object_path); BRILLO_PRIVATE void ClaimInterface( base::WeakPtr<ExportedObjectManager> object_manager, - const dbus::ObjectPath& object_path, + const ::dbus::ObjectPath& object_path, const ExportedPropertySet::PropertyWriter& writer, bool all_succeeded); @@ -518,8 +533,8 @@ class BRILLO_EXPORT DBusObject { // changes on those interfaces. // object_path - D-Bus object path for the object instance. DBusObject(ExportedObjectManager* object_manager, - const scoped_refptr<dbus::Bus>& bus, - const dbus::ObjectPath& object_path); + const scoped_refptr<::dbus::Bus>& bus, + const ::dbus::ObjectPath& object_path); // property_handler_setup_callback - To be called when setting up property // method handlers. Clients can register @@ -527,8 +542,8 @@ class BRILLO_EXPORT DBusObject { // (GetAll/Get/Set) by passing in this // callback. DBusObject(ExportedObjectManager* object_manager, - const scoped_refptr<dbus::Bus>& bus, - const dbus::ObjectPath& object_path, + const scoped_refptr<::dbus::Bus>& bus, + const ::dbus::ObjectPath& object_path, PropertyHandlerSetupCallback property_handler_setup_callback); virtual ~DBusObject(); @@ -551,6 +566,28 @@ class BRILLO_EXPORT DBusObject { const std::string& interface_name, const AsyncEventSequencer::CompletionAction& completion_callback); + // Exports a proxy handler for the interface |interface_name|. If the + // interface proxy does not exist yet, it will be automatically created. This + // call is synchronous and will block until all methods of the interface are + // registered and the interface is claimed. + void ExportInterfaceAndBlock(const std::string& interface_name); + + // Unexports the interface |interface_name| and unexports all method handlers. + // In some cases, one may want to export an interface even after it's removed. + // In that case, they should call this method before removing the interface + // to make sure it will start with a clean state of method handlers. + void UnexportInterfaceAsync( + const std::string& interface_name, + const AsyncEventSequencer::CompletionAction& completion_callback); + + // Unexports the interface |interface_name| and unexports all method handlers. + // In some cases, one may want to export an interface even after it's removed. + // In that case, they should call this method before removing the interface + // to make sure it will start with a clean state of method handlers. + // This call is synchronous and will block until the interface is released and + // all of its methods of are unregistered. + void UnexportInterfaceAndBlock(const std::string& interface_name); + // Registers the object instance with D-Bus. This is an asynchronous call // that will call |completion_callback| when the object and all of its // interfaces are registered. @@ -576,10 +613,10 @@ class BRILLO_EXPORT DBusObject { } // Sends a signal from the exported D-Bus object. - bool SendSignal(dbus::Signal* signal); + bool SendSignal(::dbus::Signal* signal); // Returns the reference to dbus::Bus this object is associated with. - scoped_refptr<dbus::Bus> GetBus() { return bus_; } + scoped_refptr<::dbus::Bus> GetBus() { return bus_; } private: // Add the org.freedesktop.DBus.Properties interface to the object. @@ -593,11 +630,11 @@ class BRILLO_EXPORT DBusObject { // Delegate object implementing org.freedesktop.DBus.ObjectManager interface. base::WeakPtr<ExportedObjectManager> object_manager_; // D-Bus bus object. - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; // D-Bus object path for this object. - dbus::ObjectPath object_path_; + ::dbus::ObjectPath object_path_; // D-Bus object instance once this object is successfully exported. - dbus::ExportedObject* exported_object_ = nullptr; // weak; owned by |bus_|. + ::dbus::ExportedObject* exported_object_ = nullptr; // weak; owned by |bus_|. // Sets up property method handlers. PropertyHandlerSetupCallback property_handler_setup_callback_; diff --git a/brillo/dbus/dbus_object_internal_impl.h b/brillo/dbus/dbus_object_internal_impl.h index 3c5e8d7..a521776 100644 --- a/brillo/dbus/dbus_object_internal_impl.h +++ b/brillo/dbus/dbus_object_internal_impl.h @@ -32,6 +32,7 @@ #include <memory> #include <string> #include <type_traits> +#include <utility> #include <brillo/dbus/data_serialization.h> #include <brillo/dbus/dbus_method_response.h> @@ -52,7 +53,7 @@ class DBusInterfaceMethodHandlerInterface { // Returns true if the method has been handled synchronously (whether or not // a success or error response message had been sent). - virtual void HandleMethod(dbus::MethodCall* method_call, + virtual void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) = 0; }; @@ -76,7 +77,7 @@ class SimpleDBusInterfaceMethodHandler explicit SimpleDBusInterfaceMethodHandler( const base::Callback<R(Args...)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { DBusMethodResponse<R> method_response(method_call, sender); auto invoke_callback = [this, &method_response](const Args&... args) { @@ -84,7 +85,7 @@ class SimpleDBusInterfaceMethodHandler }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); // The handler is expected a return value, don't allow output parameters. if (!DBusParamReader<false, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { @@ -110,19 +111,19 @@ class SimpleDBusInterfaceMethodHandler<void, Args...> explicit SimpleDBusInterfaceMethodHandler( const base::Callback<void(Args...)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { DBusMethodResponseBase method_response(method_call, sender); auto invoke_callback = [this, &method_response](const Args&... args) { handler_.Run(args...); auto response = method_response.CreateCustomResponse(); - dbus::MessageWriter writer(response.get()); + ::dbus::MessageWriter writer(response.get()); DBusParamWriter::AppendDBusOutParams(&writer, args...); method_response.SendRawResponse(std::move(response)); }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<true, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -156,7 +157,7 @@ class SimpleDBusInterfaceMethodHandlerWithError const base::Callback<bool(ErrorPtr*, Args...)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { DBusMethodResponseBase method_response(method_call, sender); auto invoke_callback = [this, &method_response](const Args&... args) { @@ -165,14 +166,14 @@ class SimpleDBusInterfaceMethodHandlerWithError method_response.ReplyWithError(error.get()); } else { auto response = method_response.CreateCustomResponse(); - dbus::MessageWriter writer(response.get()); + ::dbus::MessageWriter writer(response.get()); DBusParamWriter::AppendDBusOutParams(&writer, args...); method_response.SendRawResponse(std::move(response)); } }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<true, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -204,10 +205,10 @@ class SimpleDBusInterfaceMethodHandlerWithErrorAndMessage // A constructor that takes a |handler| to be called when HandleMethod() // virtual function is invoked. explicit SimpleDBusInterfaceMethodHandlerWithErrorAndMessage( - const base::Callback<bool(ErrorPtr*, dbus::Message*, Args...)>& handler) + const base::Callback<bool(ErrorPtr*, ::dbus::Message*, Args...)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { DBusMethodResponseBase method_response(method_call, sender); auto invoke_callback = @@ -217,14 +218,14 @@ class SimpleDBusInterfaceMethodHandlerWithErrorAndMessage method_response.ReplyWithError(error.get()); } else { auto response = method_response.CreateCustomResponse(); - dbus::MessageWriter writer(response.get()); + ::dbus::MessageWriter writer(response.get()); DBusParamWriter::AppendDBusOutParams(&writer, args...); method_response.SendRawResponse(std::move(response)); } }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<true, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -234,7 +235,7 @@ class SimpleDBusInterfaceMethodHandlerWithErrorAndMessage private: // C++ callback to be called when a DBus method is dispatched. - base::Callback<bool(ErrorPtr*, dbus::Message*, Args...)> handler_; + base::Callback<bool(ErrorPtr*, ::dbus::Message*, Args...)> handler_; DISALLOW_COPY_AND_ASSIGN(SimpleDBusInterfaceMethodHandlerWithErrorAndMessage); }; @@ -257,7 +258,7 @@ class DBusInterfaceMethodHandler : public DBusInterfaceMethodHandlerInterface { // This method forwards the call to |handler_| after extracting the required // arguments from the DBus message buffer specified in |method_call|. // The output parameters of |handler_| (if any) are sent back to the called. - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { auto invoke_callback = [this, method_call, &sender](const Args&... args) { std::unique_ptr<Response> response(new Response(method_call, sender)); @@ -265,7 +266,7 @@ class DBusInterfaceMethodHandler : public DBusInterfaceMethodHandlerInterface { }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<false, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -297,14 +298,14 @@ class DBusInterfaceMethodHandlerWithMessage // A constructor that takes a |handler| to be called when HandleMethod() // virtual function is invoked. explicit DBusInterfaceMethodHandlerWithMessage( - const base::Callback<void(std::unique_ptr<Response>, dbus::Message*, - Args...)>& handler) + const base::Callback< + void(std::unique_ptr<Response>, ::dbus::Message*, Args...)>& handler) : handler_(handler) {} // This method forwards the call to |handler_| after extracting the required // arguments from the DBus message buffer specified in |method_call|. // The output parameters of |handler_| (if any) are sent back to the called. - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { auto invoke_callback = [this, method_call, &sender](const Args&... args) { std::unique_ptr<Response> response(new Response(method_call, sender)); @@ -312,7 +313,7 @@ class DBusInterfaceMethodHandlerWithMessage }; ErrorPtr param_reader_error; - dbus::MessageReader reader(method_call); + ::dbus::MessageReader reader(method_call); if (!DBusParamReader<false, Args...>::Invoke( invoke_callback, &reader, ¶m_reader_error)) { // Error parsing method arguments. @@ -323,8 +324,8 @@ class DBusInterfaceMethodHandlerWithMessage private: // C++ callback to be called when a D-Bus method is dispatched. - base::Callback<void(std::unique_ptr<Response>, - dbus::Message*, Args...)> handler_; + base::Callback<void(std::unique_ptr<Response>, ::dbus::Message*, Args...)> + handler_; DISALLOW_COPY_AND_ASSIGN(DBusInterfaceMethodHandlerWithMessage); }; @@ -341,18 +342,18 @@ class RawDBusInterfaceMethodHandler public: // A constructor that takes a |handler| to be called when HandleMethod() // virtual function is invoked. - explicit RawDBusInterfaceMethodHandler( - const base::Callback<void(dbus::MethodCall*, ResponseSender)>& handler) + RawDBusInterfaceMethodHandler( + const base::Callback<void(::dbus::MethodCall*, ResponseSender)>& handler) : handler_(handler) {} - void HandleMethod(dbus::MethodCall* method_call, + void HandleMethod(::dbus::MethodCall* method_call, ResponseSender sender) override { handler_.Run(method_call, sender); } private: // C++ callback to be called when a D-Bus method is dispatched. - base::Callback<void(dbus::MethodCall*, ResponseSender)> handler_; + base::Callback<void(::dbus::MethodCall*, ResponseSender)> handler_; DISALLOW_COPY_AND_ASSIGN(RawDBusInterfaceMethodHandler); }; diff --git a/brillo/dbus/dbus_object_unittest.cc b/brillo/dbus/dbus_object_test.cc index 932a5c8..09615c8 100644 --- a/brillo/dbus/dbus_object_unittest.cc +++ b/brillo/dbus/dbus_object_test.cc @@ -17,8 +17,6 @@ using ::testing::AnyNumber; using ::testing::Return; -using ::testing::Invoke; -using ::testing::Mock; using ::testing::_; namespace brillo { @@ -335,7 +333,39 @@ TEST_F(DBusObjectTest, TestRemovedInterface) { EXPECT_EQ(DBUS_ERROR_UNKNOWN_INTERFACE, response->GetErrorName()); } -TEST_F(DBusObjectTest, TestInterfaceExportedLate) { +TEST_F(DBusObjectTest, TestUnexportInterfaceAsync) { + // Unexport the interface to be tested. It should unexport the methods on that + // interface. + EXPECT_CALL(*mock_exported_object_, + UnexportMethod(kTestInterface3, kTestMethod_NoOp, _)) + .Times(1); + EXPECT_CALL(*mock_exported_object_, + UnexportMethod(kTestInterface3, kTestMethod_WithMessage, _)) + .Times(1); + EXPECT_CALL(*mock_exported_object_, + UnexportMethod(kTestInterface3, kTestMethod_WithMessageAsync, _)) + .Times(1); + dbus_object_->UnexportInterfaceAsync(kTestInterface3, + base::Bind(&OnInterfaceExported)); +} + +TEST_F(DBusObjectTest, TestUnexportInterfaceBlocking) { + // Unexport the interface to be tested. It should unexport the methods on that + // interface. + EXPECT_CALL(*mock_exported_object_, + UnexportMethodAndBlock(kTestInterface3, kTestMethod_NoOp)) + .WillOnce(Return(true)); + EXPECT_CALL(*mock_exported_object_, + UnexportMethodAndBlock(kTestInterface3, kTestMethod_WithMessage)) + .WillOnce(Return(true)); + EXPECT_CALL( + *mock_exported_object_, + UnexportMethodAndBlock(kTestInterface3, kTestMethod_WithMessageAsync)) + .WillOnce(Return(true)); + dbus_object_->UnexportInterfaceAndBlock(kTestInterface3); +} + +TEST_F(DBusObjectTest, TestInterfaceExportedLateAsync) { // Registers a new interface late. dbus_object_->ExportInterfaceAsync(kTestInterface4, base::Bind(&OnInterfaceExported)); @@ -350,6 +380,20 @@ TEST_F(DBusObjectTest, TestInterfaceExportedLate) { EXPECT_EQ(DBUS_ERROR_UNKNOWN_METHOD, response->GetErrorName()); } +TEST_F(DBusObjectTest, TestInterfaceExportedLateBlocking) { + // Registers a new interface late. + dbus_object_->ExportInterfaceAndBlock(kTestInterface4); + + const std::string sender{":1.2345"}; + dbus::MethodCall method_call(kTestInterface4, kTestMethod_WithMessage); + method_call.SetSerial(123); + method_call.SetSender(sender); + auto response = testing::CallMethod(*dbus_object_, &method_call); + // The response should contain error UnknownMethod rather than + // UnknownInterface since the interface has been registered late. + EXPECT_EQ(DBUS_ERROR_UNKNOWN_METHOD, response->GetErrorName()); +} + TEST_F(DBusObjectTest, TooFewParams) { dbus::MethodCall method_call(kTestInterface1, kTestMethod_Add); method_call.SetSerial(123); diff --git a/brillo/dbus/dbus_object_test_helpers.h b/brillo/dbus/dbus_object_test_helpers.h index 59c4a06..4a9287f 100644 --- a/brillo/dbus/dbus_object_test_helpers.h +++ b/brillo/dbus/dbus_object_test_helpers.h @@ -12,6 +12,9 @@ #ifndef LIBBRILLO_BRILLO_DBUS_DBUS_OBJECT_TEST_HELPERS_H_ #define LIBBRILLO_BRILLO_DBUS_DBUS_OBJECT_TEST_HELPERS_H_ +#include <memory> +#include <utility> + #include <base/bind.h> #include <base/memory/weak_ptr.h> #include <brillo/dbus/dbus_method_invoker.h> @@ -25,7 +28,7 @@ namespace dbus_utils { class DBusInterfaceTestHelper final { public: static void HandleMethodCall(DBusInterface* itf, - dbus::MethodCall* method_call, + ::dbus::MethodCall* method_call, ResponseSender sender) { itf->HandleMethodCall(method_call, sender); } @@ -40,11 +43,11 @@ namespace testing { // ResponseHolder::ReceiveResponse() will not be called since we bind the // callback to the object instance via a weak pointer. struct ResponseHolder final : public base::SupportsWeakPtr<ResponseHolder> { - void ReceiveResponse(std::unique_ptr<dbus::Response> response) { + void ReceiveResponse(std::unique_ptr<::dbus::Response> response) { response_ = std::move(response); } - std::unique_ptr<dbus::Response> response_; + std::unique_ptr<::dbus::Response> response_; }; // Dispatches a D-Bus method call to the corresponding handler. @@ -53,10 +56,10 @@ struct ResponseHolder final : public base::SupportsWeakPtr<ResponseHolder> { // call sites. Returns a response from the method handler or nullptr if the // method hasn't provided the response message immediately // (i.e. it is asynchronous). -inline std::unique_ptr<dbus::Response> CallMethod( - const DBusObject& object, dbus::MethodCall* method_call) { +inline std::unique_ptr<::dbus::Response> CallMethod( + const DBusObject& object, ::dbus::MethodCall* method_call) { DBusInterface* itf = object.FindInterface(method_call->GetInterface()); - std::unique_ptr<dbus::Response> response; + std::unique_ptr<::dbus::Response> response; if (!itf) { response = CreateDBusErrorResponse( method_call, @@ -95,7 +98,7 @@ struct MethodHandlerInvoker { Params...), Args... args) { ResponseHolder response_holder; - dbus::MethodCall method_call("test.interface", "TestMethod"); + ::dbus::MethodCall method_call("test.interface", "TestMethod"); method_call.SetSerial(123); std::unique_ptr<DBusMethodResponse<RetType>> method_response{ new DBusMethodResponse<RetType>( @@ -122,7 +125,7 @@ struct MethodHandlerInvoker<void> { void(Class::*method)(std::unique_ptr<DBusMethodResponse<>>, Params...), Args... args) { ResponseHolder response_holder; - dbus::MethodCall method_call("test.interface", "TestMethod"); + ::dbus::MethodCall method_call("test.interface", "TestMethod"); method_call.SetSerial(123); std::unique_ptr<DBusMethodResponse<>> method_response{ new DBusMethodResponse<>(&method_call, diff --git a/brillo/dbus/dbus_param_reader.h b/brillo/dbus/dbus_param_reader.h index 228cfb6..f5c4541 100644 --- a/brillo/dbus/dbus_param_reader.h +++ b/brillo/dbus/dbus_param_reader.h @@ -51,9 +51,9 @@ struct DBusParamReader<allow_out_params, CurrentParam, RestOfParams...> { // method_call - D-Bus method call object we are processing. // reader - D-Bus message reader to pop the current argument value from. // args... - the callback parameters processed so far. - template<typename CallbackType, typename... Args> + template <typename CallbackType, typename... Args> static bool Invoke(const CallbackType& handler, - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, const Args&... args) { return InvokeHelper<CurrentParam, CallbackType, Args...>( @@ -70,10 +70,10 @@ struct DBusParamReader<allow_out_params, CurrentParam, RestOfParams...> { // parameters should be sent back in the method call response message. // Overload 1: ParamType is not a pointer. - template<typename ParamType, typename CallbackType, typename... Args> + template <typename ParamType, typename CallbackType, typename... Args> static typename std::enable_if<!std::is_pointer<ParamType>::value, bool>::type InvokeHelper(const CallbackType& handler, - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, const Args&... args) { if (!reader->HasMoreData()) { @@ -112,13 +112,14 @@ struct DBusParamReader<allow_out_params, CurrentParam, RestOfParams...> { } // Overload 2: ParamType is a pointer. - template<typename ParamType, typename CallbackType, typename... Args> + template <typename ParamType, typename CallbackType, typename... Args> static typename std::enable_if<allow_out_params && - std::is_pointer<ParamType>::value, bool>::type - InvokeHelper(const CallbackType& handler, - dbus::MessageReader* reader, - ErrorPtr* error, - const Args&... args) { + std::is_pointer<ParamType>::value, + bool>::type + InvokeHelper(const CallbackType& handler, + ::dbus::MessageReader* reader, + ErrorPtr* error, + const Args&... args) { // ParamType is a pointer. This is expected to be an output parameter. // Create storage for it and the handler will provide a value for it. using ParamValueType = typename std::remove_pointer<ParamType>::type; @@ -143,9 +144,9 @@ struct DBusParamReader<allow_out_params, CurrentParam, RestOfParams...> { // handler with all the accumulated arguments. template<bool allow_out_params> struct DBusParamReader<allow_out_params> { - template<typename CallbackType, typename... Args> + template <typename CallbackType, typename... Args> static bool Invoke(const CallbackType& handler, - dbus::MessageReader* reader, + ::dbus::MessageReader* reader, ErrorPtr* error, const Args&... args) { if (reader->HasMoreData()) { diff --git a/brillo/dbus/dbus_param_reader_unittest.cc b/brillo/dbus/dbus_param_reader_test.cc index fd9f243..abf1da3 100644 --- a/brillo/dbus/dbus_param_reader_unittest.cc +++ b/brillo/dbus/dbus_param_reader_test.cc @@ -4,6 +4,7 @@ #include <brillo/dbus/dbus_param_reader.h> +#include <memory> #include <string> #include <brillo/variant_dictionary.h> diff --git a/brillo/dbus/dbus_param_writer.h b/brillo/dbus/dbus_param_writer.h index 7c7f45e..779ea61 100644 --- a/brillo/dbus/dbus_param_writer.h +++ b/brillo/dbus/dbus_param_writer.h @@ -24,8 +24,8 @@ class DBusParamWriter final { public: // Generic writer method that takes 1 or more arguments. It recursively calls // itself (each time with one fewer arguments) until no more is left. - template<typename ParamType, typename... RestOfParams> - static void Append(dbus::MessageWriter* writer, + template <typename ParamType, typename... RestOfParams> + static void Append(::dbus::MessageWriter* writer, const ParamType& param, const RestOfParams&... rest) { // Append the current |param| to D-Bus, then call Append() with one @@ -38,13 +38,13 @@ class DBusParamWriter final { // The final overload of DBusParamWriter::Append() used when no more // parameters are remaining to be written. // Does nothing and finishes meta-recursion. - static void Append(dbus::MessageWriter* /*writer*/) {} + static void Append(::dbus::MessageWriter* /*writer*/) {} // Generic writer method that takes 1 or more arguments. It recursively calls // itself (each time with one fewer arguments) until no more is left. // Handles non-pointer parameter by just skipping over it. - template<typename ParamType, typename... RestOfParams> - static void AppendDBusOutParams(dbus::MessageWriter* writer, + template <typename ParamType, typename... RestOfParams> + static void AppendDBusOutParams(::dbus::MessageWriter* writer, const ParamType& /* param */, const RestOfParams&... rest) { // Skip the current |param| and call Append() with one fewer arguments, @@ -57,8 +57,8 @@ class DBusParamWriter final { // itself (each time with one fewer arguments) until no more is left. // Handles only a parameter of pointer type and writes the data pointed to // to the output message buffer. - template<typename ParamType, typename... RestOfParams> - static void AppendDBusOutParams(dbus::MessageWriter* writer, + template <typename ParamType, typename... RestOfParams> + static void AppendDBusOutParams(::dbus::MessageWriter* writer, ParamType* param, const RestOfParams&... rest) { // Append the current |param| to D-Bus, then call Append() with one @@ -71,7 +71,7 @@ class DBusParamWriter final { // The final overload of DBusParamWriter::AppendDBusOutParams() used when no // more parameters are remaining to be written. // Does nothing and finishes meta-recursion. - static void AppendDBusOutParams(dbus::MessageWriter* /*writer*/) {} + static void AppendDBusOutParams(::dbus::MessageWriter* /*writer*/) {} }; } // namespace dbus_utils diff --git a/brillo/dbus/dbus_param_writer_unittest.cc b/brillo/dbus/dbus_param_writer_test.cc index 6ab863a..2611ada 100644 --- a/brillo/dbus/dbus_param_writer_unittest.cc +++ b/brillo/dbus/dbus_param_writer_test.cc @@ -4,6 +4,7 @@ #include <brillo/dbus/dbus_param_writer.h> +#include <memory> #include <string> #include <brillo/any.h> diff --git a/brillo/dbus/dbus_property.h b/brillo/dbus/dbus_property.h index 01b850d..f82759e 100644 --- a/brillo/dbus/dbus_property.h +++ b/brillo/dbus/dbus_property.h @@ -5,6 +5,8 @@ #ifndef LIBBRILLO_BRILLO_DBUS_DBUS_PROPERTY_H_ #define LIBBRILLO_BRILLO_DBUS_DBUS_PROPERTY_H_ +#include <utility> + #include <brillo/dbus/data_serialization.h> #include <dbus/property.h> @@ -16,8 +18,8 @@ namespace dbus_utils { // This class is pretty much a copy of dbus::Property<T> from dbus/property.h // except that it provides the implementations for PopValueFromReader and // AppendSetValueToWriter. -template<class T> -class Property : public dbus::PropertyBase { +template <class T> +class Property : public ::dbus::PropertyBase { public: Property() = default; @@ -27,7 +29,7 @@ class Property : public dbus::PropertyBase { // Requests an updated value from the remote object incurring a // round-trip. |callback| will be called when the new value is available. // This may not be implemented by some interfaces. - void Get(dbus::PropertySet::GetCallback callback) { + void Get(::dbus::PropertySet::GetCallback callback) { property_set()->Get(this, callback); } @@ -40,9 +42,9 @@ class Property : public dbus::PropertyBase { // |callback| will be called to indicate the success or failure of the // request, however the new value may not be available depending on the // remote object. - void Set(const T& value, dbus::PropertySet::SetCallback callback) { + void Set(const T& value, ::dbus::PropertySet::SetCallback callback) { set_value_ = value; - property_set()->Set(this, callback); + property_set()->Set(this, std::move(callback)); } // Synchronous version of Set(). @@ -54,14 +56,14 @@ class Property : public dbus::PropertyBase { // Method used by PropertySet to retrieve the value from a MessageReader, // no knowledge of the contained type is required, this method returns // true if its expected type was found, false if not. - bool PopValueFromReader(dbus::MessageReader* reader) override { + bool PopValueFromReader(::dbus::MessageReader* reader) override { return PopVariantValueFromReader(reader, &value_); } // Method used by PropertySet to append the set value to a MessageWriter, // no knowledge of the contained type is required. // Implementation provided by specialization. - void AppendSetValueToWriter(dbus::MessageWriter* writer) override { + void AppendSetValueToWriter(::dbus::MessageWriter* writer) override { AppendValueToWriterAsVariant(writer, set_value_); } diff --git a/brillo/dbus/dbus_service_watcher.h b/brillo/dbus/dbus_service_watcher.h index 0031771..b747161 100644 --- a/brillo/dbus/dbus_service_watcher.h +++ b/brillo/dbus/dbus_service_watcher.h @@ -29,7 +29,7 @@ namespace dbus_utils { // cause the Bus to crash the process on destruction. class BRILLO_EXPORT DBusServiceWatcher { public: - DBusServiceWatcher(scoped_refptr<dbus::Bus> bus, + DBusServiceWatcher(scoped_refptr<::dbus::Bus> bus, const std::string& connection_name, const base::Closure& on_connection_vanish); virtual ~DBusServiceWatcher(); @@ -38,9 +38,9 @@ class BRILLO_EXPORT DBusServiceWatcher { private: void OnServiceOwnerChange(const std::string& service_owner); - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; const std::string connection_name_; - dbus::Bus::GetServiceOwnerCallback monitoring_callback_; + ::dbus::Bus::GetServiceOwnerCallback monitoring_callback_; base::Closure on_connection_vanish_; base::WeakPtrFactory<DBusServiceWatcher> weak_factory_{this}; diff --git a/brillo/dbus/dbus_signal.h b/brillo/dbus/dbus_signal.h index bda322a..d1fcced 100644 --- a/brillo/dbus/dbus_signal.h +++ b/brillo/dbus/dbus_signal.h @@ -30,7 +30,7 @@ class BRILLO_EXPORT DBusSignalBase { virtual ~DBusSignalBase() = default; protected: - bool SendSignal(dbus::Signal* signal) const; + bool SendSignal(::dbus::Signal* signal) const; std::string interface_name_; std::string signal_name_; @@ -51,9 +51,11 @@ class DBusSignal : public DBusSignalBase { ~DBusSignal() override = default; // DBusSignal<...>::Send(...) dispatches the signal with the given arguments. + // Note: This function can be called from any thread/task runner, as it'll + // eventually post the actual signal sending to the DBus thread. bool Send(const Args&... args) const { - dbus::Signal signal(interface_name_, signal_name_); - dbus::MessageWriter signal_writer(&signal); + ::dbus::Signal signal(interface_name_, signal_name_); + ::dbus::MessageWriter signal_writer(&signal); DBusParamWriter::Append(&signal_writer, args...); return SendSignal(&signal); } diff --git a/brillo/dbus/dbus_signal_handler.h b/brillo/dbus/dbus_signal_handler.h index 15cdae1..e89f867 100644 --- a/brillo/dbus/dbus_signal_handler.h +++ b/brillo/dbus/dbus_signal_handler.h @@ -7,8 +7,9 @@ #include <functional> #include <string> +#include <utility> -#include <brillo/bind_lambda.h> +#include <base/bind.h> #include <brillo/dbus/dbus_param_reader.h> #include <dbus/message.h> #include <dbus/object_proxy.h> @@ -31,39 +32,36 @@ namespace dbus_utils { // If the signal message doesn't contain correct number or types of arguments, // an error message is logged to the system log and the signal is ignored // (|signal_callback| is not invoked). -template<typename... Args> +template <typename... Args> void ConnectToSignal( - dbus::ObjectProxy* object_proxy, + ::dbus::ObjectProxy* object_proxy, const std::string& interface_name, const std::string& signal_name, base::Callback<void(Args...)> signal_callback, - dbus::ObjectProxy::OnConnectedCallback on_connected_callback) { + ::dbus::ObjectProxy::OnConnectedCallback on_connected_callback) { + // DBusParamReader::Invoke() needs a functor object, not a base::Callback. + // Wrap the callback with lambda so we can redirect the call. + auto signal_callback_wrapper = [signal_callback](const Args&... args) { + if (!signal_callback.is_null()) { + signal_callback.Run(args...); + } + }; + // Raw signal handler stub method. When called, unpacks the signal arguments // from |signal| message buffer and redirects the call to // |signal_callback_wrapper| which, in turn, would call the user-provided // |signal_callback|. - auto dbus_signal_callback = []( - const base::Callback<void(Args...)>& signal_callback, - dbus::Signal* signal) { - // DBusParamReader::Invoke() needs a functor object, not a base::Callback. - // Wrap the callback with lambda so we can redirect the call. - auto signal_callback_wrapper = [signal_callback](const Args&... args) { - if (!signal_callback.is_null()) { - signal_callback.Run(args...); - } - }; - - dbus::MessageReader reader(signal); - DBusParamReader<false, Args...>::Invoke( - signal_callback_wrapper, &reader, nullptr); + auto dbus_signal_callback = [](std::function<void(const Args&...)> callback, + ::dbus::Signal* signal) { + ::dbus::MessageReader reader(signal); + DBusParamReader<false, Args...>::Invoke(callback, &reader, nullptr); }; // Register our stub handler with D-Bus ObjectProxy. object_proxy->ConnectToSignal( - interface_name, - signal_name, - base::Bind(dbus_signal_callback, signal_callback), - on_connected_callback); + interface_name, signal_name, + base::Bind(dbus_signal_callback, signal_callback_wrapper), + std::move(on_connected_callback)); } } // namespace dbus_utils diff --git a/brillo/dbus/dbus_signal_handler_unittest.cc b/brillo/dbus/dbus_signal_handler_test.cc index e0bea10..35bad65 100644 --- a/brillo/dbus/dbus_signal_handler_unittest.cc +++ b/brillo/dbus/dbus_signal_handler_test.cc @@ -6,7 +6,7 @@ #include <string> -#include <brillo/bind_lambda.h> +#include <base/bind.h> #include <brillo/dbus/dbus_param_writer.h> #include <dbus/mock_bus.h> #include <dbus/mock_object_proxy.h> @@ -49,7 +49,8 @@ class DBusSignalHandlerTest : public testing::Test { template<typename SignalHandlerSink, typename... Args> void CallSignal(SignalHandlerSink* sink, Args... args) { dbus::ObjectProxy::SignalCallback signal_callback; - EXPECT_CALL(*mock_object_proxy_, ConnectToSignal(kInterface, kSignal, _, _)) + EXPECT_CALL(*mock_object_proxy_, + DoConnectToSignal(kInterface, kSignal, _, _)) .WillOnce(SaveArg<2>(&signal_callback)); brillo::dbus_utils::ConnectToSignal( @@ -70,7 +71,7 @@ class DBusSignalHandlerTest : public testing::Test { }; TEST_F(DBusSignalHandlerTest, ConnectToSignal) { - EXPECT_CALL(*mock_object_proxy_, ConnectToSignal(kInterface, kSignal, _, _)) + EXPECT_CALL(*mock_object_proxy_, DoConnectToSignal(kInterface, kSignal, _, _)) .Times(1); brillo::dbus_utils::ConnectToSignal( @@ -80,7 +81,7 @@ TEST_F(DBusSignalHandlerTest, ConnectToSignal) { TEST_F(DBusSignalHandlerTest, CallSignal_3Args) { class SignalHandlerSink { public: - MOCK_METHOD3(Handler, void(int, int, double)); + MOCK_METHOD(void, Handler, (int, int, double)); } sink; EXPECT_CALL(sink, Handler(10, 20, 30.5)).Times(1); @@ -91,7 +92,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_2Args) { class SignalHandlerSink { public: // Take string both by reference and by value to make sure this works too. - MOCK_METHOD2(Handler, void(const std::string&, std::string)); + MOCK_METHOD(void, Handler, (const std::string&, std::string)); } sink; EXPECT_CALL(sink, Handler(std::string{"foo"}, std::string{"bar"})).Times(1); @@ -101,7 +102,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_2Args) { TEST_F(DBusSignalHandlerTest, CallSignal_NoArgs) { class SignalHandlerSink { public: - MOCK_METHOD0(Handler, void()); + MOCK_METHOD(void, Handler, ()); } sink; EXPECT_CALL(sink, Handler()).Times(1); @@ -111,7 +112,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_NoArgs) { TEST_F(DBusSignalHandlerTest, CallSignal_Error_TooManyArgs) { class SignalHandlerSink { public: - MOCK_METHOD0(Handler, void()); + MOCK_METHOD(void, Handler, ()); } sink; // Handler() expects no args, but we send an int. @@ -122,7 +123,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_Error_TooManyArgs) { TEST_F(DBusSignalHandlerTest, CallSignal_Error_TooFewArgs) { class SignalHandlerSink { public: - MOCK_METHOD2(Handler, void(std::string, bool)); + MOCK_METHOD(void, Handler, (std::string, bool)); } sink; // Handler() expects 2 args while we send it just one. @@ -133,7 +134,7 @@ TEST_F(DBusSignalHandlerTest, CallSignal_Error_TooFewArgs) { TEST_F(DBusSignalHandlerTest, CallSignal_Error_TypeMismatchArgs) { class SignalHandlerSink { public: - MOCK_METHOD2(Handler, void(std::string, bool)); + MOCK_METHOD(void, Handler, (std::string, bool)); } sink; // Handler() expects "sb" while we send it "ii". diff --git a/brillo/dbus/exported_object_manager.cc b/brillo/dbus/exported_object_manager.cc index 61dae68..a2ae1fe 100644 --- a/brillo/dbus/exported_object_manager.cc +++ b/brillo/dbus/exported_object_manager.cc @@ -89,11 +89,11 @@ ExportedObjectManager::HandleGetManagedObjects() { // DICT<STRING,VARIANT>>> ) bus_->AssertOnOriginThread(); ExportedObjectManager::ObjectMap objects; - for (const auto path_pair : registered_objects_) { + for (const auto& path_pair : registered_objects_) { std::map<std::string, VariantDictionary>& interfaces = objects[path_pair.first]; const InterfaceProperties& interface2properties = path_pair.second; - for (const auto interface : interface2properties) { + for (const auto& interface : interface2properties) { interface.second.Run(&interfaces[interface.first]); } } diff --git a/brillo/dbus/exported_object_manager.h b/brillo/dbus/exported_object_manager.h index ea68f33..9534009 100644 --- a/brillo/dbus/exported_object_manager.h +++ b/brillo/dbus/exported_object_manager.h @@ -6,6 +6,7 @@ #define LIBBRILLO_BRILLO_DBUS_EXPORTED_OBJECT_MANAGER_H_ #include <map> +#include <memory> #include <string> #include <vector> @@ -80,12 +81,12 @@ class BRILLO_EXPORT ExportedObjectManager : public base::SupportsWeakPtr<ExportedObjectManager> { public: using ObjectMap = - std::map<dbus::ObjectPath, std::map<std::string, VariantDictionary>>; + std::map<::dbus::ObjectPath, std::map<std::string, VariantDictionary>>; using InterfaceProperties = std::map<std::string, ExportedPropertySet::PropertyWriter>; - ExportedObjectManager(scoped_refptr<dbus::Bus> bus, - const dbus::ObjectPath& path); + ExportedObjectManager(scoped_refptr<::dbus::Bus> bus, + const ::dbus::ObjectPath& path); virtual ~ExportedObjectManager() = default; // Registers methods implementing the ObjectManager interface on the object @@ -98,35 +99,28 @@ class BRILLO_EXPORT ExportedObjectManager // Trigger a signal that |path| has added an interface |interface_name| // with properties as given by |writer|. virtual void ClaimInterface( - const dbus::ObjectPath& path, + const ::dbus::ObjectPath& path, const std::string& interface_name, const ExportedPropertySet::PropertyWriter& writer); // Trigger a signal that |path| has removed an interface |interface_name|. - virtual void ReleaseInterface(const dbus::ObjectPath& path, + virtual void ReleaseInterface(const ::dbus::ObjectPath& path, const std::string& interface_name); - const scoped_refptr<dbus::Bus>& GetBus() const { return bus_; } - - // Due to D-Bus forwarding, clients may need to access the underlying - // DBusObject to handle signals/methods. - // TODO(sonnysasaka): Refactor this accessor into a stricter API once we know - // what D-Bus forwarding needs when it's completed, without exposing - // DBusObject directly. - brillo::dbus_utils::DBusObject* dbus_object() { return &dbus_object_; }; + const scoped_refptr<::dbus::Bus>& GetBus() const { return bus_; } private: BRILLO_PRIVATE ObjectMap HandleGetManagedObjects(); - scoped_refptr<dbus::Bus> bus_; + scoped_refptr<::dbus::Bus> bus_; brillo::dbus_utils::DBusObject dbus_object_; // Tracks all objects currently known to the ExportedObjectManager. - std::map<dbus::ObjectPath, InterfaceProperties> registered_objects_; + std::map<::dbus::ObjectPath, InterfaceProperties> registered_objects_; using SignalInterfacesAdded = - DBusSignal<dbus::ObjectPath, std::map<std::string, VariantDictionary>>; + DBusSignal<::dbus::ObjectPath, std::map<std::string, VariantDictionary>>; using SignalInterfacesRemoved = - DBusSignal<dbus::ObjectPath, std::vector<std::string>>; + DBusSignal<::dbus::ObjectPath, std::vector<std::string>>; std::weak_ptr<SignalInterfacesAdded> signal_itf_added_; std::weak_ptr<SignalInterfacesRemoved> signal_itf_removed_; diff --git a/brillo/dbus/exported_object_manager_unittest.cc b/brillo/dbus/exported_object_manager_test.cc index 00fe108..6837399 100644 --- a/brillo/dbus/exported_object_manager_unittest.cc +++ b/brillo/dbus/exported_object_manager_test.cc @@ -4,6 +4,8 @@ #include <brillo/dbus/exported_object_manager.h> +#include <utility> + #include <base/bind.h> #include <brillo/dbus/dbus_object_test_helpers.h> #include <brillo/dbus/utils.h> diff --git a/brillo/dbus/exported_property_set.cc b/brillo/dbus/exported_property_set.cc index 018843e..c71aab6 100644 --- a/brillo/dbus/exported_property_set.cc +++ b/brillo/dbus/exported_property_set.cc @@ -4,16 +4,15 @@ #include <brillo/dbus/exported_property_set.h> +#include <utility> + #include <base/bind.h> #include <dbus/bus.h> #include <dbus/property.h> // For kPropertyInterface -#include <brillo/dbus/async_event_sequencer.h> #include <brillo/dbus/dbus_object.h> #include <brillo/errors/error_codes.h> -using brillo::dbus_utils::AsyncEventSequencer; - namespace brillo { namespace dbus_utils { diff --git a/brillo/dbus/exported_property_set.h b/brillo/dbus/exported_property_set.h index 971e932..08d0ae4 100644 --- a/brillo/dbus/exported_property_set.h +++ b/brillo/dbus/exported_property_set.h @@ -8,6 +8,7 @@ #include <stdint.h> #include <map> +#include <memory> #include <string> #include <vector> @@ -97,7 +98,7 @@ class BRILLO_EXPORT ExportedPropertySet { public: using PropertyWriter = base::Callback<void(VariantDictionary* dict)>; - explicit ExportedPropertySet(dbus::Bus* bus); + explicit ExportedPropertySet(::dbus::Bus* bus); virtual ~ExportedPropertySet() = default; // Called to notify ExportedPropertySet that the Properties interface of the @@ -148,7 +149,7 @@ class BRILLO_EXPORT ExportedPropertySet { const std::string& property_name, const ExportedPropertyBase* exported_property); - dbus::Bus* bus_; // weak; owned by outer DBusObject containing this object. + ::dbus::Bus* bus_; // weak; owned by outer DBusObject containing this object. // This is a map from interface name -> property name -> pointer to property. std::map<std::string, std::map<std::string, ExportedPropertyBase*>> properties_; diff --git a/brillo/dbus/exported_property_set_unittest.cc b/brillo/dbus/exported_property_set_test.cc index 93aceb4..6f9dbd7 100644 --- a/brillo/dbus/exported_property_set_unittest.cc +++ b/brillo/dbus/exported_property_set_test.cc @@ -177,8 +177,7 @@ class PropertyValidatorObserver { base::Unretained(this))) {} virtual ~PropertyValidatorObserver() {} - MOCK_METHOD2_T(ValidateProperty, - bool(brillo::ErrorPtr* error, const T& value)); + MOCK_METHOD(bool, ValidateProperty, (brillo::ErrorPtr*, const T&)); const base::Callback<bool(brillo::ErrorPtr*, const T&)>& validate_property_callback() const { diff --git a/brillo/dbus/file_descriptor.h b/brillo/dbus/file_descriptor.h index f7be44f..2cf1b02 100644 --- a/brillo/dbus/file_descriptor.h +++ b/brillo/dbus/file_descriptor.h @@ -5,6 +5,8 @@ #ifndef LIBBRILLO_BRILLO_DBUS_FILE_DESCRIPTOR_H_ #define LIBBRILLO_BRILLO_DBUS_FILE_DESCRIPTOR_H_ +#include <utility> + #include <base/files/scoped_file.h> #include <base/macros.h> diff --git a/brillo/dbus/introspectable_helper.cc b/brillo/dbus/introspectable_helper.cc new file mode 100644 index 0000000..68ec78c --- /dev/null +++ b/brillo/dbus/introspectable_helper.cc @@ -0,0 +1,81 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/dbus/introspectable_helper.h> + +#include <memory> + +#include <base/bind.h> +#include <dbus/dbus-shared.h> + +namespace brillo { +namespace dbus_utils { + +using base::Bind; +using std::string; +using std::unique_ptr; + +void IntrospectableInterfaceHelper::AddInterfaceXml(string xml) { + interface_xmls.push_back(xml); +} + +void IntrospectableInterfaceHelper::RegisterWithDBusObject(DBusObject* object) { + DBusInterface* itf = object->AddOrGetInterface(DBUS_INTERFACE_INTROSPECTABLE); + + itf->AddMethodHandler("Introspect", GetHandler()); +} + +IntrospectableInterfaceHelper::IntrospectCallback +IntrospectableInterfaceHelper::GetHandler() { + return Bind( + [](const string& xml, StringResponse response) { response->Return(xml); }, + GetXmlString()); +} + +string IntrospectableInterfaceHelper::GetXmlString() { + constexpr const char header[] = + "<!DOCTYPE node PUBLIC " + "\"-//freedesktop//DTD D-BUS Object Introspection 1.0//EN\"\n" + "\"http://www.freedesktop.org/standards/dbus/1.0/introspect.dtd\">\n" + "\n" + "<node>\n" + " <interface name=\"org.freedesktop.DBus.Introspectable\">\n" + " <method name=\"Introspect\">\n" + " <arg name=\"data\" direction=\"out\" type=\"s\"/>\n" + " </method>\n" + " </interface>\n" + " <interface name=\"org.freedesktop.DBus.Properties\">\n" + " <method name=\"Get\">\n" + " <arg name=\"interface\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"propname\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"value\" direction=\"out\" type=\"v\"/>\n" + " </method>\n" + " <method name=\"Set\">\n" + " <arg name=\"interface\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"propname\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"value\" direction=\"in\" type=\"v\"/>\n" + " </method>\n" + " <method name=\"GetAll\">\n" + " <arg name=\"interface\" direction=\"in\" type=\"s\"/>\n" + " <arg name=\"props\" direction=\"out\" type=\"a{sv}\"/>\n" + " </method>\n" + " </interface>\n"; + constexpr const char footer[] = "</node>\n"; + + size_t result_len = strlen(header) + strlen(footer); + for (const string& xml : interface_xmls) { + result_len += xml.size(); + } + + string result = header; + result.reserve(result_len + 1); // +1 for null terminator + for (const string& xml : interface_xmls) { + result.append(xml); + } + result.append(footer); + return result; +} + +} // namespace dbus_utils +} // namespace brillo diff --git a/brillo/dbus/introspectable_helper.h b/brillo/dbus/introspectable_helper.h new file mode 100644 index 0000000..e1a398f --- /dev/null +++ b/brillo/dbus/introspectable_helper.h @@ -0,0 +1,68 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_DBUS_INTROSPECTABLE_HELPER_H_ +#define LIBBRILLO_BRILLO_DBUS_INTROSPECTABLE_HELPER_H_ + +#include <memory> +#include <string> +#include <vector> + +#include <brillo/brillo_export.h> +#include <brillo/dbus/dbus_method_response.h> +#include <brillo/dbus/dbus_object.h> + +namespace brillo { +namespace dbus_utils { + +// Note that brillo/dbus/dbus_object.h include files that include this file, so +// we'll need this forward declaration. +// class DBusObject; + +// This is a helper class that is used for creating the DBus Introspectable +// Interface. Each of the interfaces that is exported under a DBus Object will +// add its dbus interface introspection XML to this class, and then the user of +// this class will call RegisterWithDBusObject on the DBus object. Then this +// class can be freed. Note that this class is usually used in conjunction with +// the chromeos-dbus-bindings tool. Simply pass the string returned by +// GetIntrospectionXML() of the generated adaptor. Usage example: +// { +// IntrospectableInterfaceHelper helper; +// helper.AddInterfaceXML("<interface...> ...</interface>"); +// helper.AddInterfaceXML("<interface...> ...</interface>"); +// helper.AddInterfaceXML(XXXAdaptor::GetIntrospect()); +// helper.RegisterWithDBusObject(object); +// } +class BRILLO_EXPORT IntrospectableInterfaceHelper { + public: + IntrospectableInterfaceHelper() = default; + + // Add the Introspection XML for an interface to this class. The |xml| string + // should contain an interface XML tag and its content. + void AddInterfaceXml(std::string xml); + + // Register the Introspectable Interface with a DBus object. Note that this + // class can be freed after registering with DBus object. + void RegisterWithDBusObject(DBusObject* object); + + private: + // Internal alias for convenience. + using StringResponse = std::unique_ptr<DBusMethodResponse<std::string>>; + using IntrospectCallback = base::Callback<void(StringResponse)>; + + // Create the method handler for Introspect method call. + IntrospectCallback GetHandler(); + + // Get the complete introspection XML. + std::string GetXmlString(); + + // Stores the list of introspection XMLs for each of the interfaces that was + // added to this class. + std::vector<std::string> interface_xmls; +}; + +} // namespace dbus_utils +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_DBUS_INTROSPECTABLE_HELPER_H_ diff --git a/brillo/dbus/mock_dbus_object.h b/brillo/dbus/mock_dbus_object.h index 82e2fc7..d65f9ab 100644 --- a/brillo/dbus/mock_dbus_object.h +++ b/brillo/dbus/mock_dbus_object.h @@ -17,13 +17,15 @@ namespace dbus_utils { class MockDBusObject : public DBusObject { public: MockDBusObject(ExportedObjectManager* object_manager, - const scoped_refptr<dbus::Bus>& bus, - const dbus::ObjectPath& object_path) + const scoped_refptr<::dbus::Bus>& bus, + const ::dbus::ObjectPath& object_path) : DBusObject(object_manager, bus, object_path) {} ~MockDBusObject() override = default; - MOCK_METHOD1(RegisterAsync, - void(const AsyncEventSequencer::CompletionAction&)); + MOCK_METHOD(void, + RegisterAsync, + (const AsyncEventSequencer::CompletionAction&), + (override)); }; // class MockDBusObject } // namespace dbus_utils diff --git a/brillo/dbus/mock_exported_object_manager.h b/brillo/dbus/mock_exported_object_manager.h index d8abc0a..02bb073 100644 --- a/brillo/dbus/mock_exported_object_manager.h +++ b/brillo/dbus/mock_exported_object_manager.h @@ -24,15 +24,17 @@ class MockExportedObjectManager : public ExportedObjectManager { using ExportedObjectManager::ExportedObjectManager; ~MockExportedObjectManager() override = default; - MOCK_METHOD1(RegisterAsync, - void(const CompletionAction& completion_callback)); - MOCK_METHOD3(ClaimInterface, - void(const dbus::ObjectPath& path, - const std::string& interface_name, - const ExportedPropertySet::PropertyWriter& writer)); - MOCK_METHOD2(ReleaseInterface, - void(const dbus::ObjectPath& path, - const std::string& interface_name)); + MOCK_METHOD(void, RegisterAsync, (const CompletionAction&), (override)); + MOCK_METHOD(void, + ClaimInterface, + (const ::dbus::ObjectPath&, + const std::string&, + const ExportedPropertySet::PropertyWriter&), + (override)); + MOCK_METHOD(void, + ReleaseInterface, + (const ::dbus::ObjectPath&, const std::string&), + (override)); }; } // namespace dbus_utils diff --git a/brillo/dbus/test.proto b/brillo/dbus/test.proto index 84607a3..709bf71 100644 --- a/brillo/dbus/test.proto +++ b/brillo/dbus/test.proto @@ -1,3 +1,9 @@ +// Copyright 2015 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +syntax = "proto2"; + option optimize_for = LITE_RUNTIME; package dbus_utils_test; diff --git a/brillo/dbus/utils.h b/brillo/dbus/utils.h index a548756..163849e 100644 --- a/brillo/dbus/utils.h +++ b/brillo/dbus/utils.h @@ -18,8 +18,8 @@ namespace brillo { namespace dbus_utils { // A helper function to create a D-Bus error response object as unique_ptr<>. -BRILLO_EXPORT std::unique_ptr<dbus::Response> CreateDBusErrorResponse( - dbus::MethodCall* method_call, +BRILLO_EXPORT std::unique_ptr<::dbus::Response> CreateDBusErrorResponse( + ::dbus::MethodCall* method_call, const std::string& error_name, const std::string& error_message); @@ -28,9 +28,8 @@ BRILLO_EXPORT std::unique_ptr<dbus::Response> CreateDBusErrorResponse( // and message are directly translated to D-Bus error code and message. // Any inner errors are formatted as "domain/code:message" string and appended // to the D-Bus error message, delimited by semi-colons. -BRILLO_EXPORT std::unique_ptr<dbus::Response> GetDBusError( - dbus::MethodCall* method_call, - const brillo::Error* error); +BRILLO_EXPORT std::unique_ptr<::dbus::Response> GetDBusError( + ::dbus::MethodCall* method_call, const brillo::Error* error); // AddDBusError() is the opposite of GetDBusError(). It de-serializes the Error // object received over D-Bus. diff --git a/brillo/enum_flags.h b/brillo/enum_flags.h index 9630dd0..227cafd 100644 --- a/brillo/enum_flags.h +++ b/brillo/enum_flags.h @@ -57,7 +57,8 @@ template <typename T, typename = void> struct IsFlagEnum : std::false_type {}; template <typename T> -struct IsFlagEnum<T, Void<typename FlagEnumTraits<T>::EnumFlagType>> : std::true_type {}; +struct IsFlagEnum<T, Void<typename FlagEnumTraits<T>::EnumFlagType>> + : std::true_type {}; } // namespace enum_details @@ -68,7 +69,8 @@ struct IsFlagEnum<T, Void<typename FlagEnumTraits<T>::EnumFlagType>> : std::true template <typename T> constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator~(const T& l) { - return static_cast<T>( ~static_cast<typename std::underlying_type<T>::type>(l)); + return static_cast<T>( + ~static_cast<typename std::underlying_type<T>::type>(l)); } // T operator|(T&, T&) @@ -91,37 +93,37 @@ operator&(const T& l, const T& r) { // T operator^(T&, T&) template <typename T> -constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator^( - const T& l, const T& r) { +constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type +operator^(const T& l, const T& r) { return static_cast<T>(static_cast<typename std::underlying_type<T>::type>(l) ^ static_cast<typename std::underlying_type<T>::type>(r)); -}; +} // T operator|=(T&, T&) template <typename T> -constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator|=( - T& l, const T& r) { +constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type +operator|=(T& l, const T& r) { return l = static_cast<T>( static_cast<typename std::underlying_type<T>::type>(l) | static_cast<typename std::underlying_type<T>::type>(r)); -}; +} // T operator&=(T&, T&) template <typename T> -constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator&=( - T& l, const T& r) { +constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type +operator&=(T& l, const T& r) { return l = static_cast<T>( static_cast<typename std::underlying_type<T>::type>(l) & static_cast<typename std::underlying_type<T>::type>(r)); -}; +} // T operator^=(T&, T&) template <typename T> -constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type operator^=( - T& l, const T& r) { +constexpr typename std::enable_if<enum_details::IsFlagEnum<T>::value, T>::type +operator^=(T& l, const T& r) { return l = static_cast<T>( static_cast<typename std::underlying_type<T>::type>(l) ^ static_cast<typename std::underlying_type<T>::type>(r)); -}; +} #endif // LIBBRILLO_BRILLO_ENUM_FLAGS_H_ diff --git a/brillo/enum_flags_unittest.cc b/brillo/enum_flags_test.cc index e57b4ad..e57b4ad 100644 --- a/brillo/enum_flags_unittest.cc +++ b/brillo/enum_flags_test.cc diff --git a/brillo/errors/error.cc b/brillo/errors/error.cc index f229bd7..ccae1fa 100644 --- a/brillo/errors/error.cc +++ b/brillo/errors/error.cc @@ -4,6 +4,8 @@ #include <brillo/errors/error.h> +#include <utility> + #include <base/logging.h> #include <base/strings/stringprintf.h> @@ -19,16 +21,11 @@ inline void LogError(const base::Location& location, // the current error location with the location passed in to the Error object. // This way the log will contain the actual location of the error, and not // as if it always comes from brillo/errors/error.cc(22). - if (location.function_name() == nullptr) { - logging::LogMessage(location.file_name(), location.line_number(), - logging::LOG_ERROR) - .stream() - << "Domain=" << domain << ", Code=" << code << ", Message=" << message; - return; - } - logging::LogMessage( - location.file_name(), location.line_number(), logging::LOG_ERROR).stream() - << location.function_name() << "(...): " + logging::LogMessage(location.file_name(), location.line_number(), + logging::LOG_ERROR) + .stream() + << (location.function_name() ? location.function_name() : "unknown") + << "(...): " << "Domain=" << domain << ", Code=" << code << ", Message=" << message; } } // anonymous namespace diff --git a/brillo/errors/error.h b/brillo/errors/error.h index d08f0e7..1a6a91e 100644 --- a/brillo/errors/error.h +++ b/brillo/errors/error.h @@ -8,8 +8,8 @@ #include <memory> #include <string> -#include <base/macros.h> #include <base/location.h> +#include <base/macros.h> #include <brillo/brillo_export.h> namespace brillo { @@ -110,6 +110,7 @@ class BRILLO_EXPORT Error { // Human-readable error message. std::string message_; // Error origin in the source code. + // TODO(crbug.com/980935): Consider dropping this. base::Location location_; // Pointer to inner error, if any. This forms a chain of errors. ErrorPtr inner_error_; diff --git a/brillo/errors/error_codes.h b/brillo/errors/error_codes.h index 4f1bc09..664fb03 100644 --- a/brillo/errors/error_codes.h +++ b/brillo/errors/error_codes.h @@ -7,6 +7,7 @@ #include <string> +#include <base/location.h> #include <brillo/brillo_export.h> #include <brillo/errors/error.h> diff --git a/brillo/errors/error_codes_unittest.cc b/brillo/errors/error_codes_test.cc index 2baa28f..2baa28f 100644 --- a/brillo/errors/error_codes_unittest.cc +++ b/brillo/errors/error_codes_test.cc diff --git a/brillo/errors/error_unittest.cc b/brillo/errors/error_test.cc index 93f4372..7dd011e 100644 --- a/brillo/errors/error_unittest.cc +++ b/brillo/errors/error_test.cc @@ -4,6 +4,9 @@ #include <brillo/errors/error.h> +#include <utility> + +#include <base/location.h> #include <gtest/gtest.h> using brillo::Error; @@ -12,9 +15,9 @@ namespace { brillo::ErrorPtr GenerateNetworkError() { base::Location loc("GenerateNetworkError", - "error_unittest.cc", - 15, - ::base::GetProgramCounter()); + "error_test.cc", + 15, + ::base::GetProgramCounter()); return Error::Create(loc, "network", "not_found", "Resource not found"); } @@ -31,7 +34,7 @@ TEST(Error, Single) { EXPECT_EQ("not_found", err->GetCode()); EXPECT_EQ("Resource not found", err->GetMessage()); EXPECT_EQ("GenerateNetworkError", err->GetLocation().function_name()); - EXPECT_EQ("error_unittest.cc", err->GetLocation().file_name()); + EXPECT_EQ("error_test.cc", err->GetLocation().file_name()); EXPECT_EQ(15, err->GetLocation().line_number()); EXPECT_EQ(nullptr, err->GetInnerError()); EXPECT_TRUE(err->HasDomain("network")); @@ -73,7 +76,8 @@ TEST(Error, Clone) { EXPECT_EQ(error1->GetMessage(), error2->GetMessage()); EXPECT_EQ(error1->GetLocation().function_name(), error2->GetLocation().function_name()); - EXPECT_EQ(error1->GetLocation().file_name(), error2->GetLocation().file_name()); + EXPECT_EQ(error1->GetLocation().file_name(), + error2->GetLocation().file_name()); EXPECT_EQ(error1->GetLocation().line_number(), error2->GetLocation().line_number()); error1 = error1->GetInnerError(); diff --git a/brillo/file_utils.cc b/brillo/file_utils.cc index 8faa1b7..3661551 100644 --- a/brillo/file_utils.cc +++ b/brillo/file_utils.cc @@ -7,13 +7,19 @@ #include <fcntl.h> #include <unistd.h> +#include <limits> +#include <utility> +#include <vector> + +#include <base/files/file_enumerator.h> #include <base/files/file_path.h> #include <base/files/file_util.h> -#include <base/files/scoped_file.h> #include <base/logging.h> #include <base/posix/eintr_wrapper.h> #include <base/rand_util.h> +#include <base/stl_util.h> #include <base/strings/string_number_conversions.h> +#include <base/strings/stringprintf.h> #include <base/time/time.h> namespace brillo { @@ -25,7 +31,8 @@ constexpr const base::TimeDelta kLongSync = base::TimeDelta::FromSeconds(10); enum { kPermissions600 = S_IRUSR | S_IWUSR, - kPermissions777 = S_IRWXU | S_IRWXG | S_IRWXO + kPermissions777 = S_IRWXU | S_IRWXG | S_IRWXO, + kPermissions755 = S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH }; // Verify that base file permission enums are compatible with S_Ixxx. If these @@ -135,7 +142,7 @@ bool TouchFileInternal(const base::FilePath& path, std::string GetRandomSuffix() { const int kBufferSize = 6; unsigned char buffer[kBufferSize]; - base::RandBytes(buffer, arraysize(buffer)); + base::RandBytes(buffer, base::size(buffer)); std::string suffix; for (int i = 0; i < kBufferSize; ++i) { int random_value = buffer[i] % (2 * 26 + 10); @@ -150,6 +157,80 @@ std::string GetRandomSuffix() { return suffix; } +base::ScopedFD OpenPathComponentInternal(int parent_fd, + const std::string& file, + int flags, + mode_t mode) { + DCHECK(file == "/" || file.find("/") == std::string::npos); + base::ScopedFD fd; + + // O_NONBLOCK is used to avoid hanging on edge cases (e.g. a serial port with + // flow control, or a FIFO without a writer). + if (parent_fd >= 0 || parent_fd == AT_FDCWD) { + fd.reset(HANDLE_EINTR(openat(parent_fd, file.c_str(), + flags | O_NONBLOCK | O_NOFOLLOW | O_CLOEXEC, + mode))); + } else if (file == "/") { + fd.reset(HANDLE_EINTR(open( + file.c_str(), + flags | O_RDONLY | O_DIRECTORY | O_NONBLOCK | O_NOFOLLOW | O_CLOEXEC, + mode))); + } + + if (!fd.is_valid()) { + // open(2) fails with ELOOP when the last component of the |path| is a + // symlink. It fails with ENXIO when |path| is a FIFO and |flags| is for + // writing because of the O_NONBLOCK flag added above. + if (errno == ELOOP || errno == ENXIO) { + PLOG(WARNING) << "Failed to open " << file << " safely."; + } else { + PLOG(WARNING) << "Failed to open " << file << "."; + } + return base::ScopedFD(); + } + + // Remove the O_NONBLOCK flag unless the original |flags| have it. + if ((flags & O_NONBLOCK) == 0) { + flags = fcntl(fd.get(), F_GETFL); + if (flags == -1) { + PLOG(ERROR) << "Failed to get fd flags for " << file; + return base::ScopedFD(); + } + if (fcntl(fd.get(), F_SETFL, flags & ~O_NONBLOCK)) { + PLOG(ERROR) << "Failed to set fd flags for " << file; + return base::ScopedFD(); + } + } + + return fd; +} + +base::ScopedFD OpenSafelyInternal(int parent_fd, + const base::FilePath& path, + int flags, + mode_t mode) { + std::vector<std::string> components; + path.GetComponents(&components); + + auto itr = components.begin(); + if (itr == components.end()) { + LOG(ERROR) << "A path is required."; + return base::ScopedFD(); // This is an invalid fd. + } + + base::ScopedFD child_fd; + int parent_flags = flags | O_NONBLOCK | O_RDONLY | O_DIRECTORY | O_PATH; + for (; itr + 1 != components.end(); ++itr) { + child_fd = OpenPathComponentInternal(parent_fd, *itr, parent_flags, 0); + if (!child_fd.is_valid()) { + return base::ScopedFD(); + } + parent_fd = child_fd.get(); + } + + return OpenPathComponentInternal(parent_fd, *itr, flags, mode); +} + } // namespace bool TouchFile(const base::FilePath& path, @@ -184,9 +265,129 @@ bool TouchFile(const base::FilePath& path) { return TouchFile(path, kPermissions600, geteuid(), getegid()); } -bool WriteBlobToFile(const base::FilePath& path, const Blob& blob) { - return WriteToFile(path, reinterpret_cast<const char*>(blob.data()), - blob.size()); +base::ScopedFD OpenSafely(const base::FilePath& path, int flags, mode_t mode) { + if (!path.IsAbsolute()) { + LOG(ERROR) << "An absolute path is required."; + return base::ScopedFD(); // This is an invalid fd. + } + + base::ScopedFD fd(OpenSafelyInternal(-1, path, flags, mode)); + if (!fd.is_valid()) + return base::ScopedFD(); + + // Ensure the opened file is a regular file or directory. + struct stat st; + if (fstat(fd.get(), &st) < 0) { + PLOG(ERROR) << "Failed to fstat " << path.value(); + return base::ScopedFD(); + } + + // This detects a FIFO opened for reading, for example. + if (flags & O_DIRECTORY) { + if (!S_ISDIR(st.st_mode)) { + LOG(ERROR) << path.value() << " is not a directory: " << st.st_mode; + return base::ScopedFD(); + } + } else if (!S_ISREG(st.st_mode) && !S_ISDIR(st.st_mode)) { + LOG(ERROR) << path.value() + << " is not a regular file or directory: " << st.st_mode; + return base::ScopedFD(); + } + + return fd; +} + +base::ScopedFD OpenAtSafely(int parent_fd, + const base::FilePath& path, + int flags, + mode_t mode) { + base::ScopedFD fd(OpenSafelyInternal(parent_fd, path, flags, mode)); + if (!fd.is_valid()) + return base::ScopedFD(); + + // Ensure the opened file is a regular file or directory. + struct stat st; + if (fstat(fd.get(), &st) < 0) { + PLOG(ERROR) << "Failed to fstat " << path.value(); + return base::ScopedFD(); + } + + // This detects a FIFO opened for reading, for example. + if (flags & O_DIRECTORY) { + if (!S_ISDIR(st.st_mode)) { + LOG(ERROR) << path.value() << " is not a directory: " << st.st_mode; + return base::ScopedFD(); + } + } else if (!S_ISREG(st.st_mode)) { + LOG(ERROR) << path.value() << " is not a regular file: " << st.st_mode; + return base::ScopedFD(); + } + + return fd; +} + +base::ScopedFD OpenFifoSafely(const base::FilePath& path, + int flags, + mode_t mode) { + if (!path.IsAbsolute()) { + LOG(ERROR) << "An absolute path is required."; + return base::ScopedFD(); // This is an invalid fd. + } + + base::ScopedFD fd(OpenSafelyInternal(-1, path, flags, mode)); + if (!fd.is_valid()) + return base::ScopedFD(); + + // Ensure the opened file is a FIFO. + struct stat st; + if (fstat(fd.get(), &st) < 0) { + PLOG(ERROR) << "Failed to fstat " << path.value(); + return base::ScopedFD(); + } + + if (!S_ISFIFO(st.st_mode)) { + LOG(ERROR) << path.value() << " is not a FIFO: " << st.st_mode; + return base::ScopedFD(); + } + + return fd; +} + +base::ScopedFD MkdirRecursively(const base::FilePath& full_path, mode_t mode) { + std::vector<std::string> components; + full_path.GetComponents(&components); + + auto itr = components.begin(); + if (!full_path.IsAbsolute() || itr == components.end()) { + LOG(ERROR) << "An absolute path is required."; + return base::ScopedFD(); // This is an invalid fd. + } + + base::ScopedFD parent_fd; + int parent_flags = O_NONBLOCK | O_RDONLY | O_DIRECTORY | O_PATH; + while (itr + 1 != components.end()) { + base::ScopedFD child( + OpenPathComponentInternal(parent_fd.get(), *itr, parent_flags, 0)); + if (!child.is_valid()) { + return base::ScopedFD(); + } + parent_fd = std::move(child); + + ++itr; + + // Try to create the directory. Note that Chromium's MkdirRecursively() uses + // 0700, but we use 0755. + if (mkdirat(parent_fd.get(), itr->c_str(), mode) != 0) { + if (errno != EEXIST) { + PLOG(ERROR) << "Failed to mkdirat " << *itr + << ": full_path=" << full_path.value(); + return base::ScopedFD(); + } + } + } + + return OpenPathComponentInternal(parent_fd.get(), *itr, + O_RDONLY | O_DIRECTORY, 0); } bool WriteStringToFile(const base::FilePath& path, const std::string& data) { @@ -306,11 +507,20 @@ bool WriteToFileAtomic(const base::FilePath& path, return true; } -bool WriteBlobToFileAtomic(const base::FilePath& path, - const Blob& blob, - mode_t mode) { - return WriteToFileAtomic(path, reinterpret_cast<const char*>(blob.data()), - blob.size(), mode); +int64_t ComputeDirectoryDiskUsage(const base::FilePath& root_path) { + constexpr size_t S_BLKSIZE = 512; + int64_t running_blocks = 0; + base::FileEnumerator file_iter(root_path, true, + base::FileEnumerator::FILES | + base::FileEnumerator::DIRECTORIES | + base::FileEnumerator::SHOW_SYM_LINKS); + while (!file_iter.Next().empty()) { + // st_blocks in struct stat is the number of S_BLKSIZE (512) bytes sized + // blocks occupied by this file. + running_blocks += file_iter.GetInfo().stat().st_blocks; + } + // Each block is S_BLKSIZE (512) bytes so *S_BLKSIZE. + return running_blocks * S_BLKSIZE; } } // namespace brillo diff --git a/brillo/file_utils.h b/brillo/file_utils.h index 663d640..f328165 100644 --- a/brillo/file_utils.h +++ b/brillo/file_utils.h @@ -7,7 +7,10 @@ #include <sys/types.h> +#include <string> + #include <base/files/file_path.h> +#include <base/files/scoped_file.h> #include <brillo/brillo_export.h> #include <brillo/secure_blob.h> @@ -31,6 +34,62 @@ BRILLO_EXPORT bool TouchFile(const base::FilePath& path, // bit set. BRILLO_EXPORT bool TouchFile(const base::FilePath& path); +// Opens the absolute |path| to a regular file or directory ensuring that none +// of the path components are symbolic links and returns a FD. If |path| is +// relative, or contains any symbolic links, or points to a non-regular file or +// directory, an invalid FD is returned instead. |mode| is ignored unless +// |flags| has either O_CREAT or O_TMPFILE. Note that O_CLOEXEC is set so the +// file descriptor will not be inherited across exec calls. +// +// Parameters +// path - An absolute path of the file to open +// flags - Flags to pass to open. +// mode - Mode to pass to open. +BRILLO_EXPORT base::ScopedFD OpenSafely(const base::FilePath& path, + int flags, + mode_t mode); + +// Opens the |path| relative to the |parent_fd| to a regular file or directory +// ensuring that none of the path components are symbolic links and returns a +// FD. If |path| contains any symbolic links, or points to a non-regular file or +// directory, an invalid FD is returned instead. |mode| is ignored unless +// |flags| has either O_CREAT or O_TMPFILE. Note that O_CLOEXEC is set so the +// file descriptor will not be inherited across exec calls. +// +// Parameters +// parent_fd - The file descriptor of the parent directory +// path - An absolute path of the file to open +// flags - Flags to pass to open. +// mode - Mode to pass to open. +BRILLO_EXPORT base::ScopedFD OpenAtSafely(int parent_fd, + const base::FilePath& path, + int flags, + mode_t mode); + +// Opens the absolute |path| to a FIFO ensuring that none of the path components +// are symbolic links and returns a FD. If |path| is relative, or contains any +// symbolic links, or points to a non-regular file or directory, an invalid FD +// is returned instead. |mode| is ignored unless |flags| has either O_CREAT or +// O_TMPFILE. +// +// Parameters +// path - An absolute path of the file to open +// flags - Flags to pass to open. +// mode - Mode to pass to open. +BRILLO_EXPORT base::ScopedFD OpenFifoSafely(const base::FilePath& path, + int flags, + mode_t mode); + +// Iterates through the path components and creates any missing ones. Guarantees +// the ancestor paths are not symlinks. This function returns an invalid FD on +// failure. Newly created directories will have |mode| permissions. The returned +// file descriptor was opened with both O_RDONLY and O_CLOEXEC. +// +// Parameters +// full_path - An absolute path of the directory to create and open. +BRILLO_EXPORT base::ScopedFD MkdirRecursively(const base::FilePath& full_path, + mode_t mode); + // Writes the entirety of the given data to |path| with 0640 permissions // (modulo umask). If missing, parent (and parent of parent etc.) directories // are created with 0700 permissions (modulo umask). Returns true on success. @@ -39,13 +98,16 @@ BRILLO_EXPORT bool TouchFile(const base::FilePath& path); // path - Path of the file to write // blob/data - blob/string/array to populate from // (size - array size) -BRILLO_EXPORT bool WriteBlobToFile(const base::FilePath& path, - const Blob& blob); BRILLO_EXPORT bool WriteStringToFile(const base::FilePath& path, const std::string& data); BRILLO_EXPORT bool WriteToFile(const base::FilePath& path, const char* data, size_t size); +template <class T> +BRILLO_EXPORT bool WriteBlobToFile(const base::FilePath& path, const T& blob) { + return WriteToFile(path, reinterpret_cast<const char*>(blob.data()), + blob.size()); +} // Calls fdatasync() on file if data_sync is true or fsync() on directory or // file when data_sync is false. Returns true on success. @@ -70,13 +132,59 @@ BRILLO_EXPORT bool SyncFileOrDirectory(const base::FilePath& path, // blob/data - blob/array to populate from // (size - array size) // mode - File permission bit-pattern, eg. 0644 for rw-r--r-- -BRILLO_EXPORT bool WriteBlobToFileAtomic(const base::FilePath& path, - const Blob& blob, - mode_t mode); BRILLO_EXPORT bool WriteToFileAtomic(const base::FilePath& path, const char* data, size_t size, mode_t mode); +template <class T> +BRILLO_EXPORT bool WriteBlobToFileAtomic(const base::FilePath& path, + const T& blob, + mode_t mode) { + return WriteToFileAtomic(path, reinterpret_cast<const char*>(blob.data()), + blob.size(), mode); +} + +// ComputeDirectoryDiskUsage() is similar to base::ComputeDirectorySize() in +// libbase, but it returns the actual disk usage instead of the apparent size. +// In another word, ComputeDirectoryDiskUsage() behaves like "du -s +// --apparent-size", and ComputeDirectorySize() behaves like "du -s". The +// primary difference is that sparse file and files on filesystem with +// transparent compression will report smaller file size than +// ComputeDirectorySize(). Returns the total used bytes. +// The following behaviours of this function is guaranteed and is verified by +// unit tests: +// - This function recursively processes directory down the tree, so disk space +// used by files in all the subdirectories are counted. +// - Symbolic links will not be followed (the size of link itself is counted, +// the target is not) +// - Hidden files are counted as well. +// The following behaviours are not guaranteed, and it is recommended to avoid +// them in the field. Their current behaviour is provided for reference only: +// - This function doesn't care about filesystem boundaries, so it'll cross +// filesystem boundary to count file size if there's one in the specified +// directory. +// - Hard links will be treated like normal files, so they could be +// over-reported. +// - Directories that the current user doesn't have permission to list/stat will +// be ignored, and an error will be logged but the returned result could be +// under-reported without error in the returned value. +// - Deduplication (should the filesystem support it) is ignored, and the result +// could be over-reported. +// - Doesn't check if |root_path| exists, a non-existent directory will results +// in 0 bytes without any error. +// - There are no limit on the depth of file system tree, the program will crash +// if it run out of memory to hold the entire depth of file system tree. +// - If the directory is modified during this function call, there's no +// guarantee on if the function will count the updated or original file system +// state. The function could choose to count the updated state for one file and +// original state for another file. +// - Non-POSIX system is not supported. +// - Disk space used by directory (and its subdirectories) itself is counted. +// +// Parameters +// root_path - The directory to compute the size for +BRILLO_EXPORT int64_t +ComputeDirectoryDiskUsage(const base::FilePath& root_path); } // namespace brillo diff --git a/brillo/file_utils_test.cc b/brillo/file_utils_test.cc new file mode 100644 index 0000000..3407cd1 --- /dev/null +++ b/brillo/file_utils_test.cc @@ -0,0 +1,508 @@ +// Copyright 2014 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/file_utils.h" + +#include <fcntl.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <string> + +#include <base/files/file_util.h> +#include <base/files/scoped_temp_dir.h> +#include <base/rand_util.h> +#include <base/stl_util.h> +#include <base/strings/string_number_conversions.h> +#include <gtest/gtest.h> + +namespace brillo { + +namespace { + +constexpr int kPermissions600 = + base::FILE_PERMISSION_READ_BY_USER | base::FILE_PERMISSION_WRITE_BY_USER; +constexpr int kPermissions700 = base::FILE_PERMISSION_USER_MASK; +constexpr int kPermissions777 = base::FILE_PERMISSION_MASK; +constexpr int kPermissions755 = S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH; + +std::string GetRandomSuffix() { + const int kBufferSize = 6; + unsigned char buffer[kBufferSize]; + base::RandBytes(buffer, base::size(buffer)); + return base::HexEncode(buffer, base::size(buffer)); +} + +bool IsNonBlockingFD(int fd) { + return fcntl(fd, F_GETFL) & O_NONBLOCK; +} + +} // namespace + +class FileUtilsTest : public testing::Test { + public: + FileUtilsTest() { + CHECK(temp_dir_.CreateUniqueTempDir()); + file_path_ = temp_dir_.GetPath().Append("test.temp"); + } + + protected: + base::FilePath file_path_; + base::ScopedTempDir temp_dir_; + + // Writes |contents| to |file_path_|. Pulled into a separate function just + // to improve readability of tests. + void WriteFile(const std::string& contents) { + EXPECT_EQ(contents.length(), + base::WriteFile(file_path_, contents.c_str(), contents.length())); + } + + // Verifies that the file at |file_path_| exists and contains |contents|. + void ExpectFileContains(const std::string& contents) { + EXPECT_TRUE(base::PathExists(file_path_)); + std::string new_contents; + EXPECT_TRUE(base::ReadFileToString(file_path_, &new_contents)); + EXPECT_EQ(contents, new_contents); + } + + // Verifies that the file at |file_path_| has |permissions|. + void ExpectFilePermissions(int permissions) { + int actual_permissions; + EXPECT_TRUE(base::GetPosixFilePermissions(file_path_, &actual_permissions)); + EXPECT_EQ(permissions, actual_permissions); + } + + // Creates a file with a random name in the temporary directory. + base::FilePath GetTempName() { + return temp_dir_.GetPath().Append(GetRandomSuffix()); + } +}; + +TEST_F(FileUtilsTest, TouchFileCreate) { + EXPECT_TRUE(TouchFile(file_path_)); + ExpectFileContains(""); + ExpectFilePermissions(kPermissions600); +} + +TEST_F(FileUtilsTest, TouchFileCreateThroughUmask) { + mode_t old_umask = umask(kPermissions777); + EXPECT_TRUE(TouchFile(file_path_)); + umask(old_umask); + ExpectFileContains(""); + ExpectFilePermissions(kPermissions600); +} + +TEST_F(FileUtilsTest, TouchFileCreateDirectoryStructure) { + file_path_ = temp_dir_.GetPath().Append("foo/bar/baz/test.temp"); + EXPECT_TRUE(TouchFile(file_path_)); + ExpectFileContains(""); +} + +TEST_F(FileUtilsTest, TouchFileExisting) { + WriteFile("abcd"); + EXPECT_TRUE(TouchFile(file_path_)); + ExpectFileContains("abcd"); +} + +TEST_F(FileUtilsTest, TouchFileReplaceDirectory) { + EXPECT_TRUE(base::CreateDirectory(file_path_)); + EXPECT_TRUE(TouchFile(file_path_)); + EXPECT_FALSE(base::DirectoryExists(file_path_)); + ExpectFileContains(""); +} + +TEST_F(FileUtilsTest, TouchFileReplaceSymlink) { + base::FilePath symlink_target = temp_dir_.GetPath().Append("target.temp"); + EXPECT_TRUE(base::CreateSymbolicLink(symlink_target, file_path_)); + EXPECT_TRUE(TouchFile(file_path_)); + EXPECT_FALSE(base::IsLink(file_path_)); + ExpectFileContains(""); +} + +TEST_F(FileUtilsTest, TouchFileReplaceOtherUser) { + WriteFile("abcd"); + EXPECT_TRUE(TouchFile(file_path_, kPermissions777, geteuid() + 1, getegid())); + ExpectFileContains(""); +} + +TEST_F(FileUtilsTest, TouchFileReplaceOtherGroup) { + WriteFile("abcd"); + EXPECT_TRUE(TouchFile(file_path_, kPermissions777, geteuid(), getegid() + 1)); + ExpectFileContains(""); +} + +TEST_F(FileUtilsTest, TouchFileCreateWithAllPermissions) { + EXPECT_TRUE(TouchFile(file_path_, kPermissions777, geteuid(), getegid())); + ExpectFileContains(""); + ExpectFilePermissions(kPermissions777); +} + +TEST_F(FileUtilsTest, TouchFileCreateWithOwnerPermissions) { + EXPECT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); + ExpectFileContains(""); + ExpectFilePermissions(kPermissions700); +} + +TEST_F(FileUtilsTest, TouchFileExistingPermissionsUnchanged) { + EXPECT_TRUE(TouchFile(file_path_, kPermissions777, geteuid(), getegid())); + EXPECT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); + ExpectFileContains(""); + ExpectFilePermissions(kPermissions777); +} + +// Other parts of OpenSafely are tested in Arcsetup.TestInstallDirectory*. +TEST_F(FileUtilsTest, TestOpenSafelyWithoutNonblocking) { + ASSERT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); + base::ScopedFD fd(OpenSafely(file_path_, O_RDONLY, 0)); + EXPECT_TRUE(fd.is_valid()); + EXPECT_FALSE(IsNonBlockingFD(fd.get())); +} + +TEST_F(FileUtilsTest, TestOpenSafelyWithNonblocking) { + ASSERT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); + base::ScopedFD fd = OpenSafely(file_path_, O_RDONLY | O_NONBLOCK, 0); + EXPECT_TRUE(fd.is_valid()); + EXPECT_TRUE(IsNonBlockingFD(fd.get())); +} + +TEST_F(FileUtilsTest, TestOpenFifoSafelySuccess) { + ASSERT_EQ(0, mkfifo(file_path_.value().c_str(), kPermissions700)); + base::ScopedFD fd(OpenFifoSafely(file_path_, O_RDONLY, 0)); + EXPECT_TRUE(fd.is_valid()); + EXPECT_FALSE(IsNonBlockingFD(fd.get())); +} + +TEST_F(FileUtilsTest, TestOpenFifoSafelyRegularFile) { + ASSERT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); + base::ScopedFD fd = OpenFifoSafely(file_path_, O_RDONLY, 0); + EXPECT_FALSE(fd.is_valid()); +} + +TEST_F(FileUtilsTest, TestMkdirRecursivelyRoot) { + // Try to create an existing directory ("/") should still succeed. + EXPECT_TRUE( + MkdirRecursively(base::FilePath("/"), kPermissions755).is_valid()); +} + +TEST_F(FileUtilsTest, TestMkdirRecursivelySuccess) { + // Set |temp_directory| to 0707. + EXPECT_TRUE(base::SetPosixFilePermissions(temp_dir_.GetPath(), 0707)); + + EXPECT_TRUE( + MkdirRecursively(temp_dir_.GetPath().Append("a/b/c"), kPermissions755) + .is_valid()); + // Confirm the 3 directories are there. + EXPECT_TRUE(base::DirectoryExists(temp_dir_.GetPath().Append("a"))); + EXPECT_TRUE(base::DirectoryExists(temp_dir_.GetPath().Append("a/b"))); + EXPECT_TRUE(base::DirectoryExists(temp_dir_.GetPath().Append("a/b/c"))); + + // Confirm that the newly created directories have 0755 mode. + int mode = 0; + EXPECT_TRUE( + base::GetPosixFilePermissions(temp_dir_.GetPath().Append("a"), &mode)); + EXPECT_EQ(kPermissions755, mode); + mode = 0; + EXPECT_TRUE( + base::GetPosixFilePermissions(temp_dir_.GetPath().Append("a/b"), &mode)); + EXPECT_EQ(kPermissions755, mode); + mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions(temp_dir_.GetPath().Append("a/b/c"), + &mode)); + EXPECT_EQ(kPermissions755, mode); + + // Confirm that the existing directory |temp_directory| still has 0707 mode. + mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions(temp_dir_.GetPath(), &mode)); + EXPECT_EQ(0707, mode); + + // Call the API again which should still succeed. + EXPECT_TRUE( + MkdirRecursively(temp_dir_.GetPath().Append("a/b/c"), kPermissions755) + .is_valid()); + EXPECT_TRUE( + MkdirRecursively(temp_dir_.GetPath().Append("a/b/c/d"), kPermissions755) + .is_valid()); + EXPECT_TRUE(base::DirectoryExists(temp_dir_.GetPath().Append("a/b/c/d"))); + mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions( + temp_dir_.GetPath().Append("a/b/c/d"), &mode)); + EXPECT_EQ(kPermissions755, mode); + + // Call the API again which should still succeed. + EXPECT_TRUE( + MkdirRecursively(temp_dir_.GetPath().Append("a/b"), kPermissions755) + .is_valid()); + EXPECT_TRUE(MkdirRecursively(temp_dir_.GetPath().Append("a"), kPermissions755) + .is_valid()); +} + +TEST_F(FileUtilsTest, TestMkdirRecursivelyRelativePath) { + // Try to pass a relative or empty directory. They should all fail. + EXPECT_FALSE( + MkdirRecursively(base::FilePath("foo"), kPermissions755).is_valid()); + EXPECT_FALSE( + MkdirRecursively(base::FilePath("bar/"), kPermissions755).is_valid()); + EXPECT_FALSE(MkdirRecursively(base::FilePath(), kPermissions755).is_valid()); +} + +TEST_F(FileUtilsTest, WriteFileCanBeReadBack) { + const base::FilePath filename(GetTempName()); + const std::string content("blablabla"); + EXPECT_TRUE(WriteStringToFile(filename, content)); + std::string output; + EXPECT_TRUE(ReadFileToString(filename, &output)); + EXPECT_EQ(content, output); +} + +TEST_F(FileUtilsTest, WriteFileSets0666) { + const mode_t mask = 0000; + const mode_t mode = 0666; + const base::FilePath filename(GetTempName()); + const std::string content("blablabla"); + const mode_t old_mask = umask(mask); + EXPECT_TRUE(WriteStringToFile(filename, content)); + int file_mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions(filename, &file_mode)); + EXPECT_EQ(mode & ~mask, file_mode & 0777); + umask(old_mask); +} + +TEST_F(FileUtilsTest, WriteFileCreatesMissingParentDirectoriesWith0700) { + const mode_t mask = 0000; + const mode_t mode = 0700; + const base::FilePath dirname(GetTempName()); + const base::FilePath subdirname(dirname.Append(GetRandomSuffix())); + const base::FilePath filename(subdirname.Append(GetRandomSuffix())); + const std::string content("blablabla"); + EXPECT_TRUE(WriteStringToFile(filename, content)); + int dir_mode = 0; + int subdir_mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions(dirname, &dir_mode)); + EXPECT_TRUE(base::GetPosixFilePermissions(subdirname, &subdir_mode)); + EXPECT_EQ(mode & ~mask, dir_mode & 0777); + EXPECT_EQ(mode & ~mask, subdir_mode & 0777); + const mode_t old_mask = umask(mask); + umask(old_mask); +} + +TEST_F(FileUtilsTest, WriteToFileAtomicCanBeReadBack) { + const base::FilePath filename(GetTempName()); + const std::string content("blablabla"); + EXPECT_TRUE( + WriteToFileAtomic(filename, content.data(), content.size(), 0644)); + std::string output; + EXPECT_TRUE(ReadFileToString(filename, &output)); + EXPECT_EQ(content, output); +} + +TEST_F(FileUtilsTest, WriteToFileAtomicHonorsMode) { + const mode_t mask = 0000; + const mode_t mode = 0616; + const base::FilePath filename(GetTempName()); + const std::string content("blablabla"); + const mode_t old_mask = umask(mask); + EXPECT_TRUE( + WriteToFileAtomic(filename, content.data(), content.size(), mode)); + int file_mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions(filename, &file_mode)); + EXPECT_EQ(mode & ~mask, file_mode & 0777); + umask(old_mask); +} + +TEST_F(FileUtilsTest, WriteToFileAtomicHonorsUmask) { + const mode_t mask = 0073; + const mode_t mode = 0777; + const base::FilePath filename(GetTempName()); + const std::string content("blablabla"); + const mode_t old_mask = umask(mask); + EXPECT_TRUE( + WriteToFileAtomic(filename, content.data(), content.size(), mode)); + int file_mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions(filename, &file_mode)); + EXPECT_EQ(mode & ~mask, file_mode & 0777); + umask(old_mask); +} + +TEST_F(FileUtilsTest, + WriteToFileAtomicCreatesMissingParentDirectoriesWith0700) { + const mode_t mask = 0000; + const mode_t mode = 0700; + const base::FilePath dirname(GetTempName()); + const base::FilePath subdirname(dirname.Append(GetRandomSuffix())); + const base::FilePath filename(subdirname.Append(GetRandomSuffix())); + const std::string content("blablabla"); + EXPECT_TRUE( + WriteToFileAtomic(filename, content.data(), content.size(), 0777)); + int dir_mode = 0; + int subdir_mode = 0; + EXPECT_TRUE(base::GetPosixFilePermissions(dirname, &dir_mode)); + EXPECT_TRUE(base::GetPosixFilePermissions(subdirname, &subdir_mode)); + EXPECT_EQ(mode & ~mask, dir_mode & 0777); + EXPECT_EQ(mode & ~mask, subdir_mode & 0777); + const mode_t old_mask = umask(mask); + umask(old_mask); +} + +TEST_F(FileUtilsTest, ComputeDirectoryDiskUsageNormalRandomFile) { + // 2MB test file. + constexpr size_t kFileSize = 2 * 1024 * 1024; + + const base::FilePath dirname(GetTempName()); + EXPECT_TRUE(base::CreateDirectory(dirname)); + const base::FilePath filename = dirname.Append("test.temp"); + + std::string file_content = base::RandBytesAsString(kFileSize); + EXPECT_TRUE(WriteStringToFile(filename, file_content)); + + int64_t result_usage = ComputeDirectoryDiskUsage(dirname); + int64_t result_size = base::ComputeDirectorySize(dirname); + + // result_usage (what we are testing here) should be within +/-10% of ground + // truth. The variation is to account for filesystem overhead variations. + EXPECT_GT(result_usage, kFileSize / 10 * 9); + EXPECT_LT(result_usage, kFileSize / 10 * 11); + + // result_usage should be close to result_size, because the test file is + // random so it's disk usage should be similar to apparent size. + EXPECT_GT(result_usage, result_size / 10 * 9); + EXPECT_LT(result_usage, result_size / 10 * 11); +} + +TEST_F(FileUtilsTest, ComputeDirectoryDiskUsageDeepRandomFile) { + // 2MB test file. + constexpr size_t kFileSize = 2 * 1024 * 1024; + + const base::FilePath dirname(GetTempName()); + EXPECT_TRUE(base::CreateDirectory(dirname)); + base::FilePath currentlevel = dirname; + for (int i = 0; i < 10; i++) { + base::FilePath nextlevel = currentlevel.Append("test.dir"); + EXPECT_TRUE(base::CreateDirectory(nextlevel)); + currentlevel = nextlevel; + } + const base::FilePath filename = currentlevel.Append("test.temp"); + + std::string file_content = base::RandBytesAsString(kFileSize); + EXPECT_TRUE(WriteStringToFile(filename, file_content)); + + int64_t result_usage = ComputeDirectoryDiskUsage(dirname); + int64_t result_size = base::ComputeDirectorySize(dirname); + + // result_usage (what we are testing here) should be within +/-10% of ground + // truth. The variation is to account for filesystem overhead variations. + EXPECT_GT(result_usage, kFileSize / 10 * 9); + EXPECT_LT(result_usage, kFileSize / 10 * 11); + + // result_usage should be close to result_size, because the test file is + // random so it's disk usage should be similar to apparent size. + EXPECT_GT(result_usage, result_size / 10 * 9); + EXPECT_LT(result_usage, result_size / 10 * 11); +} + +TEST_F(FileUtilsTest, ComputeDirectoryDiskUsageHiddenRandomFile) { + // 2MB test file. + constexpr size_t kFileSize = 2 * 1024 * 1024; + + const base::FilePath dirname(GetTempName()); + EXPECT_TRUE(base::CreateDirectory(dirname)); + // File name starts with a dot, so it's a hidden file. + const base::FilePath filename = dirname.Append(".test.temp"); + + std::string file_content = base::RandBytesAsString(kFileSize); + EXPECT_TRUE(WriteStringToFile(filename, file_content)); + + int64_t result_usage = ComputeDirectoryDiskUsage(dirname); + int64_t result_size = base::ComputeDirectorySize(dirname); + + // result_usage (what we are testing here) should be within +/-10% of ground + // truth. The variation is to account for filesystem overhead variations. + EXPECT_GT(result_usage, kFileSize / 10 * 9); + EXPECT_LT(result_usage, kFileSize / 10 * 11); + + // result_usage should be close to result_size, because the test file is + // random so it's disk usage should be similar to apparent size. + EXPECT_GT(result_usage, result_size / 10 * 9); + EXPECT_LT(result_usage, result_size / 10 * 11); +} + +TEST_F(FileUtilsTest, ComputeDirectoryDiskUsageSparseFile) { + // 128MB sparse test file. + constexpr size_t kFileSize = 128 * 1024 * 1024; + constexpr size_t kFileSizeThreshold = 64 * 1024; + + const base::FilePath dirname(GetTempName()); + EXPECT_TRUE(base::CreateDirectory(dirname)); + const base::FilePath filename = dirname.Append("test.temp"); + + int fd = + open(filename.value().c_str(), O_CREAT | O_WRONLY, S_IRUSR | S_IWUSR); + EXPECT_NE(fd, -1); + // Calling ftruncate on an empty file will create a sparse file. + EXPECT_EQ(0, ftruncate(fd, kFileSize)); + + int64_t result_usage = ComputeDirectoryDiskUsage(dirname); + int64_t result_size = base::ComputeDirectorySize(dirname); + + // result_usage (what we are testing here) should be less than + // kFileSizeThreshold, the threshold is to account for filesystem overhead + // variations. + EXPECT_LT(result_usage, kFileSizeThreshold); + + // Since we are dealing with sparse files here, the apparent size should be + // much much larger than the actual disk usage. + EXPECT_LT(result_usage, result_size / 1000); +} + +TEST_F(FileUtilsTest, ComputeDirectoryDiskUsageSymlinkFile) { + // 2MB test file. + constexpr size_t kFileSize = 2 * 1024 * 1024; + + const base::FilePath dirname(GetTempName()); + EXPECT_TRUE(base::CreateDirectory(dirname)); + const base::FilePath filename = dirname.Append("test.temp"); + const base::FilePath linkname = dirname.Append("test.link"); + + std::string file_content = base::RandBytesAsString(kFileSize); + EXPECT_TRUE(WriteStringToFile(filename, file_content)); + + // Create a symlink. + EXPECT_TRUE(base::CreateSymbolicLink(filename, linkname)); + + int64_t result_usage = ComputeDirectoryDiskUsage(dirname); + + // result_usage (what we are testing here) should be within +/-10% of ground + // truth. The variation is to account for filesystem overhead variations. + // Note that it's not 2x kFileSize because symblink is not counted twice. + EXPECT_GT(result_usage, kFileSize / 10 * 9); + EXPECT_LT(result_usage, kFileSize / 10 * 11); +} + +TEST_F(FileUtilsTest, ComputeDirectoryDiskUsageSymlinkDir) { + // 2MB test file. + constexpr size_t kFileSize = 2 * 1024 * 1024; + + const base::FilePath parentname(GetTempName()); + EXPECT_TRUE(base::CreateDirectory(parentname)); + const base::FilePath dirname = parentname.Append("target.dir"); + EXPECT_TRUE(base::CreateDirectory(dirname)); + const base::FilePath linkname = parentname.Append("link.dir"); + + const base::FilePath filename = dirname.Append("test.temp"); + + std::string file_content = base::RandBytesAsString(kFileSize); + EXPECT_TRUE(WriteStringToFile(filename, file_content)); + + // Create a symlink. + EXPECT_TRUE(base::CreateSymbolicLink(dirname, linkname)); + + int64_t result_usage = ComputeDirectoryDiskUsage(dirname); + + // result_usage (what we are testing here) should be within +/-10% of ground + // truth. The variation is to account for filesystem overhead variations. + // Note that it's not 2x kFileSize because symblink is not counted twice. + EXPECT_GT(result_usage, kFileSize / 10 * 9); + EXPECT_LT(result_usage, kFileSize / 10 * 11); +} + +} // namespace brillo diff --git a/brillo/file_utils_unittest.cc b/brillo/file_utils_unittest.cc deleted file mode 100644 index 7a730f0..0000000 --- a/brillo/file_utils_unittest.cc +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright 2014 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "brillo/file_utils.h" - -#include <sys/stat.h> -#include <unistd.h> - -#include <string> - -#include <base/files/file_util.h> -#include <base/files/scoped_temp_dir.h> -#include <base/rand_util.h> -#include <base/strings/string_number_conversions.h> -#include <gtest/gtest.h> - -namespace brillo { - -namespace { - -constexpr int kPermissions600 = - base::FILE_PERMISSION_READ_BY_USER | base::FILE_PERMISSION_WRITE_BY_USER; -constexpr int kPermissions700 = base::FILE_PERMISSION_USER_MASK; -constexpr int kPermissions777 = base::FILE_PERMISSION_MASK; - -std::string GetRandomSuffix() { - const int kBufferSize = 6; - unsigned char buffer[kBufferSize]; - base::RandBytes(buffer, arraysize(buffer)); - return base::HexEncode(buffer, arraysize(buffer)); -} - -} // namespace - -class FileUtilsTest : public testing::Test { - public: - FileUtilsTest() { - CHECK(temp_dir_.CreateUniqueTempDir()); - file_path_ = temp_dir_.GetPath().Append("test.temp"); - } - - protected: - base::FilePath file_path_; - base::ScopedTempDir temp_dir_; - - // Writes |contents| to |file_path_|. Pulled into a separate function just - // to improve readability of tests. - void WriteFile(const std::string& contents) { - EXPECT_EQ(contents.length(), - base::WriteFile(file_path_, contents.c_str(), contents.length())); - } - - // Verifies that the file at |file_path_| exists and contains |contents|. - void ExpectFileContains(const std::string& contents) { - EXPECT_TRUE(base::PathExists(file_path_)); - std::string new_contents; - EXPECT_TRUE(base::ReadFileToString(file_path_, &new_contents)); - EXPECT_EQ(contents, new_contents); - } - - // Verifies that the file at |file_path_| has |permissions|. - void ExpectFilePermissions(int permissions) { - int actual_permissions; - EXPECT_TRUE(base::GetPosixFilePermissions(file_path_, &actual_permissions)); - EXPECT_EQ(permissions, actual_permissions); - } - - // Creates a file with a random name in the temporary directory. - base::FilePath GetTempName() { - return temp_dir_.GetPath().Append(GetRandomSuffix()); - } -}; - -TEST_F(FileUtilsTest, TouchFileCreate) { - EXPECT_TRUE(TouchFile(file_path_)); - ExpectFileContains(""); - ExpectFilePermissions(kPermissions600); -} - -TEST_F(FileUtilsTest, TouchFileCreateThroughUmask) { - mode_t old_umask = umask(kPermissions777); - EXPECT_TRUE(TouchFile(file_path_)); - umask(old_umask); - ExpectFileContains(""); - ExpectFilePermissions(kPermissions600); -} - -TEST_F(FileUtilsTest, TouchFileCreateDirectoryStructure) { - file_path_ = temp_dir_.GetPath().Append("foo/bar/baz/test.temp"); - EXPECT_TRUE(TouchFile(file_path_)); - ExpectFileContains(""); -} - -TEST_F(FileUtilsTest, TouchFileExisting) { - WriteFile("abcd"); - EXPECT_TRUE(TouchFile(file_path_)); - ExpectFileContains("abcd"); -} - -TEST_F(FileUtilsTest, TouchFileReplaceDirectory) { - EXPECT_TRUE(base::CreateDirectory(file_path_)); - EXPECT_TRUE(TouchFile(file_path_)); - EXPECT_FALSE(base::DirectoryExists(file_path_)); - ExpectFileContains(""); -} - -TEST_F(FileUtilsTest, TouchFileReplaceSymlink) { - base::FilePath symlink_target = temp_dir_.GetPath().Append("target.temp"); - EXPECT_TRUE(base::CreateSymbolicLink(symlink_target, file_path_)); - EXPECT_TRUE(TouchFile(file_path_)); - EXPECT_FALSE(base::IsLink(file_path_)); - ExpectFileContains(""); -} - -TEST_F(FileUtilsTest, TouchFileReplaceOtherUser) { - WriteFile("abcd"); - EXPECT_TRUE(TouchFile(file_path_, kPermissions777, geteuid() + 1, getegid())); - ExpectFileContains(""); -} - -TEST_F(FileUtilsTest, TouchFileReplaceOtherGroup) { - WriteFile("abcd"); - EXPECT_TRUE(TouchFile(file_path_, kPermissions777, geteuid(), getegid() + 1)); - ExpectFileContains(""); -} - -TEST_F(FileUtilsTest, TouchFileCreateWithAllPermissions) { - EXPECT_TRUE(TouchFile(file_path_, kPermissions777, geteuid(), getegid())); - ExpectFileContains(""); - ExpectFilePermissions(kPermissions777); -} - -TEST_F(FileUtilsTest, TouchFileCreateWithOwnerPermissions) { - EXPECT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); - ExpectFileContains(""); - ExpectFilePermissions(kPermissions700); -} - -TEST_F(FileUtilsTest, TouchFileExistingPermissionsUnchanged) { - EXPECT_TRUE(TouchFile(file_path_, kPermissions777, geteuid(), getegid())); - EXPECT_TRUE(TouchFile(file_path_, kPermissions700, geteuid(), getegid())); - ExpectFileContains(""); - ExpectFilePermissions(kPermissions777); -} - -TEST_F(FileUtilsTest, WriteFileCanBeReadBack) { - const base::FilePath filename(GetTempName()); - const std::string content("blablabla"); - EXPECT_TRUE(WriteStringToFile(filename, content)); - std::string output; - EXPECT_TRUE(ReadFileToString(filename, &output)); - EXPECT_EQ(content, output); -} - -TEST_F(FileUtilsTest, WriteFileSets0666) { - const mode_t mask = 0000; - const mode_t mode = 0666; - const base::FilePath filename(GetTempName()); - const std::string content("blablabla"); - const mode_t old_mask = umask(mask); - EXPECT_TRUE(WriteStringToFile(filename, content)); - int file_mode = 0; - EXPECT_TRUE(base::GetPosixFilePermissions(filename, &file_mode)); - EXPECT_EQ(mode & ~mask, file_mode & 0777); - umask(old_mask); -} - -TEST_F(FileUtilsTest, WriteFileCreatesMissingParentDirectoriesWith0700) { - const mode_t mask = 0000; - const mode_t mode = 0700; - const base::FilePath dirname(GetTempName()); - const base::FilePath subdirname(dirname.Append(GetRandomSuffix())); - const base::FilePath filename(subdirname.Append(GetRandomSuffix())); - const std::string content("blablabla"); - EXPECT_TRUE(WriteStringToFile(filename, content)); - int dir_mode = 0; - int subdir_mode = 0; - EXPECT_TRUE(base::GetPosixFilePermissions(dirname, &dir_mode)); - EXPECT_TRUE(base::GetPosixFilePermissions(subdirname, &subdir_mode)); - EXPECT_EQ(mode & ~mask, dir_mode & 0777); - EXPECT_EQ(mode & ~mask, subdir_mode & 0777); - const mode_t old_mask = umask(mask); - umask(old_mask); -} - -TEST_F(FileUtilsTest, WriteToFileAtomicCanBeReadBack) { - const base::FilePath filename(GetTempName()); - const std::string content("blablabla"); - EXPECT_TRUE( - WriteToFileAtomic(filename, content.data(), content.size(), 0644)); - std::string output; - EXPECT_TRUE(ReadFileToString(filename, &output)); - EXPECT_EQ(content, output); -} - -TEST_F(FileUtilsTest, WriteToFileAtomicHonorsMode) { - const mode_t mask = 0000; - const mode_t mode = 0616; - const base::FilePath filename(GetTempName()); - const std::string content("blablabla"); - const mode_t old_mask = umask(mask); - EXPECT_TRUE( - WriteToFileAtomic(filename, content.data(), content.size(), mode)); - int file_mode = 0; - EXPECT_TRUE(base::GetPosixFilePermissions(filename, &file_mode)); - EXPECT_EQ(mode & ~mask, file_mode & 0777); - umask(old_mask); -} - -TEST_F(FileUtilsTest, WriteToFileAtomicHonorsUmask) { - const mode_t mask = 0073; - const mode_t mode = 0777; - const base::FilePath filename(GetTempName()); - const std::string content("blablabla"); - const mode_t old_mask = umask(mask); - EXPECT_TRUE( - WriteToFileAtomic(filename, content.data(), content.size(), mode)); - int file_mode = 0; - EXPECT_TRUE(base::GetPosixFilePermissions(filename, &file_mode)); - EXPECT_EQ(mode & ~mask, file_mode & 0777); - umask(old_mask); -} - -TEST_F(FileUtilsTest, - WriteToFileAtomicCreatesMissingParentDirectoriesWith0700) { - const mode_t mask = 0000; - const mode_t mode = 0700; - const base::FilePath dirname(GetTempName()); - const base::FilePath subdirname(dirname.Append(GetRandomSuffix())); - const base::FilePath filename(subdirname.Append(GetRandomSuffix())); - const std::string content("blablabla"); - EXPECT_TRUE( - WriteToFileAtomic(filename, content.data(), content.size(), 0777)); - int dir_mode = 0; - int subdir_mode = 0; - EXPECT_TRUE(base::GetPosixFilePermissions(dirname, &dir_mode)); - EXPECT_TRUE(base::GetPosixFilePermissions(subdirname, &subdir_mode)); - EXPECT_EQ(mode & ~mask, dir_mode & 0777); - EXPECT_EQ(mode & ~mask, subdir_mode & 0777); - const mode_t old_mask = umask(mask); - umask(old_mask); -} - -} // namespace brillo diff --git a/brillo/files/OWNERS b/brillo/files/OWNERS new file mode 100644 index 0000000..da09356 --- /dev/null +++ b/brillo/files/OWNERS @@ -0,0 +1,3 @@ +allenwebb@chromium.org +jorgelo@chromium.org +mnissler@chromium.org diff --git a/brillo/files/file_util.cc b/brillo/files/file_util.cc new file mode 100644 index 0000000..c642d14 --- /dev/null +++ b/brillo/files/file_util.cc @@ -0,0 +1,112 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/files/file_util.h" + +#include <fcntl.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <utility> + +#include <base/files/file_util.h> +#include <base/logging.h> +#include <base/strings/stringprintf.h> +#include <brillo/syslog_logging.h> + +namespace brillo { + +namespace { + +enum class FSObjectType { + RegularFile = 0, + Directory, +}; + +SafeFD::SafeFDResult OpenOrRemake(SafeFD* parent, + const std::string& name, + FSObjectType type, + int permissions, + uid_t uid, + gid_t gid, + int flags) { + SafeFD::Error err = IsValidFilename(name); + if (SafeFD::IsError(err)) { + return std::make_pair(SafeFD(), err); + } + + SafeFD::SafeFDResult (SafeFD::*maker)(const base::FilePath&, mode_t, uid_t, + gid_t, int); + if (type == FSObjectType::Directory) { + maker = &SafeFD::MakeDir; + } else { + maker = &SafeFD::MakeFile; + } + + SafeFD child; + std::tie(child, err) = + (parent->*maker)(base::FilePath(name), permissions, uid, gid, flags); + if (child.is_valid()) { + return std::make_pair(std::move(child), err); + } + + // Rmdir should be used on directories. However, kWrongType indicates when + // a directory was expected and a non-directory was found or when a + // directory was found but not expected, so XOR was used. + if ((type == FSObjectType::Directory) ^ (err == SafeFD::Error::kWrongType)) { + err = parent->Rmdir(name, true /*recursive*/); + } else { + err = parent->Unlink(name); + } + if (SafeFD::IsError(err)) { + PLOG(ERROR) << "Failed to clean up \"" << name << "\""; + return std::make_pair(SafeFD(), err); + } + + std::tie(child, err) = + (parent->*maker)(base::FilePath(name), permissions, uid, gid, flags); + return std::make_pair(std::move(child), err); +} + +} // namespace + +SafeFD::Error IsValidFilename(const std::string& filename) { + if (filename == "." || filename == ".." || + filename.find("/") != std::string::npos) { + return SafeFD::Error::kBadArgument; + } + return SafeFD::Error::kNoError; +} + +base::FilePath GetFDPath(int fd) { + const base::FilePath proc_fd(base::StringPrintf("/proc/self/fd/%d", fd)); + base::FilePath resolved; + if (!base::ReadSymbolicLink(proc_fd, &resolved)) { + LOG(ERROR) << "Failed to read " << proc_fd.value(); + return base::FilePath(); + } + return resolved; +} + +SafeFD::SafeFDResult OpenOrRemakeDir(SafeFD* parent, + const std::string& name, + int permissions, + uid_t uid, + gid_t gid, + int flags) { + return OpenOrRemake(parent, name, FSObjectType::Directory, permissions, uid, + gid, flags); +} + +SafeFD::SafeFDResult OpenOrRemakeFile(SafeFD* parent, + const std::string& name, + int permissions, + uid_t uid, + gid_t gid, + int flags) { + return OpenOrRemake(parent, name, FSObjectType::RegularFile, permissions, uid, + gid, flags); +} + +} // namespace brillo diff --git a/brillo/files/file_util.h b/brillo/files/file_util.h new file mode 100644 index 0000000..c020667 --- /dev/null +++ b/brillo/files/file_util.h @@ -0,0 +1,62 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Filesystem-related utility functions. + +#ifndef LIBBRILLO_BRILLO_FILES_FILE_UTIL_H_ +#define LIBBRILLO_BRILLO_FILES_FILE_UTIL_H_ + +#include <string> + +#include <brillo/files/safe_fd.h> + +namespace brillo { + +SafeFD::Error IsValidFilename(const std::string& filename); + +// Obtain the canonical path of the file descriptor or base::FilePath() on +// failure. +BRILLO_EXPORT base::FilePath GetFDPath(int fd); + +// Open or create a child directory named |name| as a child of |parent| with +// the specified permissions and ownership. Custom open flags can be set with +// |flags|. The directory will be re-created if: +// * The open operation fails (e.g. if |name| is not a directory). +// * The permissions do not match. +// * The ownership is different. +// +// Parameters +// parent - An open SafeFD to the parent directory. +// name - the name of the directory being created. It cannot have more than one +// path component. +BRILLO_EXPORT SafeFD::SafeFDResult OpenOrRemakeDir( + SafeFD* parent, + const std::string& name, + int permissions = SafeFD::kDefaultDirPermissions, + uid_t uid = getuid(), + gid_t gid = getgid(), + int flags = O_RDONLY | O_CLOEXEC); + +// Open or create a file named |name| under the directory |parent| with +// the specified permissions and ownership. Custom open flags can be set with +// |flags|. The file will be re-created if: +// * The open operation fails (e.g. |name| is a directory). +// * The permissions do not match. +// * The ownership is different. +// +// Parameters +// parent - An open SafeFD to the parent directory. +// name - the name of the file being created. It cannot have more than one +// path component. +BRILLO_EXPORT SafeFD::SafeFDResult OpenOrRemakeFile( + SafeFD* parent, + const std::string& name, + int permissions = SafeFD::kDefaultFilePermissions, + uid_t uid = getuid(), + gid_t gid = getgid(), + int flags = O_RDWR | O_CLOEXEC); + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_FILES_FILE_UTIL_H_ diff --git a/brillo/files/file_util_test.cc b/brillo/files/file_util_test.cc new file mode 100644 index 0000000..f1ba527 --- /dev/null +++ b/brillo/files/file_util_test.cc @@ -0,0 +1,284 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/files/file_util_test.h" + +#include <base/files/file_util.h> +#include <base/rand_util.h> +#include <base/stl_util.h> +#include <base/strings/string_number_conversions.h> +#include <brillo/files/file_util.h> +#include <brillo/files/safe_fd.h> + +namespace brillo { + +#define TO_STRING_HELPER(x) \ + case brillo::SafeFD::Error::x: \ + return #x; +std::string to_string(brillo::SafeFD::Error err) { + switch (err) { + TO_STRING_HELPER(kNoError) + TO_STRING_HELPER(kBadArgument) + TO_STRING_HELPER(kNotInitialized) + TO_STRING_HELPER(kIOError) + TO_STRING_HELPER(kDoesNotExist) + TO_STRING_HELPER(kSymlinkDetected) + TO_STRING_HELPER(kWrongType) + TO_STRING_HELPER(kWrongUID) + TO_STRING_HELPER(kWrongGID) + TO_STRING_HELPER(kWrongPermissions) + TO_STRING_HELPER(kExceededMaximum) + default: + return std::string("unknown (") + std::to_string(static_cast<int>(err)) + + ")"; + } +} +#undef TO_STRING_HELPER + +std::ostream& operator<<(std::ostream& os, const brillo::SafeFD::Error err) { + return os << to_string(err); // whatever needed to print bar to os +} + +std::string GetRandomSuffix() { + const int kBufferSize = 6; + unsigned char buffer[kBufferSize]; + base::RandBytes(buffer, base::size(buffer)); + return base::HexEncode(buffer, base::size(buffer)); +} + +void FileTest::SetUpTestCase() { + umask(0); +} + +FileTest::FileTest() { + CHECK(temp_dir_.CreateUniqueTempDir()) << strerror(errno); + sub_dir_path_ = temp_dir_.GetPath().Append(kSubdirName); + file_path_ = sub_dir_path_.Append(kFileName); + + std::string path = temp_dir_.GetPath().value(); + temp_dir_path_.reserve(path.size() + 1); + temp_dir_path_.assign(temp_dir_.GetPath().value().begin(), + temp_dir_.GetPath().value().end()); + temp_dir_path_.push_back('\0'); + + CHECK_EQ(chmod(temp_dir_path_.data(), SafeFD::kDefaultDirPermissions), 0); + SafeFD::SetRootPathForTesting(temp_dir_path_.data()); + root_ = SafeFD::Root().first; + CHECK(root_.is_valid()); +} + +bool FileTest::SetupSubdir() { + if (!base::CreateDirectory(sub_dir_path_)) { + PLOG(ERROR) << "Failed to create '" << sub_dir_path_.value() << "'"; + return false; + } + if (chmod(sub_dir_path_.value().c_str(), SafeFD::kDefaultDirPermissions) != + 0) { + PLOG(ERROR) << "Failed to set permissions of '" << sub_dir_path_.value() + << "'"; + return false; + } + return true; +} + +bool FileTest::SetupSymlinks() { + symlink_file_path_ = temp_dir_.GetPath().Append(kSymbolicFileName); + symlink_dir_path_ = temp_dir_.GetPath().Append(kSymbolicDirName); + if (!base::CreateSymbolicLink(file_path_, symlink_file_path_)) { + PLOG(ERROR) << "Failed to create symlink to '" << symlink_file_path_.value() + << "'"; + return false; + } + if (!base::CreateSymbolicLink(temp_dir_.GetPath(), symlink_dir_path_)) { + PLOG(ERROR) << "Failed to create symlink to'" << symlink_dir_path_.value() + << "'"; + return false; + } + return true; +} + +bool FileTest::WriteFile(const std::string& contents) { + if (!SetupSubdir()) { + return false; + } + if (contents.length() != + base::WriteFile(file_path_, contents.c_str(), contents.length())) { + PLOG(ERROR) << "base::WriteFile failed"; + return false; + } + if (chmod(file_path_.value().c_str(), SafeFD::kDefaultFilePermissions) != 0) { + PLOG(ERROR) << "chmod failed"; + return false; + } + return true; +} + +void FileTest::ExpectFileContains(const std::string& contents) { + EXPECT_TRUE(base::PathExists(file_path_)); + std::string new_contents; + EXPECT_TRUE(base::ReadFileToString(file_path_, &new_contents)); + EXPECT_EQ(contents, new_contents); +} + +void FileTest::ExpectPermissions(base::FilePath path, int permissions) { + int actual_permissions = 0; + // This breaks out of the ExpectPermissions() call but not the test case. + ASSERT_TRUE(base::GetPosixFilePermissions(path, &actual_permissions)); + EXPECT_EQ(permissions, actual_permissions); +} + +// Creates a file with a random name in the temporary directory. +base::FilePath FileTest::GetTempName() { + return temp_dir_.GetPath().Append(GetRandomSuffix()); +} + +constexpr char FileTest::kFileName[]; +constexpr char FileTest::kSubdirName[]; +constexpr char FileTest::kSymbolicFileName[]; +constexpr char FileTest::kSymbolicDirName[]; + +class FileUtilTest : public FileTest {}; + +TEST_F(FileUtilTest, GetFDPath_SimpleSuccess) { + EXPECT_EQ(GetFDPath(root_.get()), temp_dir_.GetPath()); +} + +TEST_F(FileUtilTest, GetFDPath_BadFD) { + base::FilePath path = GetFDPath(-1); + EXPECT_TRUE(path.empty()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_SimpleSuccess) { + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, kSubdirName); + EXPECT_EQ(err, SafeFD::Error::kNoError); + EXPECT_TRUE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_SuccessAfterRetry) { + ASSERT_NE(base::WriteFile(sub_dir_path_, "", 0), -1); + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, kSubdirName); + EXPECT_EQ(err, SafeFD::Error::kNoError); + EXPECT_TRUE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_BadArgument) { + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, "."); + EXPECT_EQ(err, SafeFD::Error::kBadArgument); + EXPECT_FALSE(subdir.is_valid()); + std::tie(subdir, err) = OpenOrRemakeDir(&dir, ".."); + EXPECT_EQ(err, SafeFD::Error::kBadArgument); + EXPECT_FALSE(subdir.is_valid()); + std::tie(subdir, err) = OpenOrRemakeDir(&dir, "a/a"); + EXPECT_EQ(err, SafeFD::Error::kBadArgument); + EXPECT_FALSE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_NotInitialized) { + SafeFD::Error err; + SafeFD dir; + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, kSubdirName); + EXPECT_EQ(err, SafeFD::Error::kNotInitialized); + EXPECT_FALSE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeDir_IOError) { + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + ASSERT_EQ(chmod(temp_dir_path_.data(), 0000), 0); + + SafeFD subdir; + std::tie(subdir, err) = OpenOrRemakeDir(&dir, kSubdirName); + EXPECT_EQ(err, SafeFD::Error::kIOError); + EXPECT_FALSE(subdir.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeFile_SimpleSuccess) { + ASSERT_TRUE(SetupSubdir()); + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD file; + std::tie(file, err) = OpenOrRemakeFile(&dir, kFileName); + EXPECT_EQ(err, SafeFD::Error::kNoError); + EXPECT_TRUE(file.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeFile_SuccessAfterRetry) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_TRUE(base::CreateDirectory(file_path_)); + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + + SafeFD file; + std::tie(file, err) = OpenOrRemakeFile(&dir, kFileName); + EXPECT_EQ(err, SafeFD::Error::kNoError); + EXPECT_TRUE(file.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeFile_NotInitialized) { + ASSERT_TRUE(SetupSubdir()); + SafeFD::Error err; + SafeFD dir; + + SafeFD file; + std::tie(file, err) = OpenOrRemakeFile(&dir, kFileName); + EXPECT_EQ(err, SafeFD::Error::kNotInitialized); + EXPECT_FALSE(file.is_valid()); +} + +TEST_F(FileUtilTest, OpenOrRemakeFile_IOError) { + ASSERT_TRUE(SetupSubdir()); + SafeFD::Error err; + SafeFD dir; + + std::tie(dir, err) = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + ASSERT_EQ(chmod(sub_dir_path_.value().c_str(), 0000), 0); + + SafeFD file; + std::tie(file, err) = OpenOrRemakeFile(&dir, kFileName); + EXPECT_EQ(err, SafeFD::Error::kIOError); + EXPECT_FALSE(file.is_valid()); +} + +} // namespace brillo diff --git a/brillo/files/file_util_test.h b/brillo/files/file_util_test.h new file mode 100644 index 0000000..182cdf4 --- /dev/null +++ b/brillo/files/file_util_test.h @@ -0,0 +1,70 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Filesystem-related utility functions. + +#ifndef LIBBRILLO_BRILLO_FILES_FILE_UTIL_TEST_H_ +#define LIBBRILLO_BRILLO_FILES_FILE_UTIL_TEST_H_ + +#include <string> +#include <vector> + +#include <base/files/scoped_temp_dir.h> +#include <brillo/files/safe_fd.h> +#include <gtest/gtest.h> + +namespace brillo { + +// Convert the SafeFD::Error enum class to a string for readability of +// test results. +std::string to_string(brillo::SafeFD::Error err); + +// Helper to enable gtest to print SafeFD::Error results in a way that is easier +// to read. +std::ostream& operator<<(std::ostream& os, const brillo::SafeFD::Error err); + +// Gets a short random string that can be used as part of a file name. +std::string GetRandomSuffix(); + +class FileTest : public testing::Test { + public: + static constexpr char kFileName[] = "test.temp"; + static constexpr char kSubdirName[] = "test_dir"; + static constexpr char kSymbolicFileName[] = "sym_test.temp"; + static constexpr char kSymbolicDirName[] = "sym_dir"; + + static void SetUpTestCase(); + + FileTest(); + + protected: + std::vector<char> temp_dir_path_; + base::FilePath file_path_; + base::FilePath sub_dir_path_; + base::FilePath symlink_file_path_; + base::FilePath symlink_dir_path_; + base::ScopedTempDir temp_dir_; + SafeFD root_; + + bool SetupSubdir() WARN_UNUSED_RESULT; + + bool SetupSymlinks() WARN_UNUSED_RESULT; + + // Writes |contents| to |file_path_|. Pulled into a separate function just + // to improve readability of tests. + bool WriteFile(const std::string& contents) WARN_UNUSED_RESULT; + + // Verifies that the file at |file_path_| exists and contains |contents|. + void ExpectFileContains(const std::string& contents); + + // Verifies that the file at |file_path_| has |permissions|. + void ExpectPermissions(base::FilePath path, int permissions); + + // Creates a file with a random name in the temporary directory. + base::FilePath GetTempName() WARN_UNUSED_RESULT; +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_FILES_FILE_UTIL_TEST_H_ diff --git a/brillo/files/safe_fd.cc b/brillo/files/safe_fd.cc new file mode 100644 index 0000000..ac19dc3 --- /dev/null +++ b/brillo/files/safe_fd.cc @@ -0,0 +1,552 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/files/safe_fd.h" + +#include <fcntl.h> +#include <sys/stat.h> +#include <unistd.h> + +#include <base/files/file_util.h> +#include <base/logging.h> +#include <base/posix/eintr_wrapper.h> +#include <brillo/files/file_util.h> +#include <brillo/files/scoped_dir.h> +#include <brillo/syslog_logging.h> + +namespace brillo { + +namespace { + +SafeFD::SafeFDResult MakeErrorResult(SafeFD::Error error) { + return std::make_pair(SafeFD(), error); +} + +SafeFD::SafeFDResult MakeSuccessResult(SafeFD&& fd) { + return std::make_pair(std::move(fd), SafeFD::Error::kNoError); +} + +SafeFD::SafeFDResult OpenPathComponentInternal(int parent_fd, + const std::string& file, + int flags, + mode_t mode) { + if (file != "/" && file.find("/") != std::string::npos) { + return MakeErrorResult(SafeFD::Error::kBadArgument); + } + SafeFD fd; + + // O_NONBLOCK is used to avoid hanging on edge cases (e.g. a serial port with + // flow control, or a FIFO without a writer). + if (parent_fd >= 0 || parent_fd == AT_FDCWD) { + fd.UnsafeReset(HANDLE_EINTR(openat(parent_fd, file.c_str(), + flags | O_NONBLOCK | O_NOFOLLOW, mode))); + } else if (file == "/") { + fd.UnsafeReset(HANDLE_EINTR(open( + file.c_str(), flags | O_DIRECTORY | O_NONBLOCK | O_NOFOLLOW, mode))); + } + + if (!fd.is_valid()) { + // open(2) fails with ELOOP when the last component of the |path| is a + // symlink. It fails with ENXIO when |path| is a FIFO and |flags| is for + // writing because of the O_NONBLOCK flag added above. + switch (errno) { + case ENOENT: + // Do not write to the log because opening a non-existent file is a + // frequent occurrence. + return MakeErrorResult(SafeFD::Error::kDoesNotExist); + case ELOOP: + // PLOG prints something along the lines of the symlink depth being too + // great which is is misleading so LOG is used instead. + LOG(ERROR) << "Symlink detected! failed to open \"" << file + << "\" safely."; + return MakeErrorResult(SafeFD::Error::kSymlinkDetected); + case EISDIR: + PLOG(ERROR) << "Directory detected! failed to open \"" << file + << "\" safely"; + return MakeErrorResult(SafeFD::Error::kWrongType); + case ENOTDIR: + PLOG(ERROR) << "Not a directory! failed to open \"" << file + << "\" safely"; + return MakeErrorResult(SafeFD::Error::kWrongType); + case ENXIO: + PLOG(ERROR) << "FIFO detected! failed to open \"" << file + << "\" safely"; + return MakeErrorResult(SafeFD::Error::kWrongType); + default: + PLOG(ERROR) << "Failed to open \"" << file << '"'; + return MakeErrorResult(SafeFD::Error::kIOError); + } + } + + // Remove the O_NONBLOCK flag unless the original |flags| have it. + if ((flags & O_NONBLOCK) == 0) { + flags = fcntl(fd.get(), F_GETFL); + if (flags == -1) { + PLOG(ERROR) << "Failed to get fd flags for " << file; + return MakeErrorResult(SafeFD::Error::kIOError); + } + if (fcntl(fd.get(), F_SETFL, flags & ~O_NONBLOCK)) { + PLOG(ERROR) << "Failed to set fd flags for " << file; + return MakeErrorResult(SafeFD::Error::kIOError); + } + } + + return MakeSuccessResult(std::move(fd)); +} + +SafeFD::SafeFDResult OpenSafelyInternal(int parent_fd, + const base::FilePath& path, + int flags, + mode_t mode) { + std::vector<std::string> components; + path.GetComponents(&components); + + auto itr = components.begin(); + if (itr == components.end()) { + LOG(ERROR) << "A path is required."; + return MakeErrorResult(SafeFD::Error::kBadArgument); + } + + SafeFD::SafeFDResult child_fd; + int parent_flags = flags | O_NONBLOCK | O_RDONLY | O_DIRECTORY | O_PATH; + for (; itr + 1 != components.end(); ++itr) { + child_fd = OpenPathComponentInternal(parent_fd, *itr, parent_flags, 0); + // Operation failed, so directly return the error result. + if (!child_fd.first.is_valid()) { + return child_fd; + } + parent_fd = child_fd.first.get(); + } + + return OpenPathComponentInternal(parent_fd, *itr, flags, mode); +} + +SafeFD::Error CheckAttributes(int fd, + mode_t permissions, + uid_t uid, + gid_t gid) { + struct stat fd_attributes; + if (fstat(fd, &fd_attributes) != 0) { + PLOG(ERROR) << "fstat failed"; + return SafeFD::Error::kIOError; + } + + if (fd_attributes.st_uid != uid) { + LOG(ERROR) << "Owner uid is " << fd_attributes.st_uid << " instead of " + << uid; + return SafeFD::Error::kWrongUID; + } + + if (fd_attributes.st_gid != gid) { + LOG(ERROR) << "Owner gid is " << fd_attributes.st_gid << " instead of " + << gid; + return SafeFD::Error::kWrongGID; + } + + if ((0777 & (fd_attributes.st_mode ^ permissions)) != 0) { + mode_t mask = umask(0); + umask(mask); + LOG(ERROR) << "Permissions are " << std::oct + << (0777 & fd_attributes.st_mode) << " instead of " + << (0777 & permissions) << ". Umask is " << std::oct << mask + << std::dec; + return SafeFD::Error::kWrongPermissions; + } + + return SafeFD::Error::kNoError; +} + +SafeFD::Error GetFileSize(int fd, size_t* file_size) { + struct stat fd_attributes; + if (fstat(fd, &fd_attributes) != 0) { + return SafeFD::Error::kIOError; + } + + *file_size = fd_attributes.st_size; + return SafeFD::Error::kNoError; +} + +} // namespace + +bool SafeFD::IsError(SafeFD::Error err) { + return err != Error::kNoError; +} + +const char* SafeFD::RootPath = "/"; + +SafeFD::SafeFDResult SafeFD::Root() { + SafeFD::SafeFDResult root = + OpenPathComponentInternal(-1, "/", O_DIRECTORY, 0); + if (strcmp(SafeFD::RootPath, "/") == 0) { + return root; + } + + if (!root.first.is_valid()) { + LOG(ERROR) << "Failed to open root directory!"; + return root; + } + return root.first.OpenExistingDir(base::FilePath(SafeFD::RootPath)); +} + +void SafeFD::SetRootPathForTesting(const char* new_root_path) { + SafeFD::RootPath = new_root_path; +} + +int SafeFD::get() const { + return fd_.get(); +} + +bool SafeFD::is_valid() const { + return fd_.is_valid(); +} + +void SafeFD::reset() { + return fd_.reset(); +} + +void SafeFD::UnsafeReset(int fd) { + return fd_.reset(fd); +} + +SafeFD::Error SafeFD::Write(const char* data, size_t size) { + if (!fd_.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + errno = 0; + if (!base::WriteFileDescriptor(fd_.get(), data, size)) { + PLOG(ERROR) << "Failed to write to file"; + return SafeFD::Error::kIOError; + } + + if (HANDLE_EINTR(ftruncate(fd_.get(), size)) != 0) { + PLOG(ERROR) << "Failed to truncate file"; + return SafeFD::Error::kIOError; + } + return SafeFD::Error::kNoError; +} + +std::pair<std::vector<char>, SafeFD::Error> SafeFD::ReadContents( + size_t max_size) { + std::vector<char> buffer; + if (!fd_.is_valid()) { + return std::make_pair(std::move(buffer), SafeFD::Error::kNotInitialized); + } + + size_t file_size = 0; + SafeFD::Error err = GetFileSize(fd_.get(), &file_size); + if (IsError(err)) { + return std::make_pair(std::move(buffer), err); + } + + if (file_size > max_size) { + return std::make_pair(std::move(buffer), SafeFD::Error::kExceededMaximum); + } + + buffer.resize(file_size); + + err = Read(buffer.data(), buffer.size()); + if (IsError(err)) { + buffer.clear(); + } + return std::make_pair(std::move(buffer), err); +} + +SafeFD::Error SafeFD::Read(char* data, size_t size) { + if (!fd_.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + + if (!base::ReadFromFD(fd_.get(), data, size)) { + PLOG(ERROR) << "Failed to read file"; + return SafeFD::Error::kIOError; + } + return SafeFD::Error::kNoError; +} + +SafeFD::SafeFDResult SafeFD::OpenExistingFile(const base::FilePath& path, + int flags) { + if (!fd_.is_valid()) { + return MakeErrorResult(SafeFD::Error::kNotInitialized); + } + + return OpenSafelyInternal(get(), path, flags, 0 /*mode*/); +} + +SafeFD::SafeFDResult SafeFD::OpenExistingDir(const base::FilePath& path, + int flags) { + if (!fd_.is_valid()) { + return MakeErrorResult(SafeFD::Error::kNotInitialized); + } + + return OpenSafelyInternal(get(), path, O_DIRECTORY | flags /*flags*/, + 0 /*mode*/); +} + +SafeFD::SafeFDResult SafeFD::MakeFile(const base::FilePath& path, + mode_t permissions, + uid_t uid, + gid_t gid, + int flags) { + if (!fd_.is_valid()) { + return MakeErrorResult(SafeFD::Error::kNotInitialized); + } + + // Open (and create if necessary) the parent directory. + base::FilePath dir_name = path.DirName(); + SafeFD::SafeFDResult parent_dir; + int parent_dir_fd = get(); + if (!dir_name.empty() && + dir_name.value() != base::FilePath::kCurrentDirectory) { + // Apply execute permission where read permission are present for parent + // directories. + int dir_permissions = permissions | ((permissions & 0444) >> 2); + parent_dir = + MakeDir(dir_name, dir_permissions, uid, gid, O_RDONLY | O_CLOEXEC); + if (!parent_dir.first.is_valid()) { + return parent_dir; + } + parent_dir_fd = parent_dir.first.get(); + } + + // If file already exists, validate permissions. + SafeFDResult file = OpenPathComponentInternal( + parent_dir_fd, path.BaseName().value(), flags, permissions /*mode*/); + if (file.first.is_valid()) { + SafeFD::Error err = + CheckAttributes(file.first.get(), permissions, uid, gid); + if (IsError(err)) { + return MakeErrorResult(err); + } + return file; + } else if (errno != ENOENT) { + return file; + } + + // The file does exist, create it and set the ownership. + file = + OpenPathComponentInternal(parent_dir_fd, path.BaseName().value(), + O_CREAT | O_EXCL | flags, permissions /*mode*/); + if (!file.first.is_valid()) { + return file; + } + if (HANDLE_EINTR(fchown(file.first.get(), uid, gid)) != 0) { + PLOG(ERROR) << "Failed to set ownership in MakeFile() for \"" + << path.value() << '"'; + return MakeErrorResult(SafeFD::Error::kIOError); + } + return file; +} + +SafeFD::SafeFDResult SafeFD::MakeDir(const base::FilePath& path, + mode_t permissions, + uid_t uid, + gid_t gid, + int flags) { + if (!fd_.is_valid()) { + return MakeErrorResult(SafeFD::Error::kNotInitialized); + } + + std::vector<std::string> components; + path.GetComponents(&components); + if (components.empty()) { + LOG(ERROR) << "Called MakeDir() with an empty path"; + return MakeErrorResult(SafeFD::Error::kBadArgument); + } + + // Walk the path creating directories as necessary. + SafeFD dir; + SafeFDResult child_dir; + int parent_dir_fd = get(); + int dir_flags = O_NONBLOCK | O_DIRECTORY | O_PATH; + bool made_dir = false; + for (const auto& component : components) { + if (mkdirat(parent_dir_fd, component.c_str(), permissions) != 0) { + if (errno != EEXIST) { + PLOG(ERROR) << "Failed to mkdirat() " << component << ": full_path=\"" + << path.value() << '"'; + return MakeErrorResult(SafeFD::Error::kIOError); + } + } else { + made_dir = true; + } + + // For the last component in the path, use the flags provided by the caller. + if (&component == &components.back()) { + dir_flags = flags | O_DIRECTORY; + } + child_dir = OpenPathComponentInternal(parent_dir_fd, component, dir_flags, + 0 /*mode*/); + if (!child_dir.first.is_valid()) { + return child_dir; + } + + dir = std::move(child_dir.first); + parent_dir_fd = dir.get(); + } + + if (made_dir) { + // If the directory was created, set the ownership. + if (HANDLE_EINTR(fchown(dir.get(), uid, gid)) != 0) { + PLOG(ERROR) << "Failed to set ownership in MakeDir() for \"" + << path.value() << '"'; + return MakeErrorResult(SafeFD::Error::kIOError); + } + } + // If the directory already existed, validate the permissions. + SafeFD::Error err = CheckAttributes(dir.get(), permissions, uid, gid); + if (IsError(err)) { + return MakeErrorResult(err); + } + + return MakeSuccessResult(std::move(dir)); +} + +SafeFD::Error SafeFD::Link(const SafeFD& source_dir, + const std::string& source_name, + const std::string& destination_name) { + if (!fd_.is_valid() || !source_dir.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + + SafeFD::Error err = IsValidFilename(source_name); + if (IsError(err)) { + return err; + } + + err = IsValidFilename(destination_name); + if (IsError(err)) { + return err; + } + + if (HANDLE_EINTR(linkat(source_dir.get(), source_name.c_str(), fd_.get(), + destination_name.c_str(), 0)) != 0) { + PLOG(ERROR) << "Failed to link \"" << destination_name << "\""; + return SafeFD::Error::kIOError; + } + return SafeFD::Error::kNoError; +} + +SafeFD::Error SafeFD::Unlink(const std::string& name) { + if (!fd_.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + + SafeFD::Error err = IsValidFilename(name); + if (IsError(err)) { + return err; + } + + if (HANDLE_EINTR(unlinkat(fd_.get(), name.c_str(), 0 /*flags*/)) != 0) { + PLOG(ERROR) << "Failed to unlink \"" << name << "\""; + return SafeFD::Error::kIOError; + } + return SafeFD::Error::kNoError; +} + +SafeFD::Error SafeFD::Rmdir(const std::string& name, + bool recursive, + size_t max_depth, + bool keep_going) { + if (!fd_.is_valid()) { + return SafeFD::Error::kNotInitialized; + } + + if (max_depth == 0) { + return SafeFD::Error::kExceededMaximum; + } + + SafeFD::Error err = IsValidFilename(name); + if (IsError(err)) { + return err; + } + + SafeFD::Error last_err = SafeFD::Error::kNoError; + + if (recursive) { + SafeFD dir_fd; + std::tie(dir_fd, err) = + OpenPathComponentInternal(fd_.get(), name, O_DIRECTORY, 0); + if (!dir_fd.is_valid()) { + return err; + } + + // The ScopedDIR takes ownership of this so dup_fd is not scoped on its own. + int dup_fd = dup(dir_fd.get()); + if (dup_fd < 0) { + PLOG(ERROR) << "dup failed"; + return SafeFD::Error::kIOError; + } + + ScopedDIR dir(fdopendir(dup_fd)); + if (!dir.is_valid()) { + PLOG(ERROR) << "fdopendir failed"; + close(dup_fd); + return SafeFD::Error::kIOError; + } + + struct stat dir_info; + if (fstat(dir_fd.get(), &dir_info) != 0) { + return SafeFD::Error::kIOError; + } + + errno = 0; + const dirent* entry = HANDLE_EINTR_IF_EQ(readdir(dir.get()), nullptr); + while (entry != nullptr) { + SafeFD::Error err = [&]() { + if (strcmp(entry->d_name, ".") == 0 || + strcmp(entry->d_name, "..") == 0) { + return SafeFD::Error::kNoError; + } + + struct stat child_info; + if (fstatat(dir_fd.get(), entry->d_name, &child_info, + AT_NO_AUTOMOUNT | AT_SYMLINK_NOFOLLOW) != 0) { + return SafeFD::Error::kIOError; + } + + if (child_info.st_dev != dir_info.st_dev) { + return SafeFD::Error::kBoundaryDetected; + } + + if (entry->d_type != DT_DIR) { + return dir_fd.Unlink(entry->d_name); + } + + return dir_fd.Rmdir(entry->d_name, true, max_depth - 1, keep_going); + }(); + + if (IsError(err)) { + if (!keep_going) { + return err; + } + last_err = err; + } + + errno = 0; + entry = HANDLE_EINTR_IF_EQ(readdir(dir.get()), nullptr); + } + if (errno != 0) { + PLOG(ERROR) << "readdir failed"; + return SafeFD::Error::kIOError; + } + } + + if (HANDLE_EINTR(unlinkat(fd_.get(), name.c_str(), AT_REMOVEDIR)) != 0) { + PLOG(ERROR) << "unlinkat failed"; + if (errno == ENOTDIR) { + return SafeFD::Error::kWrongType; + } + // If there was an error during the recursive delete, we expect unlink + // to fail with ENOTEMPTY and we bubble the error from recursion + // instead. + if (IsError(last_err) && errno == ENOTEMPTY) { + return last_err; + } + return SafeFD::Error::kIOError; + } + + return last_err; +} + +} // namespace brillo diff --git a/brillo/files/safe_fd.h b/brillo/files/safe_fd.h new file mode 100644 index 0000000..3c77362 --- /dev/null +++ b/brillo/files/safe_fd.h @@ -0,0 +1,204 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This provides an API for performing typical filesystem related tasks while +// guaranteeing certain security properties are maintained. Specifically, checks +// are performed to disallow symbolic links, and exotic file objects. The goal +// behind these checks is to thwart attacks that rely on confusing system +// services to perform unintended file operations like ownership changes or +// copy-as-root attack primitives. To accomplish this these operations are +// written to avoid susceptibility to TOCTOU (time-of-check-time-of-use) +// attacks. + +// To use this API start with the root path and work from there. For example: +// SafeFD fd(SafeDirFD::Root().MakeFile(PATH).first); +// if (!fd.is_valid()) { +// LOG(ERROR) << "Failed to open " << PATH; +// return false; +// } +// if (fd.WriteString(CONTENTS) != SafeFD::kNoError) { +// LOG(ERROR) << "Failed to write to " << PATH; +// return false; +// } +// auto read_result = fd.ReadString(); +// if (!read_result.second != SafeFD::kNoError) { +// LOG(ERROR) << "Failed to read from " << PATH; +// return false; +// } + +#ifndef LIBBRILLO_BRILLO_FILES_SAFE_FD_H_ +#define LIBBRILLO_BRILLO_FILES_SAFE_FD_H_ + +#include <fcntl.h> + +#include <string> +#include <utility> +#include <vector> + +#include <base/files/file_path.h> +#include <base/files/scoped_file.h> +#include <base/optional.h> +#include <base/synchronization/lock.h> +#include <brillo/brillo_export.h> + +namespace brillo { + +class SafeFDTest; + +class SafeFD { + public: + enum class Error { + kNoError = 0, + kBadArgument, + kNotInitialized, // Invalid operation on a SafeFD that was not initialized. + kIOError, // Check errno for specific cause. + kDoesNotExist, // The specified path does not exist. + kSymlinkDetected, + kBoundaryDetected, // Detected a file system boundary during recursion. + kWrongType, // (e.g. got a directory and expected a file) + kWrongUID, + kWrongGID, + kWrongPermissions, + kExceededMaximum, // The maximum allowed read size was reached. + }; + + // Returns true if |err| denotes a failed operation. + BRILLO_EXPORT static bool IsError(SafeFD::Error err); + + typedef std::pair<SafeFD, Error> SafeFDResult; + + // 100 MiB + BRILLO_EXPORT static constexpr size_t kDefaultMaxRead = 100 << 20; + BRILLO_EXPORT static constexpr size_t kDefaultMaxPathDepth = 256; + // User read and write only. + BRILLO_EXPORT static constexpr size_t kDefaultFilePermissions = 0640; + // User read, write, and execute. Group read and execute. + BRILLO_EXPORT static constexpr size_t kDefaultDirPermissions = 0750; + + // Get a SafeFD to the root path. + BRILLO_EXPORT static SafeFDResult Root() WARN_UNUSED_RESULT; + BRILLO_EXPORT static void SetRootPathForTesting(const char* new_root_path); + + // Constructs an invalid fd; + BRILLO_EXPORT SafeFD() = default; + + // Move-based constructor and assignment. + BRILLO_EXPORT SafeFD(SafeFD&&) = default; + BRILLO_EXPORT SafeFD& operator=(SafeFD&&) = default; + + // Return the fd number. + BRILLO_EXPORT int get() const WARN_UNUSED_RESULT; + + // Check the validity of the file descriptor. + BRILLO_EXPORT bool is_valid() const WARN_UNUSED_RESULT; + + // Close the scoped file if one was open. + BRILLO_EXPORT void reset(); + + // Wrap |fd| with a SafeFD which will close the fd when this goes out of + // scope. This closes the original fd if one was open. + // This is named "Unsafe" because the recommended way to get a SafeFD + // instance is opening one from SafeFD::Root(). + BRILLO_EXPORT void UnsafeReset(int fd); + + // Writes |size| bytes from |data| into a file and returns kNoError on + // success. Note the file will be truncated to the size of the content. + // + // Parameters + // data - The buffer to write to the file. + // size - The number of bytes to write. + BRILLO_EXPORT Error Write(const char* data, size_t size) WARN_UNUSED_RESULT; + + // Read the contents of the file and return it as a string. + // + // Parameters + // size - The max number of bytes to read. + BRILLO_EXPORT std::pair<std::vector<char>, Error> ReadContents( + size_t max_size = kDefaultMaxRead) WARN_UNUSED_RESULT; + + // Reads exactly |size| bytes into |data|. + // + // Parameters + // data - The buffer to read the file into. + // size - The number of bytes to read. + BRILLO_EXPORT Error Read(char* data, size_t size) WARN_UNUSED_RESULT; + + // Open an existing file relative to this directory. + // + // Parameters + // path - The path to open relative to the current directory. + BRILLO_EXPORT SafeFDResult OpenExistingFile(const base::FilePath& path, + int flags = O_RDWR | O_CLOEXEC) + WARN_UNUSED_RESULT; + + // Open an existing directory relative to this directory. + // + // Parameters + // path - The path to open relative to the current directory. + BRILLO_EXPORT SafeFDResult OpenExistingDir(const base::FilePath& path, + int flags = O_RDONLY | O_CLOEXEC) + WARN_UNUSED_RESULT; + + // Open a file relative to this directory creating the parent directories and + // file if they don't already exist. + BRILLO_EXPORT SafeFDResult + MakeFile(const base::FilePath& path, + mode_t permissions = kDefaultFilePermissions, + uid_t uid = getuid(), + gid_t gid = getgid(), + int flags = O_RDWR | O_CLOEXEC) WARN_UNUSED_RESULT; + + // Create the directories in the relative path with the given ownership and + // permissions and return a file descriptor to the result. + BRILLO_EXPORT SafeFDResult + MakeDir(const base::FilePath& path, + mode_t permissions = kDefaultDirPermissions, + uid_t uid = getuid(), + gid_t gid = getgid(), + int flags = O_RDONLY | O_CLOEXEC) WARN_UNUSED_RESULT; + + // Hard link |fd| in the directory represented by |this| with the specified + // name |filename|. This requires CAP_DAC_READ_SEARCH. + // + // Parameters + // data - The buffer to write to the file. + // size - The number of bytes to write. + BRILLO_EXPORT Error Link(const SafeFD& source_dir, + const std::string& source_name, + const std::string& destination_name) + WARN_UNUSED_RESULT; + + // Deletes the child path named |name|. + // + // Parameters + // name - the name of the filesystem object to delete. + BRILLO_EXPORT Error Unlink(const std::string& name) WARN_UNUSED_RESULT; + + // Deletes a child directory. It will return kBoundaryDetected if a file + // system boundary is reached during recursion. + // + // Parameters + // name - the name of the directory to delete. + // recursive - if true also unlink child paths. + // max_depth - limit on recursion depth to prevent fd exhaustion and stack + // overflows. + // keep_going - in recursive case continue deleting even in the face of + // errors. If all entries cannot be deleted, the last error encountered + // during recursion is returned. + BRILLO_EXPORT Error Rmdir(const std::string& name, + bool recursive = false, + size_t max_depth = kDefaultMaxPathDepth, + bool keep_going = true) WARN_UNUSED_RESULT; + + private: + BRILLO_EXPORT static const char* RootPath; + + base::ScopedFD fd_; + + DISALLOW_COPY_AND_ASSIGN(SafeFD); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_FILES_SAFE_FD_H_ diff --git a/brillo/files/safe_fd_test.cc b/brillo/files/safe_fd_test.cc new file mode 100644 index 0000000..6401f3d --- /dev/null +++ b/brillo/files/safe_fd_test.cc @@ -0,0 +1,698 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/files/safe_fd.h" + +#include <fcntl.h> +#include <sys/stat.h> + +#include <base/files/file_util.h> +#include <brillo/files/file_util_test.h> +#include <brillo/syslog_logging.h> +#include <gtest/gtest.h> + +namespace brillo { + +class SafeFDTest : public FileTest {}; + +TEST_F(SafeFDTest, SafeFD) { + EXPECT_FALSE(SafeFD().is_valid()); +} + +TEST_F(SafeFDTest, SafeFD_Move) { + SafeFD moved_root = std::move(root_); + EXPECT_FALSE(root_.is_valid()); + ASSERT_TRUE(moved_root.is_valid()); + + SafeFD moved_root2(std::move(moved_root)); + EXPECT_FALSE(moved_root.is_valid()); + ASSERT_TRUE(moved_root2.is_valid()); +} + +TEST_F(SafeFDTest, Root) { + SafeFD::SafeFDResult result = SafeFD::Root(); + EXPECT_TRUE(result.first.is_valid()); + EXPECT_EQ(result.second, SafeFD::Error::kNoError); +} + +TEST_F(SafeFDTest, reset) { + root_.reset(); + EXPECT_FALSE(root_.is_valid()); +} + +TEST_F(SafeFDTest, UnsafeReset) { + int fd = + HANDLE_EINTR(open(temp_dir_path_.data(), + O_NONBLOCK | O_DIRECTORY | O_RDONLY | O_CLOEXEC, 0777)); + ASSERT_GE(fd, 0); + + { + SafeFD safe_fd; + safe_fd.UnsafeReset(fd); + EXPECT_EQ(safe_fd.get(), fd); + } + + // Verify the file descriptor is closed. + int result = fcntl(fd, F_GETFD); + int error = errno; + EXPECT_EQ(result, -1); + EXPECT_EQ(error, EBADF); +} + +TEST_F(SafeFDTest, Write_Success) { + std::string random_data = GetRandomSuffix(); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + EXPECT_EQ(file.first.Write(random_data.data(), random_data.size()), + SafeFD::Error::kNoError); + } + + ExpectFileContains(random_data); + ExpectPermissions(file_path_, SafeFD::kDefaultFilePermissions); +} + +TEST_F(SafeFDTest, Write_NotInitialized) { + SafeFD invalid; + ASSERT_FALSE(invalid.is_valid()); + + std::string random_data = GetRandomSuffix(); + EXPECT_EQ(invalid.Write(random_data.data(), random_data.size()), + SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Write_VerifyTruncate) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + { + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + EXPECT_EQ(file.first.Write("", 0), SafeFD::Error::kNoError); + } + + ExpectFileContains(""); +} + +TEST_F(SafeFDTest, Write_Failure) { + std::string random_data = GetRandomSuffix(); + EXPECT_EQ(root_.Write("", 1), SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, ReadContents_Success) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + auto result = file.first.ReadContents(); + EXPECT_EQ(result.second, SafeFD::Error::kNoError); + ASSERT_EQ(random_data.size(), result.first.size()); + EXPECT_EQ(memcmp(random_data.data(), result.first.data(), random_data.size()), + 0); +} + +TEST_F(SafeFDTest, ReadContents_ExceededMaximum) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + ASSERT_LT(1, random_data.size()); + auto result = file.first.ReadContents(1); + EXPECT_EQ(result.second, SafeFD::Error::kExceededMaximum); +} + +TEST_F(SafeFDTest, ReadContents_NotInitialized) { + SafeFD invalid; + ASSERT_FALSE(invalid.is_valid()); + + auto result = invalid.ReadContents(); + EXPECT_EQ(result.second, SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Read_Success) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + std::vector<char> buffer(random_data.size(), '\0'); + ASSERT_EQ(file.first.Read(buffer.data(), buffer.size()), + SafeFD::Error::kNoError); + EXPECT_EQ(memcmp(random_data.data(), buffer.data(), random_data.size()), 0); +} + +TEST_F(SafeFDTest, Read_NotInitialized) { + SafeFD invalid; + ASSERT_FALSE(invalid.is_valid()); + + char to_read; + EXPECT_EQ(invalid.Read(&to_read, 1), SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Read_IOError) { + std::string random_data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(random_data)); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + + std::vector<char> buffer(random_data.size() * 2, '\0'); + ASSERT_EQ(file.first.Read(buffer.data(), buffer.size()), + SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, OpenExistingFile_Success) { + std::string data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(data)); + { + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + } + ExpectFileContains(data); +} + +TEST_F(SafeFDTest, OpenExistingFile_NotInitialized) { + SafeFD::SafeFDResult file = SafeFD().OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNotInitialized); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingFile_DoesNotExist) { + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kDoesNotExist); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingFile_IOError) { + ASSERT_TRUE(WriteFile("")); + EXPECT_EQ(chmod(file_path_.value().c_str(), 0000), 0) << strerror(errno); + + SafeFD::SafeFDResult file = root_.OpenExistingFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kIOError); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingFile_SymlinkDetected) { + ASSERT_TRUE(SetupSymlinks()); + ASSERT_TRUE(WriteFile("")); + SafeFD::SafeFDResult file = root_.OpenExistingFile(symlink_file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kSymlinkDetected); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingFile_WrongType) { + ASSERT_TRUE(SetupSymlinks()); + ASSERT_TRUE(WriteFile("")); + SafeFD::SafeFDResult file = + root_.OpenExistingFile(symlink_dir_path_.Append(kFileName)); + EXPECT_EQ(file.second, SafeFD::Error::kWrongType); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_Success) { + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(dir.second, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_NotInitialized) { + SafeFD::SafeFDResult dir = SafeFD().OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(dir.second, SafeFD::Error::kNotInitialized); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_DoesNotExist) { + SafeFD::SafeFDResult dir = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kDoesNotExist); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_IOError) { + ASSERT_TRUE(WriteFile("")); + ASSERT_EQ(chmod(sub_dir_path_.value().c_str(), 0000), 0) << strerror(errno); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kIOError); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, OpenExistingDir_WrongType) { + ASSERT_TRUE(SetupSymlinks()); + SafeFD::SafeFDResult dir = root_.OpenExistingDir(symlink_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kWrongType); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeFile_DoesNotExistSuccess) { + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + } + + ExpectPermissions(file_path_, SafeFD::kDefaultFilePermissions); +} + +TEST_F(SafeFDTest, MakeFile_LeadingSelfDirSuccess) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::Error err; + SafeFD dir; + std::tie(dir, err) = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(err, SafeFD::Error::kNoError); + + { + SafeFD file; + std::tie(file, err) = dir.MakeFile(file_path_.BaseName()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(file.is_valid()); + } + + ExpectPermissions(file_path_, SafeFD::kDefaultFilePermissions); +} + +TEST_F(SafeFDTest, MakeFile_ExistsSuccess) { + std::string data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(data)); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + } + ExpectPermissions(file_path_, SafeFD::kDefaultFilePermissions); + ExpectFileContains(data); +} + +TEST_F(SafeFDTest, MakeFile_IOError) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_EQ(mkfifo(file_path_.value().c_str(), 0), 0); + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kIOError); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeFile_SymlinkDetected) { + ASSERT_TRUE(SetupSymlinks()); + SafeFD::SafeFDResult file = root_.MakeFile(symlink_file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kSymlinkDetected); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeFile_WrongType) { + ASSERT_TRUE(SetupSubdir()); + SafeFD::SafeFDResult file = root_.MakeFile(sub_dir_path_); + EXPECT_EQ(file.second, SafeFD::Error::kWrongType); + ASSERT_FALSE(file.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeFile_WrongGID) { + ASSERT_TRUE(WriteFile("")); + ASSERT_EQ(chown(file_path_.value().c_str(), getuid(), 0), 0) + << strerror(errno); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kWrongGID); + ASSERT_FALSE(file.first.is_valid()); + } +} + +TEST_F(SafeFDTest, MakeFile_WrongPermissions) { + ASSERT_TRUE(WriteFile("")); + ASSERT_EQ(chmod(file_path_.value().c_str(), 0777), 0) << strerror(errno); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kWrongPermissions); + ASSERT_FALSE(file.first.is_valid()); + } + ASSERT_EQ(chmod(file_path_.value().c_str(), SafeFD::kDefaultFilePermissions), + 0) + << strerror(errno); + + EXPECT_EQ(chmod(sub_dir_path_.value().c_str(), 0777), 0) << strerror(errno); + { + SafeFD::SafeFDResult file = root_.MakeFile(file_path_); + EXPECT_EQ(file.second, SafeFD::Error::kWrongPermissions); + ASSERT_FALSE(file.first.is_valid()); + } +} + +TEST_F(SafeFDTest, MakeDir_DoesNotExistSuccess) { + { + SafeFD::SafeFDResult dir = root_.MakeDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.first.is_valid()); + } + + ExpectPermissions(sub_dir_path_, SafeFD::kDefaultDirPermissions); +} + +TEST_F(SafeFDTest, MakeFile_SingleComponentSuccess) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::Error err; + SafeFD dir; + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(err, SafeFD::Error::kNoError); + + { + SafeFD subdir; + std::tie(subdir, err) = dir.MakeDir(base::FilePath(kSubdirName)); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(subdir.is_valid()); + } + + ExpectPermissions(sub_dir_path_, SafeFD::kDefaultDirPermissions); +} + +TEST_F(SafeFDTest, MakeDir_ExistsSuccess) { + ASSERT_TRUE(SetupSubdir()); + { + SafeFD::SafeFDResult dir = root_.MakeDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.first.is_valid()); + } + + ExpectPermissions(sub_dir_path_, SafeFD::kDefaultDirPermissions); +} + +TEST_F(SafeFDTest, MakeDir_WrongType) { + ASSERT_TRUE(SetupSymlinks()); + SafeFD::SafeFDResult dir = root_.MakeDir(symlink_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kWrongType); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, MakeDir_WrongGID) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_EQ(chown(sub_dir_path_.value().c_str(), getuid(), 0), 0) + << strerror(errno); + { + SafeFD::SafeFDResult dir = root_.MakeDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kWrongGID); + ASSERT_FALSE(dir.first.is_valid()); + } +} + +TEST_F(SafeFDTest, MakeDir_WrongPermissions) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_EQ(chmod(sub_dir_path_.value().c_str(), 0777), 0) << strerror(errno); + + SafeFD::SafeFDResult dir = root_.MakeDir(sub_dir_path_); + EXPECT_EQ(dir.second, SafeFD::Error::kWrongPermissions); + ASSERT_FALSE(dir.first.is_valid()); +} + +TEST_F(SafeFDTest, Link_Success) { + std::string data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(data)); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, kFileName), + SafeFD::Error::kNoError); + + SafeFD::SafeFDResult new_file = dir.first.OpenExistingFile( + base::FilePath(kFileName), O_RDONLY | O_CLOEXEC); + EXPECT_EQ(new_file.second, SafeFD::Error::kNoError); + std::pair<std::vector<char>, SafeFD::Error> contents = + new_file.first.ReadContents(); + EXPECT_EQ(contents.second, SafeFD::Error::kNoError); + EXPECT_EQ(data.size(), contents.first.size()); + EXPECT_EQ(memcmp(data.data(), contents.first.data(), data.size()), 0); +} + +TEST_F(SafeFDTest, Link_NotInitialized) { + std::string data = GetRandomSuffix(); + ASSERT_TRUE(WriteFile(data)); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(SafeFD().Link(subdir.first, kFileName, kFileName), + SafeFD::Error::kNotInitialized); + + EXPECT_EQ(dir.first.Link(SafeFD(), kFileName, kFileName), + SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Link_BadArgument) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Link(subdir.first, "a/a", kFileName), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, ".", kFileName), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, "..", kFileName), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, "a/a"), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, "."), + SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, ".."), + SafeFD::Error::kBadArgument); +} + +TEST_F(SafeFDTest, Link_IOError) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Link(subdir.first, kFileName, kFileName), + SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, Unlink_Success) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(subdir.first.Unlink(kFileName), SafeFD::Error::kNoError); + EXPECT_FALSE(base::PathExists(file_path_)); +} + +TEST_F(SafeFDTest, Unlink_NotInitialized) { + ASSERT_TRUE(WriteFile("")); + + EXPECT_EQ(SafeFD().Unlink(kFileName), SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Unlink_BadArgument) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(subdir.first.Unlink("a/a"), SafeFD::Error::kBadArgument); + EXPECT_EQ(subdir.first.Unlink("."), SafeFD::Error::kBadArgument); + EXPECT_EQ(subdir.first.Unlink(".."), SafeFD::Error::kBadArgument); +} + +TEST_F(SafeFDTest, Unlink_IOError_Nonexistent) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(subdir.first.Unlink(kFileName), SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, Unlink_IOError_IsADir) { + ASSERT_TRUE(SetupSubdir()); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Unlink(kSubdirName), SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, Rmdir_Recursive_Success) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Rmdir(kSubdirName, true /*recursive*/), + SafeFD::Error::kNoError); + EXPECT_FALSE(base::PathExists(file_path_)); + EXPECT_FALSE(base::PathExists(sub_dir_path_)); +} + +TEST_F(SafeFDTest, Rmdir_Recursive_SuccessMaxRecursion) { + SafeFD::Error err; + SafeFD dir; + + // Create directory with the maximum depth. + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + for (size_t x = 0; x < SafeFD::kDefaultMaxPathDepth; ++x) { + std::tie(dir, err) = dir.MakeDir(base::FilePath(kSubdirName)); + EXPECT_EQ(err, SafeFD::Error::kNoError); + ASSERT_TRUE(dir.is_valid()); + } + + // Check if recursive Rmdir succeeds (i.e. there isn't a stack overflow). + std::tie(dir, err) = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(err, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.Rmdir(kSubdirName, true /*recursive*/), + SafeFD::Error::kNoError); + EXPECT_FALSE(base::PathExists(file_path_)); + EXPECT_FALSE(base::PathExists(sub_dir_path_)); +} + +TEST_F(SafeFDTest, Rmdir_NotInitialized) { + ASSERT_TRUE(WriteFile("")); + + EXPECT_EQ(SafeFD().Rmdir(kSubdirName, true /*recursive*/), + SafeFD::Error::kNotInitialized); +} + +TEST_F(SafeFDTest, Rmdir_BadArgument) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + EXPECT_EQ(dir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Rmdir("a/a"), SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Rmdir("."), SafeFD::Error::kBadArgument); + EXPECT_EQ(dir.first.Rmdir(".."), SafeFD::Error::kBadArgument); +} + +TEST_F(SafeFDTest, Rmdir_ExceededMaximum) { + ASSERT_TRUE(SetupSubdir()); + ASSERT_TRUE(base::CreateDirectory(sub_dir_path_.Append(kSubdirName))); + + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(dir.first.Rmdir(kSubdirName, true /*recursive*/, 1), + SafeFD::Error::kExceededMaximum); +} + +TEST_F(SafeFDTest, Rmdir_IOError) { + SafeFD::SafeFDResult dir = root_.OpenExistingDir(temp_dir_.GetPath()); + ASSERT_EQ(dir.second, SafeFD::Error::kNoError); + + // Dir doesn't exist. + EXPECT_EQ(dir.first.Rmdir(kSubdirName), SafeFD::Error::kIOError); + + // Dir not empty. + ASSERT_TRUE(WriteFile("")); + EXPECT_EQ(dir.first.Rmdir(kSubdirName), SafeFD::Error::kIOError); +} + +TEST_F(SafeFDTest, Rmdir_WrongType) { + ASSERT_TRUE(WriteFile("")); + + SafeFD::SafeFDResult subdir = root_.OpenExistingDir(sub_dir_path_); + ASSERT_EQ(subdir.second, SafeFD::Error::kNoError); + + EXPECT_EQ(subdir.first.Rmdir(kFileName), SafeFD::Error::kWrongType); +} + +TEST_F(SafeFDTest, Rmdir_Recursive_KeepGoing) { + ASSERT_TRUE(SetupSubdir()); + + ASSERT_TRUE(base::CreateDirectory(sub_dir_path_.Append(kSubdirName))); + + // Give us something to iterate over. + constexpr int kNumSentinel = 25; + for (int i = 0; i < kNumSentinel; i++) { + SafeFD::SafeFDResult file = + root_.MakeFile(sub_dir_path_.Append(GetRandomSuffix())); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + } + + // Recursively delete with a max level that is too small. Capture errno. + SafeFD::Error result = root_.Rmdir(kSubdirName, true /*recursive*/, + 1 /*max_depth*/, true /*keep_going*/); + int rmdir_errno = errno; + + EXPECT_EQ(result, SafeFD::Error::kExceededMaximum); + + // If we keep going, the last operation will be the post-order unlink of + // the top-level directory. This has to fail with ENOTEMPTY since we did + // not delete the too-deep sub-directories. This particular behavior + // should not be part of the API contract and this can be relaxed if the + // implementation is changed. + EXPECT_EQ(rmdir_errno, ENOTEMPTY); + + // The deep directory must still exist. + ASSERT_TRUE( + base::DeleteFile(sub_dir_path_.Append(kSubdirName), false /*recursive*/)); + + // We cannot control the iteration order so even if we incorrectly + // stopped early the directory might still be empty if the deep + // directories were last in the iteration order. But a non-empty + // directory is always incorrect. + ASSERT_TRUE(base::IsDirectoryEmpty(sub_dir_path_)); +} + +TEST_F(SafeFDTest, Rmdir_Recursive_StopOnError) { + ASSERT_TRUE(SetupSubdir()); + + ASSERT_TRUE(base::CreateDirectory(sub_dir_path_.Append(kSubdirName))); + + // Give us something to iterate over. + constexpr int kNumSentinel = 25; + for (int i = 0; i < kNumSentinel; i++) { + SafeFD::SafeFDResult file = + root_.MakeFile(sub_dir_path_.Append(GetRandomSuffix())); + EXPECT_EQ(file.second, SafeFD::Error::kNoError); + ASSERT_TRUE(file.first.is_valid()); + } + + // Recursively delete with a max level that is too small. Capture errno. + SafeFD::Error result = root_.Rmdir(kSubdirName, true /*recursive*/, + 1 /*max_depth*/, false /*keep_going*/); + int rmdir_errno = errno; + + EXPECT_EQ(result, SafeFD::Error::kExceededMaximum); + + // If we stop on encountering a too-deep directory, we never actually + // make any libc calls that encounter errors. This particular behavior + // should not be part of the API contract and this can be relaxed if the + // implementation is changed. + EXPECT_EQ(rmdir_errno, 0); + + // The deep directory must still exist. + ASSERT_TRUE( + base::DeleteFile(sub_dir_path_.Append(kSubdirName), false /*recursive*/)); +} + +} // namespace brillo diff --git a/brillo/files/scoped_dir.h b/brillo/files/scoped_dir.h new file mode 100644 index 0000000..b4ca6a9 --- /dev/null +++ b/brillo/files/scoped_dir.h @@ -0,0 +1,36 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_FILES_SCOPED_DIR_H_ +#define LIBBRILLO_BRILLO_FILES_SCOPED_DIR_H_ + +#include <dirent.h> + +#include <base/scoped_generic.h> + +#define HANDLE_EINTR_IF_EQ(x, val) \ + ({ \ + decltype(x) eintr_wrapper_result; \ + do { \ + eintr_wrapper_result = (x); \ + } while (eintr_wrapper_result == (val) && errno == EINTR); \ + eintr_wrapper_result; \ + }) + +namespace brillo { + +struct ScopedDIRCloseTraits { + static DIR* InvalidValue() { return nullptr; } + static void Free(DIR* dir) { + if (dir != nullptr) { + closedir(dir); + } + } +}; + +typedef base::ScopedGeneric<DIR*, ScopedDIRCloseTraits> ScopedDIR; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_FILES_SCOPED_DIR_H_ diff --git a/brillo/flag_helper.cc b/brillo/flag_helper.cc index bb51818..1c332d3 100644 --- a/brillo/flag_helper.cc +++ b/brillo/flag_helper.cc @@ -4,12 +4,14 @@ #include "brillo/flag_helper.h" -#include <memory> #include <stdio.h> #include <stdlib.h> -#include <string> #include <sysexits.h> +#include <memory> +#include <string> +#include <utility> + #include <base/base_switches.h> #include <base/command_line.h> #include <base/logging.h> @@ -84,6 +86,22 @@ const char* Int32Flag::GetType() const { return "int"; } +UInt32Flag::UInt32Flag(const char* name, + uint32_t* value, + const char* default_value, + const char* help, + bool visible) + : Flag(name, default_value, help, visible), value_(value) { +} + +bool UInt32Flag::SetValue(const std::string& value) { + return base::StringToUint(value, value_); +} + +const char* UInt32Flag::GetType() const { + return "uint32"; +} + Int64Flag::Int64Flag(const char* name, int64_t* value, const char* default_value, diff --git a/brillo/flag_helper.h b/brillo/flag_helper.h index 810a00c..c6d63cd 100644 --- a/brillo/flag_helper.h +++ b/brillo/flag_helper.h @@ -20,6 +20,7 @@ // // DEFINE_bool(name, default_value, help) // DEFINE_int32(name, default_value, help) +// DEFINE_uint32(name, default_value, help) // DEFINE_int64(name, default_value, help) // DEFINE_uint64(name, default_value, help) // DEFINE_double(name, default_value, help) @@ -118,6 +119,21 @@ class BRILLO_EXPORT Int32Flag final : public Flag { int* value_; }; +class BRILLO_EXPORT UInt32Flag final : public Flag { + public: + UInt32Flag(const char* name, + uint32_t* value, + const char* default_value, + const char* help, + bool visible); + bool SetValue(const std::string& value) override; + + const char* GetType() const override; + + private: + uint32_t* value_; +}; + class BRILLO_EXPORT Int64Flag final : public Flag { public: Int64Flag(const char* name, @@ -191,6 +207,8 @@ class BRILLO_EXPORT StringFlag final : public Flag { #define DEFINE_int32(name, value, help) \ DEFINE_type(int, Int32Flag, name, value, help) +#define DEFINE_uint32(name, value, help) \ + DEFINE_type(uint32_t, UInt32Flag, name, value, help) #define DEFINE_int64(name, value, help) \ DEFINE_type(int64_t, Int64Flag, name, value, help) #define DEFINE_uint64(name, value, help) \ diff --git a/brillo/flag_helper_unittest.cc b/brillo/flag_helper_test.cc index 29c6429..7c7164d 100644 --- a/brillo/flag_helper_unittest.cc +++ b/brillo/flag_helper_test.cc @@ -8,6 +8,7 @@ #include <base/command_line.h> #include <base/macros.h> +#include <base/stl_util.h> #include <brillo/flag_helper.h> #include <gtest/gtest.h> @@ -29,6 +30,8 @@ TEST_F(FlagHelperTest, Defaults) { DEFINE_int32(int32_1, INT32_MIN, "Test int32 flag"); DEFINE_int32(int32_2, 0, "Test int32 flag"); DEFINE_int32(int32_3, INT32_MAX, "Test int32 flag"); + DEFINE_uint32(uint32_1, 0, "Test uint32 flag"); + DEFINE_uint32(uint32_2, UINT32_MAX, "Test uint32 flag"); DEFINE_int64(int64_1, INT64_MIN, "Test int64 flag"); DEFINE_int64(int64_2, 0, "Test int64 flag"); DEFINE_int64(int64_3, INT64_MAX, "Test int64 flag"); @@ -41,17 +44,19 @@ TEST_F(FlagHelperTest, Defaults) { DEFINE_string(string_2, "value", "Test string flag"); const char* argv[] = {"test_program"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); - brillo::FlagHelper::Init(arraysize(argv), argv, "TestDefaultTrue"); + brillo::FlagHelper::Init(base::size(argv), argv, "TestDefaultTrue"); EXPECT_TRUE(FLAGS_bool1); EXPECT_FALSE(FLAGS_bool2); EXPECT_EQ(FLAGS_int32_1, INT32_MIN); EXPECT_EQ(FLAGS_int32_2, 0); EXPECT_EQ(FLAGS_int32_3, INT32_MAX); + EXPECT_EQ(FLAGS_uint32_1, 0); + EXPECT_EQ(FLAGS_uint32_2, UINT32_MAX); EXPECT_EQ(FLAGS_int64_1, INT64_MIN); EXPECT_EQ(FLAGS_int64_2, 0); EXPECT_EQ(FLAGS_int64_3, INT64_MAX); @@ -74,6 +79,8 @@ TEST_F(FlagHelperTest, SetValueDoubleDash) { DEFINE_int32(int32_1, 1, "Test int32 flag"); DEFINE_int32(int32_2, 1, "Test int32 flag"); DEFINE_int32(int32_3, 1, "Test int32 flag"); + DEFINE_uint32(uint32_1, 1, "Test uint32 flag"); + DEFINE_uint32(uint32_2, 1, "Test uint32 flag"); DEFINE_int64(int64_1, 1, "Test int64 flag"); DEFINE_int64(int64_2, 1, "Test int64 flag"); DEFINE_int64(int64_3, 1, "Test int64 flag"); @@ -93,6 +100,8 @@ TEST_F(FlagHelperTest, SetValueDoubleDash) { "--int32_1=-2147483648", "--int32_2=0", "--int32_3=2147483647", + "--uint32_1=0", + "--uint32_2=4294967295", "--int64_1=-9223372036854775808", "--int64_2=0", "--int64_3=9223372036854775807", @@ -103,11 +112,11 @@ TEST_F(FlagHelperTest, SetValueDoubleDash) { "--double_3=100.5", "--string_1=", "--string_2=value"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); - brillo::FlagHelper::Init(arraysize(argv), argv, "TestDefaultTrue"); + brillo::FlagHelper::Init(base::size(argv), argv, "TestDefaultTrue"); EXPECT_TRUE(FLAGS_bool1); EXPECT_FALSE(FLAGS_bool2); @@ -116,6 +125,8 @@ TEST_F(FlagHelperTest, SetValueDoubleDash) { EXPECT_EQ(FLAGS_int32_1, INT32_MIN); EXPECT_EQ(FLAGS_int32_2, 0); EXPECT_EQ(FLAGS_int32_3, INT32_MAX); + EXPECT_EQ(FLAGS_uint32_1, 0); + EXPECT_EQ(FLAGS_uint32_2, UINT32_MAX); EXPECT_EQ(FLAGS_int64_1, INT64_MIN); EXPECT_EQ(FLAGS_int64_2, 0); EXPECT_EQ(FLAGS_int64_3, INT64_MAX); @@ -136,6 +147,8 @@ TEST_F(FlagHelperTest, SetValueSingleDash) { DEFINE_int32(int32_1, 1, "Test int32 flag"); DEFINE_int32(int32_2, 1, "Test int32 flag"); DEFINE_int32(int32_3, 1, "Test int32 flag"); + DEFINE_uint64(uint32_1, 1, "Test uint32 flag"); + DEFINE_uint64(uint32_2, 1, "Test uint32 flag"); DEFINE_int64(int64_1, 1, "Test int64 flag"); DEFINE_int64(int64_2, 1, "Test int64 flag"); DEFINE_int64(int64_3, 1, "Test int64 flag"); @@ -153,6 +166,8 @@ TEST_F(FlagHelperTest, SetValueSingleDash) { "-int32_1=-2147483648", "-int32_2=0", "-int32_3=2147483647", + "-uint32_1=0", + "-uint32_2=4294967295", "-int64_1=-9223372036854775808", "-int64_2=0", "-int64_3=9223372036854775807", @@ -163,17 +178,19 @@ TEST_F(FlagHelperTest, SetValueSingleDash) { "-double_3=100.5", "-string_1=", "-string_2=value"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); - brillo::FlagHelper::Init(arraysize(argv), argv, "TestDefaultTrue"); + brillo::FlagHelper::Init(base::size(argv), argv, "TestDefaultTrue"); EXPECT_TRUE(FLAGS_bool1); EXPECT_FALSE(FLAGS_bool2); EXPECT_EQ(FLAGS_int32_1, INT32_MIN); EXPECT_EQ(FLAGS_int32_2, 0); EXPECT_EQ(FLAGS_int32_3, INT32_MAX); + EXPECT_EQ(FLAGS_uint32_1, 0); + EXPECT_EQ(FLAGS_uint32_2, UINT32_MAX); EXPECT_EQ(FLAGS_int64_1, INT64_MIN); EXPECT_EQ(FLAGS_int64_2, 0); EXPECT_EQ(FLAGS_int64_3, INT64_MAX); @@ -192,11 +209,11 @@ TEST_F(FlagHelperTest, DuplicateSetValue) { DEFINE_int32(int32_1, 0, "Test in32 flag"); const char* argv[] = {"test_program", "--int32_1=5", "--int32_1=10"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); - brillo::FlagHelper::Init(arraysize(argv), argv, "TestDuplicateSetvalue"); + brillo::FlagHelper::Init(base::size(argv), argv, "TestDuplicateSetvalue"); EXPECT_EQ(FLAGS_int32_1, 10); } @@ -206,11 +223,11 @@ TEST_F(FlagHelperTest, FlagTerminator) { DEFINE_int32(int32_1, 0, "Test int32 flag"); const char* argv[] = {"test_program", "--int32_1=5", "--", "--int32_1=10"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); - brillo::FlagHelper::Init(arraysize(argv), argv, "TestFlagTerminator"); + brillo::FlagHelper::Init(base::size(argv), argv, "TestFlagTerminator"); EXPECT_EQ(FLAGS_int32_1, 5); } @@ -220,13 +237,14 @@ TEST_F(FlagHelperTest, FlagTerminator) { TEST_F(FlagHelperTest, HelpMessage) { DEFINE_bool(bool_1, true, "Test bool flag"); DEFINE_int32(int_1, 0, "Test int flag"); + DEFINE_uint32(uint32_1, 0, "Test uint32 flag"); DEFINE_int64(int64_1, 0, "Test int64 flag"); DEFINE_uint64(uint64_1, 0, "Test uint64 flag"); DEFINE_double(double_1, 0, "Test double flag"); DEFINE_string(string_1, "", "Test string flag"); const char* argv[] = {"test_program", "--int_1=value", "--help"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); @@ -235,7 +253,7 @@ TEST_F(FlagHelperTest, HelpMessage) { stdout = stderr; ASSERT_EXIT( - brillo::FlagHelper::Init(arraysize(argv), argv, "TestHelpMessage"), + brillo::FlagHelper::Init(base::size(argv), argv, "TestHelpMessage"), ::testing::ExitedWithCode(EX_OK), "TestHelpMessage\n\n" " --bool_1 \\(Test bool flag\\) type: bool default: true\n" @@ -244,6 +262,7 @@ TEST_F(FlagHelperTest, HelpMessage) { " --int64_1 \\(Test int64 flag\\) type: int64 default: 0\n" " --int_1 \\(Test int flag\\) type: int default: 0\n" " --string_1 \\(Test string flag\\) type: string default: \"\"\n" + " --uint32_1 \\(Test uint32 flag\\) type: uint32 default: 0\n" " --uint64_1 \\(Test uint64 flag\\) type: uint64 default: 0\n"); stdout = orig; @@ -253,7 +272,7 @@ TEST_F(FlagHelperTest, HelpMessage) { // to exit with EX_USAGE error code and corresponding error message. TEST_F(FlagHelperTest, UnknownFlag) { const char* argv[] = {"test_program", "--flag=value"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); @@ -261,7 +280,7 @@ TEST_F(FlagHelperTest, UnknownFlag) { FILE* orig = stdout; stdout = stderr; - ASSERT_EXIT(brillo::FlagHelper::Init(arraysize(argv), argv, "TestIntExit"), + ASSERT_EXIT(brillo::FlagHelper::Init(base::size(argv), argv, "TestIntExit"), ::testing::ExitedWithCode(EX_USAGE), "ERROR: unknown command line flag 'flag'"); @@ -274,7 +293,7 @@ TEST_F(FlagHelperTest, BoolParseError) { DEFINE_bool(bool_1, 0, "Test bool flag"); const char* argv[] = {"test_program", "--bool_1=value"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); @@ -283,7 +302,7 @@ TEST_F(FlagHelperTest, BoolParseError) { stdout = stderr; ASSERT_EXIT( - brillo::FlagHelper::Init(arraysize(argv), argv, "TestBoolParseError"), + brillo::FlagHelper::Init(base::size(argv), argv, "TestBoolParseError"), ::testing::ExitedWithCode(EX_DATAERR), "ERROR: illegal value 'value' specified for bool flag 'bool_1'"); @@ -296,7 +315,7 @@ TEST_F(FlagHelperTest, Int32ParseError) { DEFINE_int32(int_1, 0, "Test int flag"); const char* argv[] = {"test_program", "--int_1=value"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); @@ -304,11 +323,57 @@ TEST_F(FlagHelperTest, Int32ParseError) { FILE* orig = stdout; stdout = stderr; - ASSERT_EXIT(brillo::FlagHelper::Init(arraysize(argv), - argv, - "TestInt32ParseError"), - ::testing::ExitedWithCode(EX_DATAERR), - "ERROR: illegal value 'value' specified for int flag 'int_1'"); + ASSERT_EXIT( + brillo::FlagHelper::Init(base::size(argv), argv, "TestInt32ParseError"), + ::testing::ExitedWithCode(EX_DATAERR), + "ERROR: illegal value 'value' specified for int flag 'int_1'"); + + stdout = orig; +} + +// Test that when passing an incorrect/unparsable type to a command line flag, +// the program exits with code EX_DATAERR and outputs a corresponding message. +TEST_F(FlagHelperTest, Uint32ParseErrorUppperBound) { + DEFINE_uint32(uint32_1, 0, "Test uint32 flag"); + + // test with UINT32_MAX + 1 + const char* argv[] = {"test_program", "--uint32_1=4294967296"}; + base::CommandLine command_line(base::size(argv), argv); + + brillo::FlagHelper::GetInstance()->set_command_line_for_testing( + &command_line); + + FILE* orig = stdout; + stdout = stderr; + + ASSERT_EXIT( + brillo::FlagHelper::Init(base::size(argv), argv, "TestUint32ParseError"), + ::testing::ExitedWithCode(EX_DATAERR), + "ERROR: illegal value '4294967296' specified for uint32 flag " + "'uint32_1'"); + + stdout = orig; +} + +// Test that when passing an incorrect/unparsable type to a command line flag, +// the program exits with code EX_DATAERR and outputs a corresponding message. +TEST_F(FlagHelperTest, Uint32ParseErrorNegativeValue) { + DEFINE_uint32(uint32_1, 0, "Test uint32 flag"); + + const char* argv[] = {"test_program", "--uint32_1=-1"}; + base::CommandLine command_line(base::size(argv), argv); + + brillo::FlagHelper::GetInstance()->set_command_line_for_testing( + &command_line); + + FILE* orig = stdout; + stdout = stderr; + + ASSERT_EXIT( + brillo::FlagHelper::Init(base::size(argv), argv, "TestUint32ParseError"), + ::testing::ExitedWithCode(EX_DATAERR), + "ERROR: illegal value '-1' specified for uint32 flag " + "'uint32_1'"); stdout = orig; } @@ -319,7 +384,7 @@ TEST_F(FlagHelperTest, Int64ParseError) { DEFINE_int64(int64_1, 0, "Test int64 flag"); const char* argv[] = {"test_program", "--int64_1=value"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); @@ -328,7 +393,7 @@ TEST_F(FlagHelperTest, Int64ParseError) { stdout = stderr; ASSERT_EXIT( - brillo::FlagHelper::Init(arraysize(argv), argv, "TestInt64ParseError"), + brillo::FlagHelper::Init(base::size(argv), argv, "TestInt64ParseError"), ::testing::ExitedWithCode(EX_DATAERR), "ERROR: illegal value 'value' specified for int64 flag " "'int64_1'"); @@ -342,7 +407,7 @@ TEST_F(FlagHelperTest, UInt64ParseError) { DEFINE_uint64(uint64_1, 0, "Test uint64 flag"); const char* argv[] = {"test_program", "--uint64_1=value"}; - base::CommandLine command_line(arraysize(argv), argv); + base::CommandLine command_line(base::size(argv), argv); brillo::FlagHelper::GetInstance()->set_command_line_for_testing( &command_line); @@ -351,7 +416,7 @@ TEST_F(FlagHelperTest, UInt64ParseError) { stdout = stderr; ASSERT_EXIT( - brillo::FlagHelper::Init(arraysize(argv), argv, "TestUInt64ParseError"), + brillo::FlagHelper::Init(base::size(argv), argv, "TestUInt64ParseError"), ::testing::ExitedWithCode(EX_DATAERR), "ERROR: illegal value 'value' specified for uint64 flag " "'uint64_1'"); diff --git a/brillo/glib/dbus.h b/brillo/glib/dbus.h index 7a28480..0e756bf 100644 --- a/brillo/glib/dbus.h +++ b/brillo/glib/dbus.h @@ -13,6 +13,7 @@ #include <algorithm> #include <string> +#include <utility> #include "base/logging.h" #include <brillo/brillo_export.h> diff --git a/brillo/glib/object.h b/brillo/glib/object.h index 15de52c..56d38a4 100644 --- a/brillo/glib/object.h +++ b/brillo/glib/object.h @@ -8,13 +8,14 @@ #include <glib-object.h> #include <stdint.h> -#include <base/logging.h> -#include <base/macros.h> - #include <algorithm> #include <cstddef> #include <memory> #include <string> +#include <utility> + +#include <base/logging.h> +#include <base/macros.h> namespace brillo { diff --git a/brillo/glib/object_unittest.cc b/brillo/glib/object_test.cc index a1ed408..a1ed408 100644 --- a/brillo/glib/object_unittest.cc +++ b/brillo/glib/object_test.cc diff --git a/brillo/http/http_connection_curl.cc b/brillo/http/http_connection_curl.cc index 3720330..6f1b3ed 100644 --- a/brillo/http/http_connection_curl.cc +++ b/brillo/http/http_connection_curl.cc @@ -4,6 +4,8 @@ #include <brillo/http/http_connection_curl.h> +#include <utility> + #include <base/logging.h> #include <brillo/http/http_request.h> #include <brillo/http/http_transport_curl.h> diff --git a/brillo/http/http_connection_curl.h b/brillo/http/http_connection_curl.h index c34de57..81008e1 100644 --- a/brillo/http/http_connection_curl.h +++ b/brillo/http/http_connection_curl.h @@ -6,6 +6,7 @@ #define LIBBRILLO_BRILLO_HTTP_HTTP_CONNECTION_CURL_H_ #include <map> +#include <memory> #include <string> #include <vector> diff --git a/brillo/http/http_connection_curl_unittest.cc b/brillo/http/http_connection_curl_test.cc index 90a5626..d908ac0 100644 --- a/brillo/http/http_connection_curl_unittest.cc +++ b/brillo/http/http_connection_curl_test.cc @@ -6,6 +6,7 @@ #include <algorithm> #include <set> +#include <utility> #include <base/callback.h> #include <brillo/http/http_request.h> diff --git a/brillo/http/http_connection_fake.cc b/brillo/http/http_connection_fake.cc index 15e5181..dbd9f90 100644 --- a/brillo/http/http_connection_fake.cc +++ b/brillo/http/http_connection_fake.cc @@ -4,8 +4,8 @@ #include <brillo/http/http_connection_fake.h> +#include <base/bind.h> #include <base/logging.h> -#include <brillo/bind_lambda.h> #include <brillo/http/http_request.h> #include <brillo/mime_utils.h> #include <brillo/streams/memory_stream.h> diff --git a/brillo/http/http_connection_fake.h b/brillo/http/http_connection_fake.h index a6ebeee..402d6f9 100644 --- a/brillo/http/http_connection_fake.h +++ b/brillo/http/http_connection_fake.h @@ -6,7 +6,9 @@ #define LIBBRILLO_BRILLO_HTTP_HTTP_CONNECTION_FAKE_H_ #include <map> +#include <memory> #include <string> +#include <utility> #include <vector> #include <base/macros.h> diff --git a/brillo/http/http_form_data.cc b/brillo/http/http_form_data.cc index 4d8f6f0..eb1d028 100644 --- a/brillo/http/http_form_data.cc +++ b/brillo/http/http_form_data.cc @@ -5,10 +5,12 @@ #include <brillo/http/http_form_data.h> #include <limits> +#include <utility> #include <base/format_macros.h> #include <base/rand_util.h> #include <base/strings/stringprintf.h> +#include <base/strings/string_util.h> #include <brillo/errors/error_codes.h> #include <brillo/http/http_transport.h> @@ -141,8 +143,18 @@ bool MultiPartFormField::ExtractDataStreams(std::vector<StreamPtr>* streams) { } std::string MultiPartFormField::GetContentType() const { + // Quote the boundary only if it has non-alphanumeric chars in it. + // https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html + bool use_quotes = false; + for (auto ch : boundary_) { + if (!base::IsAsciiAlpha(ch) && !base::IsAsciiDigit(ch)) { + use_quotes = true; + break; + } + } return base::StringPrintf( - "%s; boundary=\"%s\"", content_type_.c_str(), boundary_.c_str()); + use_quotes ? "%s; boundary=\"%s\"" : "%s; boundary=%s", + content_type_.c_str(), boundary_.c_str()); } void MultiPartFormField::AddCustomField(std::unique_ptr<FormField> field) { @@ -180,7 +192,7 @@ std::string MultiPartFormField::GetBoundaryStart() const { } std::string MultiPartFormField::GetBoundaryEnd() const { - return base::StringPrintf("--%s--", boundary_.c_str()); + return base::StringPrintf("--%s--\r\n", boundary_.c_str()); } FormData::FormData() : FormData{std::string{}} { diff --git a/brillo/http/http_form_data_fuzzer.cc b/brillo/http/http_form_data_fuzzer.cc new file mode 100644 index 0000000..f73a89f --- /dev/null +++ b/brillo/http/http_form_data_fuzzer.cc @@ -0,0 +1,128 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <stddef.h> +#include <stdint.h> + +#include <base/files/file_path.h> +#include <base/files/file_util.h> +#include <base/files/scoped_temp_dir.h> +#include <base/logging.h> +#include <brillo/http/http_form_data.h> +#include <brillo/streams/memory_stream.h> +#include <fuzzer/FuzzedDataProvider.h> + +namespace { +constexpr int kRandomDataMaxLength = 64; +constexpr int kMaxRecursionDepth = 256; + +std::unique_ptr<brillo::http::TextFormField> CreateTextFormField( + FuzzedDataProvider* data_provider) { + return std::make_unique<brillo::http::TextFormField>( + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength)); +} + +std::unique_ptr<brillo::http::FileFormField> CreateFileFormField( + FuzzedDataProvider* data_provider) { + brillo::StreamPtr mem_stream = brillo::MemoryStream::OpenCopyOf( + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), nullptr); + return std::make_unique<brillo::http::FileFormField>( + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + std::move(mem_stream), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength)); +} + +std::unique_ptr<brillo::http::MultiPartFormField> CreateMultipartFormField( + FuzzedDataProvider* data_provider, int depth) { + std::unique_ptr<brillo::http::MultiPartFormField> multipart_field = + std::make_unique<brillo::http::MultiPartFormField>( + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength), + data_provider->ConsumeRandomLengthString(kRandomDataMaxLength)); + + // Randomly add fields to this like we do the base FormData, but don't loop + // forever. + while (data_provider->ConsumeBool()) { + if (data_provider->ConsumeBool()) { + // Add a random text field to the form. + multipart_field->AddCustomField(CreateTextFormField(data_provider)); + } + if (data_provider->ConsumeBool()) { + // Add a random file field to the form. + multipart_field->AddCustomField(CreateFileFormField(data_provider)); + } + // Limit our recursion depth. We could make this part of our code iterative, + // but that won't help because in libbrillo we use recursion to generate the + // stream so we would hit a stack depth limit there as well. + if (depth < kMaxRecursionDepth && data_provider->ConsumeBool()) { + // Add a random multipart form field to the form. + multipart_field->AddCustomField( + CreateMultipartFormField(data_provider, depth + 1)); + } + } + + return multipart_field; +} + +} // namespace + +bool IgnoreLogging(int, const char*, int, size_t, const std::string&) { + return true; +} + +class Environment { + public: + Environment() { + // Disable logging. Normally this would be done with logging::SetMinLogLevel + // but that doesn't work for brillo::Error for because it's not using the + // LOG(ERROR) macro which is where the actual log level check occurs. + logging::SetLogMessageHandler(&IgnoreLogging); + } +}; + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + static Environment env; + FuzzedDataProvider data_provider(data, size); + // Randomly add a bunch of fields to the FormData and then when done extract + // and consume the data stream. + brillo::http::FormData form_data( + data_provider.ConsumeRandomLengthString(kRandomDataMaxLength)); + while (data_provider.remaining_bytes() > 0) { + if (data_provider.ConsumeBool()) { + // Add a random text field to the form. + form_data.AddCustomField(CreateTextFormField(&data_provider)); + } + if (data_provider.ConsumeBool()) { + // Add a random file field to the form. + form_data.AddCustomField(CreateFileFormField(&data_provider)); + } + if (data_provider.ConsumeBool()) { + // Add a random multipart form field to the form. + form_data.AddCustomField(CreateMultipartFormField(&data_provider, 0)); + } + } + + brillo::StreamPtr form_stream = form_data.ExtractDataStream(); + if (!form_stream) + return 0; + + // We need to use a decent sized buffer and call ReadAllBlocking to avoid + // excess overhead with reading here that can make the fuzzer timeout. + uint8_t buffer[32768]; + while (form_stream->GetRemainingSize() > 0) { + if (!form_stream->ReadAllBlocking(buffer, sizeof(buffer), nullptr)) { + // If there's an error reading from the stream, then bail since we'd + // likely just see repeated errors and never exit. + break; + } + } + + return 0; +} diff --git a/brillo/http/http_form_data_unittest.cc b/brillo/http/http_form_data_test.cc index 34288d0..80bf30a 100644 --- a/brillo/http/http_form_data_unittest.cc +++ b/brillo/http/http_form_data_test.cc @@ -5,6 +5,7 @@ #include <brillo/http/http_form_data.h> #include <set> +#include <utility> #include <base/files/file_util.h> #include <base/files/scoped_temp_dir.h> @@ -94,7 +95,7 @@ TEST(HttpFormData, MultiPartFormField) { nullptr)); const char expected_header[] = "Content-Disposition: form-data; name=\"foo\"\r\n" - "Content-Type: multipart/form-data; boundary=\"Delimiter\"\r\n" + "Content-Type: multipart/form-data; boundary=Delimiter\r\n" "\r\n"; EXPECT_EQ(expected_header, form_field.GetContentHeader()); const char expected_data[] = @@ -116,7 +117,7 @@ TEST(HttpFormData, MultiPartFormField) { "Content-Transfer-Encoding: binary\r\n" "\r\n" "\x01\x02\x03\x04\x05\r\n" - "--Delimiter--"; + "--Delimiter--\r\n"; EXPECT_EQ(expected_data, GetFormFieldData(&form_field)); } @@ -158,7 +159,7 @@ TEST(HttpFormData, FormData) { FormData form_data{"boundary1"}; form_data.AddTextField("name", "John Doe"); std::unique_ptr<MultiPartFormField> files{ - new MultiPartFormField{"files", "", "boundary2"}}; + new MultiPartFormField{"files", "", "boundary 2"}}; EXPECT_TRUE(files->AddFileField( "", filename1, content_disposition::kFile, mime::text::kPlain, nullptr)); EXPECT_TRUE(files->AddFileField("", @@ -167,7 +168,7 @@ TEST(HttpFormData, FormData) { mime::application::kOctet_stream, nullptr)); form_data.AddCustomField(std::move(files)); - EXPECT_EQ("multipart/form-data; boundary=\"boundary1\"", + EXPECT_EQ("multipart/form-data; boundary=boundary1", form_data.GetContentType()); StreamPtr stream = form_data.ExtractDataStream(); @@ -180,22 +181,22 @@ TEST(HttpFormData, FormData) { "John Doe\r\n" "--boundary1\r\n" "Content-Disposition: form-data; name=\"files\"\r\n" - "Content-Type: multipart/mixed; boundary=\"boundary2\"\r\n" + "Content-Type: multipart/mixed; boundary=\"boundary 2\"\r\n" "\r\n" - "--boundary2\r\n" + "--boundary 2\r\n" "Content-Disposition: file; filename=\"sample.txt\"\r\n" "Content-Type: text/plain\r\n" "Content-Transfer-Encoding: binary\r\n" "\r\n" "text line1\ntext line2\n\r\n" - "--boundary2\r\n" + "--boundary 2\r\n" "Content-Disposition: file; filename=\"test.bin\"\r\n" "Content-Type: application/octet-stream\r\n" "Content-Transfer-Encoding: binary\r\n" "\r\n" "\x01\x02\x03\x04\x05\r\n" - "--boundary2--\r\n" - "--boundary1--"; + "--boundary 2--\r\n\r\n" + "--boundary1--\r\n"; EXPECT_EQ(expected_data, (std::string{data.begin(), data.end()})); } } // namespace http diff --git a/brillo/http/http_proxy.cc b/brillo/http/http_proxy.cc index bf6a8af..b697518 100644 --- a/brillo/http/http_proxy.cc +++ b/brillo/http/http_proxy.cc @@ -6,6 +6,7 @@ #include <memory> #include <string> +#include <utility> #include <vector> #include <base/bind.h> diff --git a/brillo/http/http_proxy.h b/brillo/http/http_proxy.h index c142af2..46863b6 100644 --- a/brillo/http/http_proxy.h +++ b/brillo/http/http_proxy.h @@ -32,13 +32,13 @@ using GetChromeProxyServersCallback = // Even if this function returns false, it will still set |proxies_out| to be // just the direct proxy. This function will only return false if there is an // error in the D-Bus communication itself. -BRILLO_EXPORT bool GetChromeProxyServers(scoped_refptr<dbus::Bus> bus, +BRILLO_EXPORT bool GetChromeProxyServers(scoped_refptr<::dbus::Bus> bus, const std::string& url, std::vector<std::string>* proxies_out); // Async version of GetChromeProxyServers. BRILLO_EXPORT void GetChromeProxyServersAsync( - scoped_refptr<dbus::Bus> bus, + scoped_refptr<::dbus::Bus> bus, const std::string& url, const GetChromeProxyServersCallback& callback); diff --git a/brillo/http/http_proxy_unittest.cc b/brillo/http/http_proxy_test.cc index 4893a87..a0d1bfa 100644 --- a/brillo/http/http_proxy_unittest.cc +++ b/brillo/http/http_proxy_test.cc @@ -4,7 +4,9 @@ #include <brillo/http/http_proxy.h> +#include <memory> #include <string> +#include <utility> #include <vector> #include <base/bind.h> @@ -30,25 +32,25 @@ class HttpProxyTest : public testing::Test { public: void ResolveProxyHandlerAsync(dbus::MethodCall* method_call, int timeout_msec, - dbus::ObjectProxy::ResponseCallback callback) { + dbus::ObjectProxy::ResponseCallback* callback) { if (null_dbus_response_) { - callback.Run(nullptr); + std::move(*callback).Run(nullptr); return; } - callback.Run(CreateDBusResponse(method_call).get()); + std::move(*callback).Run(CreateDBusResponse(method_call).get()); } - dbus::Response* ResolveProxyHandler(dbus::MethodCall* method_call, - int timeout_msec) { + std::unique_ptr<dbus::Response> ResolveProxyHandler( + dbus::MethodCall* method_call, int timeout_msec) { if (null_dbus_response_) { - return nullptr; + return std::unique_ptr<dbus::Response>(); } - // The mock wraps this back into a std::unique_ptr in the function calling - // us. - return CreateDBusResponse(method_call).release(); + return CreateDBusResponse(method_call); } - MOCK_METHOD2(GetProxiesCallback, void(bool, const std::vector<std::string>&)); + MOCK_METHOD(void, + GetProxiesCallback, + (bool, const std::vector<std::string>&)); protected: HttpProxyTest() { @@ -97,7 +99,7 @@ class HttpProxyTest : public testing::Test { TEST_F(HttpProxyTest, DBusNullResponseFails) { std::vector<std::string> proxies; null_dbus_response_ = true; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, CallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_FALSE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); } @@ -105,14 +107,14 @@ TEST_F(HttpProxyTest, DBusNullResponseFails) { TEST_F(HttpProxyTest, DBusInvalidResponseFails) { std::vector<std::string> proxies; invalid_dbus_response_ = true; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, CallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_FALSE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); } TEST_F(HttpProxyTest, NoProxies) { std::vector<std::string> proxies; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, CallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_TRUE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); EXPECT_THAT(proxies, ElementsAre(kDirectProxy)); @@ -121,7 +123,7 @@ TEST_F(HttpProxyTest, NoProxies) { TEST_F(HttpProxyTest, MultipleProxiesWithoutDirect) { proxy_info_ = "proxy example.com; socks5 foo.com;"; std::vector<std::string> proxies; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, CallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_TRUE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); EXPECT_THAT(proxies, ElementsAre("http://example.com", "socks5://foo.com", @@ -132,7 +134,7 @@ TEST_F(HttpProxyTest, MultipleProxiesWithDirect) { proxy_info_ = "socks foo.com; Https example.com ; badproxy example2.com ; " "socks5 test.com ; proxy foobar.com; DIRECT "; std::vector<std::string> proxies; - EXPECT_CALL(*object_proxy_, MockCallMethodAndBlock(_, _)) + EXPECT_CALL(*object_proxy_, CallMethodAndBlock(_, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandler)); EXPECT_TRUE(GetChromeProxyServers(bus_, kTestUrl, &proxies)); EXPECT_THAT(proxies, ElementsAre("socks4://foo.com", "https://example.com", @@ -142,7 +144,7 @@ TEST_F(HttpProxyTest, MultipleProxiesWithDirect) { TEST_F(HttpProxyTest, DBusNullResponseFailsAsync) { null_dbus_response_ = true; - EXPECT_CALL(*object_proxy_, CallMethod(_, _, _)) + EXPECT_CALL(*object_proxy_, DoCallMethod(_, _, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandlerAsync)); EXPECT_CALL(*this, GetProxiesCallback(false, _)); GetChromeProxyServersAsync( @@ -152,7 +154,7 @@ TEST_F(HttpProxyTest, DBusNullResponseFailsAsync) { TEST_F(HttpProxyTest, DBusInvalidResponseFailsAsync) { invalid_dbus_response_ = true; - EXPECT_CALL(*object_proxy_, CallMethod(_, _, _)) + EXPECT_CALL(*object_proxy_, DoCallMethod(_, _, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandlerAsync)); EXPECT_CALL(*this, GetProxiesCallback(false, _)); GetChromeProxyServersAsync( @@ -168,7 +170,7 @@ TEST_F(HttpProxyTest, MultipleProxiesWithDirectAsync) { std::vector<std::string> expected = { "socks4://foo.com", "https://example.com", "socks5://test.com", "http://foobar.com", kDirectProxy}; - EXPECT_CALL(*object_proxy_, CallMethod(_, _, _)) + EXPECT_CALL(*object_proxy_, DoCallMethod(_, _, _)) .WillOnce(Invoke(this, &HttpProxyTest::ResolveProxyHandlerAsync)); EXPECT_CALL(*this, GetProxiesCallback(true, expected)); GetChromeProxyServersAsync( diff --git a/brillo/http/http_request_unittest.cc b/brillo/http/http_request_test.cc index 39ccc18..e0be38b 100644 --- a/brillo/http/http_request_unittest.cc +++ b/brillo/http/http_request_test.cc @@ -6,8 +6,8 @@ #include <string> +#include <base/bind.h> #include <base/callback.h> -#include <brillo/bind_lambda.h> #include <brillo/http/mock_connection.h> #include <brillo/http/mock_transport.h> #include <brillo/mime_utils.h> diff --git a/brillo/http/http_transport.cc b/brillo/http/http_transport.cc index 0c27489..d713e50 100644 --- a/brillo/http/http_transport.cc +++ b/brillo/http/http_transport.cc @@ -26,5 +26,32 @@ std::shared_ptr<Transport> Transport::CreateDefaultWithProxy( } } +base::FilePath Transport::CertificateToPath(Transport::Certificate cert) { + const char* str; + switch (cert) { + case Certificate::kDefault: + str = +#ifdef __ANDROID__ + "/system/etc/security/cacerts_google"; +#else + "/usr/share/chromeos-ca-certificates"; +#endif + break; + case Certificate::kHermesProd: + str = "/usr/share/hermes-ca-certificates/prod"; + break; + case Certificate::kHermesTest: + str = "/usr/share/hermes-ca-certificates/test"; + break; + case Certificate::kNss: + str = "/etc/ssl/certs"; + break; + default: + CHECK(false) << "Invalid certificate"; + break; + } + return base::FilePath(str); +} + } // namespace http } // namespace brillo diff --git a/brillo/http/http_transport.h b/brillo/http/http_transport.h index e00166c..76ff901 100644 --- a/brillo/http/http_transport.h +++ b/brillo/http/http_transport.h @@ -11,6 +11,7 @@ #include <vector> #include <base/callback_forward.h> +#include <base/files/file_path.h> #include <base/location.h> #include <base/macros.h> #include <base/time/time.h> @@ -38,10 +39,26 @@ using ErrorCallback = base::Callback<void(RequestID, const brillo::Error*)>; /////////////////////////////////////////////////////////////////////////////// // Transport is a base class for specific implementation of HTTP communication. // This class (and its underlying implementation) is used by http::Request and -// http::Response classes to provide HTTP functionality to the clients. +// http::Response classes to provide HTTP functionality to the clients. By +// default, this interface will use CA certificates that only allow secure +// (HTTPS) communication with Google services. /////////////////////////////////////////////////////////////////////////////// class BRILLO_EXPORT Transport : public std::enable_shared_from_this<Transport> { public: + enum class Certificate { + // Default certificate; only allows communication with Google services. + kDefault, + // Certificates for communicating only with production SM-DP+ and SM-DS + // servers. + kHermesProd, + // Certificates for communicating only with test SM-DP+ and SM-DS servers. + kHermesTest, + // The NSS certificate store, which the curl command-line tool and libcurl + // library use by default. This set of certificates does not restrict + // secure communication to only Google services. + kNss, + }; + Transport() = default; virtual ~Transport() = default; @@ -87,6 +104,28 @@ class BRILLO_EXPORT Transport : public std::enable_shared_from_this<Transport> { // Set the local IP address of requests virtual void SetLocalIpAddress(const std::string& ip_address) = 0; + // Use the default CA certificate for certificate verification. This + // means that clients are only allowed to communicate with Google services. + virtual void UseDefaultCertificate() {} + + // Set the CA certificate to use for certificate verification. + // + // This call can allow a client to securly communicate with a different subset + // of services than it can otherwise. However, setting a custom certificate + // should be done only when necessary, and should be done with careful control + // over the certificates that are contained in the relevant path. See + // https://chromium.googlesource.com/chromiumos/docs/+/master/ca_certs.md for + // more information on certificates in Chrome OS. + virtual void UseCustomCertificate(Transport::Certificate cert) {} + + // Appends host entry to DNS cache. curl can only do HTTPS request to a custom + // IP if it resolves an HTTPS hostname to that IP. This is useful in + // forcing a particular mapping for an HTTPS host. See CURLOPT_RESOLVE for + // more details. + virtual void ResolveHostToIp(const std::string& host, + uint16_t port, + const std::string& ip_address) {} + // Creates a default http::Transport (currently, using http::curl::Transport). static std::shared_ptr<Transport> CreateDefault(); @@ -97,6 +136,12 @@ class BRILLO_EXPORT Transport : public std::enable_shared_from_this<Transport> { static std::shared_ptr<Transport> CreateDefaultWithProxy( const std::string& proxy); + protected: + // Clears the forced DNS mappings created by ResolveHostToIp. + virtual void ClearHost() {} + + static base::FilePath CertificateToPath(Certificate cert); + private: DISALLOW_COPY_AND_ASSIGN(Transport); }; diff --git a/brillo/http/http_transport_curl.cc b/brillo/http/http_transport_curl.cc index 9affc2a..de6899a 100644 --- a/brillo/http/http_transport_curl.cc +++ b/brillo/http/http_transport_curl.cc @@ -7,23 +7,15 @@ #include <limits> #include <base/bind.h> +#include <base/files/file_descriptor_watcher_posix.h> +#include <base/files/file_util.h> #include <base/logging.h> -#include <base/message_loop/message_loop.h> +#include <base/strings/stringprintf.h> +#include <base/threading/thread_task_runner_handle.h> #include <brillo/http/http_connection_curl.h> #include <brillo/http/http_request.h> #include <brillo/strings/string_utils.h> -namespace { - -const char kCACertificatePath[] = -#ifdef __ANDROID__ - "/system/etc/security/cacerts_google"; -#else - "/usr/share/brillo-ca-certificates"; -#endif - -} // namespace - namespace brillo { namespace http { namespace curl { @@ -31,7 +23,7 @@ namespace curl { // This is a class that stores connection data on particular CURL socket // and provides file descriptor watcher to monitor read and/or write operations // on the socket's file descriptor. -class Transport::SocketPollData : public base::MessagePumpForIO::FdWatcher { +class Transport::SocketPollData { public: SocketPollData(const std::shared_ptr<CurlInterface>& curl_interface, CURLM* curl_multi_handle, @@ -40,27 +32,35 @@ class Transport::SocketPollData : public base::MessagePumpForIO::FdWatcher { : curl_interface_(curl_interface), curl_multi_handle_(curl_multi_handle), transport_(transport), - socket_fd_(socket_fd), - file_descriptor_watcher_(FROM_HERE) {} + socket_fd_(socket_fd) {} - // Returns the pointer for the socket-specific file descriptor watcher. - base::MessagePumpForIO::FdWatchController* GetWatcher() { - return &file_descriptor_watcher_; + void StopWatcher() { + read_watcher_ = nullptr; + write_watcher_ = nullptr; } - private: - // Overrides from base::MessagePumpForIO::Watcher. - void OnFileCanReadWithoutBlocking(int fd) override { - OnSocketReady(fd, CURL_CSELECT_IN); + bool WatchReadable() { + read_watcher_ = base::FileDescriptorWatcher::WatchReadable( + socket_fd_, + base::BindRepeating(&Transport::SocketPollData::OnSocketReady, + base::Unretained(this), + CURL_CSELECT_IN)); + return read_watcher_.get(); } - void OnFileCanWriteWithoutBlocking(int fd) override { - OnSocketReady(fd, CURL_CSELECT_OUT); + + bool WatchWritable() { + write_watcher_ = base::FileDescriptorWatcher::WatchWritable( + socket_fd_, + base::BindRepeating(&Transport::SocketPollData::OnSocketReady, + base::Unretained(this), + CURL_CSELECT_OUT)); + return write_watcher_.get(); } + private: // Data on the socket is available to be read from or written to. // Notify CURL of the action it needs to take on the socket file descriptor. - void OnSocketReady(int fd, int action) { - CHECK_EQ(socket_fd_, fd) << "Unexpected socket file descriptor"; + void OnSocketReady(int action) { int still_running_count = 0; CURLMcode code = curl_interface_->MultiSocketAction( curl_multi_handle_, socket_fd_, action, &still_running_count); @@ -79,8 +79,9 @@ class Transport::SocketPollData : public base::MessagePumpForIO::FdWatcher { Transport* transport_; // The socket file descriptor for the connection. curl_socket_t socket_fd_; - // File descriptor watcher to notify us of asynchronous I/O on the FD. - base::MessagePumpForIO::FdWatchController file_descriptor_watcher_; + + std::unique_ptr<base::FileDescriptorWatcher::Controller> read_watcher_; + std::unique_ptr<base::FileDescriptorWatcher::Controller> write_watcher_; DISALLOW_COPY_AND_ASSIGN(SocketPollData); }; @@ -101,15 +102,18 @@ struct Transport::AsyncRequestData { Transport::Transport(const std::shared_ptr<CurlInterface>& curl_interface) : curl_interface_{curl_interface} { VLOG(2) << "curl::Transport created"; + UseDefaultCertificate(); } Transport::Transport(const std::shared_ptr<CurlInterface>& curl_interface, const std::string& proxy) : curl_interface_{curl_interface}, proxy_{proxy} { VLOG(2) << "curl::Transport created with proxy " << proxy; + UseDefaultCertificate(); } Transport::~Transport() { + ClearHost(); ShutDownAsyncCurl(); VLOG(2) << "curl::Transport destroyed"; } @@ -134,8 +138,14 @@ std::shared_ptr<http::Connection> Transport::CreateConnection( CURLcode code = curl_interface_->EasySetOptStr(curl_handle, CURLOPT_URL, url); if (code == CURLE_OK) { + // CURLOPT_CAINFO is a string, but CurlApi::EasySetOptStr will never pass + // curl_easy_setopt a null pointer, so we use EasySetOptPtr instead. + code = curl_interface_->EasySetOptPtr(curl_handle, CURLOPT_CAINFO, nullptr); + } + if (code == CURLE_OK) { + CHECK(base::PathExists(certificate_path_)); code = curl_interface_->EasySetOptStr(curl_handle, CURLOPT_CAPATH, - kCACertificatePath); + certificate_path_.value()); } if (code == CURLE_OK) { code = @@ -169,6 +179,10 @@ std::shared_ptr<http::Connection> Transport::CreateConnection( code = curl_interface_->EasySetOptStr( curl_handle, CURLOPT_INTERFACE, ip_address_.c_str()); } + if (code == CURLE_OK && host_list_) { + code = curl_interface_->EasySetOptPtr(curl_handle, CURLOPT_RESOLVE, + host_list_); + } // Setup HTTP request method and optional request body. if (code == CURLE_OK) { @@ -208,8 +222,7 @@ std::shared_ptr<http::Connection> Transport::CreateConnection( void Transport::RunCallbackAsync(const base::Location& from_here, const base::Closure& callback) { - base::MessageLoopForIO::current()->task_runner()->PostTask( - from_here, callback); + base::ThreadTaskRunnerHandle::Get()->PostTask(from_here, callback); } RequestID Transport::StartAsyncTransfer(http::Connection* connection, @@ -274,6 +287,29 @@ void Transport::SetLocalIpAddress(const std::string& ip_address) { ip_address_ = "host!" + ip_address; } +void Transport::UseDefaultCertificate() { + UseCustomCertificate(Certificate::kDefault); +} + +void Transport::UseCustomCertificate(Transport::Certificate cert) { + certificate_path_ = CertificateToPath(cert); + CHECK(base::PathExists(certificate_path_)); +} + +void Transport::ResolveHostToIp(const std::string& host, + uint16_t port, + const std::string& ip_address) { + host_list_ = curl_slist_append( + host_list_, + base::StringPrintf("%s:%d:%s", host.c_str(), port, ip_address.c_str()) + .c_str()); +} + +void Transport::ClearHost() { + curl_slist_free_all(host_list_); + host_list_ = nullptr; +} + void Transport::AddEasyCurlError(brillo::ErrorPtr* error, const base::Location& location, CURLcode code, @@ -359,42 +395,22 @@ int Transport::MultiSocketCallback(CURL* easy, // Make sure we stop watching the socket file descriptor now, before // we schedule the SocketPollData for deletion. - poll_data->GetWatcher()->StopWatchingFileDescriptor(); + poll_data->StopWatcher(); // This method can be called indirectly from SocketPollData::OnSocketReady, // so delay destruction of SocketPollData object till the next loop cycle. - base::MessageLoopForIO::current()->task_runner()->DeleteSoon(FROM_HERE, - poll_data); + base::ThreadTaskRunnerHandle::Get()->DeleteSoon(FROM_HERE, poll_data); return 0; } - base::MessagePumpForIO::Mode watch_mode = base::MessagePumpForIO::WATCH_READ; - switch (what) { - case CURL_POLL_IN: - watch_mode = base::MessagePumpForIO::WATCH_READ; - break; - case CURL_POLL_OUT: - watch_mode = base::MessagePumpForIO::WATCH_WRITE; - break; - case CURL_POLL_INOUT: - watch_mode = base::MessagePumpForIO::WATCH_READ_WRITE; - break; - default: - LOG(FATAL) << "Unknown CURL socket action: " << what; - break; - } + poll_data->StopWatcher(); + + bool success = true; + if (what == CURL_POLL_IN || what == CURL_POLL_INOUT) + success = poll_data->WatchReadable() && success; + if (what == CURL_POLL_OUT || what == CURL_POLL_INOUT) + success = poll_data->WatchWritable() && success; - // WatchFileDescriptor() can be called with the same controller object - // (watcher) to amend the watch mode, however this has cumulative effect. - // For example, if we were watching a file descriptor for READ operations - // and now call it to watch for WRITE, it will end up watching for both - // READ and WRITE. This is not what we want here, so stop watching the - // file descriptor on previous controller before starting with a different - // mode. - if (!poll_data->GetWatcher()->StopWatchingFileDescriptor()) - LOG(WARNING) << "Failed to stop watching the previous socket descriptor"; - CHECK(base::MessageLoopForIO::current()->WatchFileDescriptor( - s, true, watch_mode, poll_data->GetWatcher(), poll_data)) - << "Failed to watch the CURL socket."; + CHECK(success) << "Failed to watch the CURL socket."; return 0; } @@ -406,11 +422,11 @@ int Transport::MultiTimerCallback(CURLM* /* multi */, // Cancel any previous timer callbacks. transport->weak_ptr_factory_for_timer_.InvalidateWeakPtrs(); if (timeout_ms >= 0) { - base::MessageLoopForIO::current()->task_runner()->PostDelayedTask( - FROM_HERE, - base::Bind(&Transport::OnTimer, - transport->weak_ptr_factory_for_timer_.GetWeakPtr()), - base::TimeDelta::FromMilliseconds(timeout_ms)); + base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( + FROM_HERE, + base::Bind(&Transport::OnTimer, + transport->weak_ptr_factory_for_timer_.GetWeakPtr()), + base::TimeDelta::FromMilliseconds(timeout_ms)); } return 0; } diff --git a/brillo/http/http_transport_curl.h b/brillo/http/http_transport_curl.h index 175a675..5af2c61 100644 --- a/brillo/http/http_transport_curl.h +++ b/brillo/http/http_transport_curl.h @@ -6,9 +6,11 @@ #define LIBBRILLO_BRILLO_HTTP_HTTP_TRANSPORT_CURL_H_ #include <map> +#include <memory> #include <string> #include <utility> +#include <base/location.h> #include <base/memory/weak_ptr.h> #include <brillo/brillo_export.h> #include <brillo/http/curl_api.h> @@ -61,6 +63,14 @@ class BRILLO_EXPORT Transport : public http::Transport { void SetLocalIpAddress(const std::string& ip_address) override; + void UseDefaultCertificate() override; + + void UseCustomCertificate(Certificate cert) override; + + void ResolveHostToIp(const std::string& host, + uint16_t port, + const std::string& ip_address) override; + // Helper methods to convert CURL error codes (CURLcode and CURLMcode) // into brillo::Error object. static void AddEasyCurlError(brillo::ErrorPtr* error, @@ -73,6 +83,9 @@ class BRILLO_EXPORT Transport : public http::Transport { CURLMcode code, CurlInterface* curl_interface); + protected: + void ClearHost() override; + private: // Forward-declaration of internal implementation structures. struct AsyncRequestData; @@ -130,6 +143,8 @@ class BRILLO_EXPORT Transport : public http::Transport { // The connection timeout for the requests made. base::TimeDelta connection_timeout_; std::string ip_address_; + base::FilePath certificate_path_; + curl_slist* host_list_{nullptr}; base::WeakPtrFactory<Transport> weak_ptr_factory_for_timer_{this}; base::WeakPtrFactory<Transport> weak_ptr_factory_{this}; diff --git a/brillo/http/http_transport_curl_unittest.cc b/brillo/http/http_transport_curl_test.cc index c05c81a..6e94978 100644 --- a/brillo/http/http_transport_curl_unittest.cc +++ b/brillo/http/http_transport_curl_test.cc @@ -5,9 +5,10 @@ #include <brillo/http/http_transport_curl.h> #include <base/at_exit.h> +#include <base/bind.h> #include <base/message_loop/message_loop.h> #include <base/run_loop.h> -#include <brillo/bind_lambda.h> +#include <base/threading/thread_task_runner_handle.h> #include <brillo/http/http_connection_curl.h> #include <brillo/http/http_request.h> #include <brillo/http/mock_curl_api.h> @@ -33,6 +34,8 @@ class HttpCurlTransportTest : public testing::Test { transport_ = std::make_shared<Transport>(curl_api_); handle_ = reinterpret_cast<CURL*>(100); // Mock handle value. EXPECT_CALL(*curl_api_, EasyInit()).WillOnce(Return(handle_)); + EXPECT_CALL(*curl_api_, EasySetOptPtr(handle_, CURLOPT_CAINFO, _)) + .WillOnce(Return(CURLE_OK)); EXPECT_CALL(*curl_api_, EasySetOptStr(handle_, CURLOPT_CAPATH, _)) .WillOnce(Return(CURLE_OK)); EXPECT_CALL(*curl_api_, EasySetOptInt(handle_, CURLOPT_SSL_VERIFYPEER, 1)) @@ -197,6 +200,8 @@ class HttpCurlTransportAsyncTest : public testing::Test { curl_api_ = std::make_shared<MockCurlInterface>(); transport_ = std::make_shared<Transport>(curl_api_); EXPECT_CALL(*curl_api_, EasyInit()).WillOnce(Return(handle_)); + EXPECT_CALL(*curl_api_, EasySetOptPtr(handle_, CURLOPT_CAINFO, _)) + .WillOnce(Return(CURLE_OK)); EXPECT_CALL(*curl_api_, EasySetOptStr(handle_, CURLOPT_CAPATH, _)) .WillOnce(Return(CURLE_OK)); EXPECT_CALL(*curl_api_, EasySetOptInt(handle_, CURLOPT_SSL_VERIFYPEER, 1)) @@ -238,7 +243,7 @@ TEST_F(HttpCurlTransportAsyncTest, StartAsyncTransfer) { auto success_callback = base::Bind([]( int* success_call_count, const base::Closure& quit_closure, RequestID /* request_id */, std::unique_ptr<http::Response> /* resp */) { - base::MessageLoop::current()->task_runner()->PostTask( + base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, quit_closure); (*success_call_count)++; }, &success_call_count, run_loop.QuitClosure()); @@ -333,6 +338,23 @@ TEST_F(HttpCurlTransportTest, RequestGetTimeout) { connection.reset(); } +TEST_F(HttpCurlTransportTest, RequestGetResolveHost) { + transport_->ResolveHostToIp("foo.bar", 80, "127.0.0.1"); + EXPECT_CALL(*curl_api_, + EasySetOptStr(handle_, CURLOPT_URL, "http://foo.bar/get")) + .WillOnce(Return(CURLE_OK)); + EXPECT_CALL(*curl_api_, EasySetOptPtr(handle_, CURLOPT_RESOLVE, _)) + .WillOnce(Return(CURLE_OK)); + EXPECT_CALL(*curl_api_, EasySetOptInt(handle_, CURLOPT_HTTPGET, 1)) + .WillOnce(Return(CURLE_OK)); + auto connection = transport_->CreateConnection( + "http://foo.bar/get", request_type::kGet, {}, "", "", nullptr); + EXPECT_NE(nullptr, connection.get()); + + EXPECT_CALL(*curl_api_, EasyCleanup(handle_)).Times(1); + connection.reset(); +} + } // namespace curl } // namespace http } // namespace brillo diff --git a/brillo/http/http_transport_fake.cc b/brillo/http/http_transport_fake.cc index 224b5de..c4757f9 100644 --- a/brillo/http/http_transport_fake.cc +++ b/brillo/http/http_transport_fake.cc @@ -6,10 +6,10 @@ #include <utility> +#include <base/bind.h> #include <base/json/json_reader.h> #include <base/json/json_writer.h> #include <base/logging.h> -#include <brillo/bind_lambda.h> #include <brillo/http/http_connection_fake.h> #include <brillo/http/http_request.h> #include <brillo/mime_utils.h> diff --git a/brillo/http/http_transport_fake.h b/brillo/http/http_transport_fake.h index 0a2fe90..56351ec 100644 --- a/brillo/http/http_transport_fake.h +++ b/brillo/http/http_transport_fake.h @@ -6,12 +6,15 @@ #define LIBBRILLO_BRILLO_HTTP_HTTP_TRANSPORT_FAKE_H_ #include <map> +#include <memory> #include <queue> #include <string> #include <type_traits> +#include <utility> #include <vector> #include <base/callback.h> +#include <base/location.h> #include <base/values.h> #include <brillo/http/http_transport.h> #include <brillo/http/http_utils.h> @@ -104,6 +107,13 @@ class Transport : public http::Transport { void SetLocalIpAddress(const std::string& /* ip_address */) override {} + void ResolveHostToIp(const std::string& host, + uint16_t port, + const std::string& ip_address) override {} + + protected: + void ClearHost() override {} + private: // A list of user-supplied request handlers. std::map<std::string, HandlerCallback> handlers_; diff --git a/brillo/http/http_utils.h b/brillo/http/http_utils.h index e09bab8..0d4d109 100644 --- a/brillo/http/http_utils.h +++ b/brillo/http/http_utils.h @@ -5,6 +5,7 @@ #ifndef LIBBRILLO_BRILLO_HTTP_HTTP_UTILS_H_ #define LIBBRILLO_BRILLO_HTTP_HTTP_UTILS_H_ +#include <memory> #include <string> #include <utility> #include <vector> diff --git a/brillo/http/http_utils_unittest.cc b/brillo/http/http_utils_test.cc index 376ba53..409282c 100644 --- a/brillo/http/http_utils_unittest.cc +++ b/brillo/http/http_utils_test.cc @@ -6,8 +6,8 @@ #include <string> #include <vector> +#include <base/bind.h> #include <base/values.h> -#include <brillo/bind_lambda.h> #include <brillo/http/http_transport_fake.h> #include <brillo/http/http_utils.h> #include <brillo/mime_utils.h> @@ -366,7 +366,7 @@ TEST(HttpUtils, PostMultipartFormData) { "Content-Disposition: form-data; name=\"key2\"\r\n" "\r\n" "value2\r\n" - "--boundary123--"; + "--boundary123--\r\n"; EXPECT_EQ(expected_value, response->ExtractDataAsString()); } diff --git a/brillo/http/mock_connection.h b/brillo/http/mock_connection.h index 0796a7e..1810824 100644 --- a/brillo/http/mock_connection.h +++ b/brillo/http/mock_connection.h @@ -19,17 +19,22 @@ class MockConnection : public Connection { public: using Connection::Connection; - MOCK_METHOD2(SendHeaders, bool(const HeaderList&, ErrorPtr*)); - MOCK_METHOD2(MockSetRequestData, bool(Stream*, ErrorPtr*)); - MOCK_METHOD1(MockSetResponseData, void(Stream*)); - MOCK_METHOD1(FinishRequest, bool(ErrorPtr*)); - MOCK_METHOD2(FinishRequestAsync, - RequestID(const SuccessCallback&, const ErrorCallback&)); - MOCK_CONST_METHOD0(GetResponseStatusCode, int()); - MOCK_CONST_METHOD0(GetResponseStatusText, std::string()); - MOCK_CONST_METHOD0(GetProtocolVersion, std::string()); - MOCK_CONST_METHOD1(GetResponseHeader, std::string(const std::string&)); - MOCK_CONST_METHOD1(MockExtractDataStream, Stream*(brillo::ErrorPtr*)); + MOCK_METHOD(bool, SendHeaders, (const HeaderList&, ErrorPtr*), (override)); + MOCK_METHOD(bool, MockSetRequestData, (Stream*, ErrorPtr*)); + MOCK_METHOD(void, MockSetResponseData, (Stream*)); + MOCK_METHOD(bool, FinishRequest, (ErrorPtr*), (override)); + MOCK_METHOD(RequestID, + FinishRequestAsync, + (const SuccessCallback&, const ErrorCallback&), + (override)); + MOCK_METHOD(int, GetResponseStatusCode, (), (const, override)); + MOCK_METHOD(std::string, GetResponseStatusText, (), (const, override)); + MOCK_METHOD(std::string, GetProtocolVersion, (), (const, override)); + MOCK_METHOD(std::string, + GetResponseHeader, + (const std::string&), + (const, override)); + MOCK_METHOD(Stream*, MockExtractDataStream, (brillo::ErrorPtr*), (const)); private: bool SetRequestData(StreamPtr stream, brillo::ErrorPtr* error) override { diff --git a/brillo/http/mock_curl_api.h b/brillo/http/mock_curl_api.h index 32b6e0d..daac8c2 100644 --- a/brillo/http/mock_curl_api.h +++ b/brillo/http/mock_curl_api.h @@ -20,34 +20,67 @@ class MockCurlInterface : public CurlInterface { public: MockCurlInterface() = default; - MOCK_METHOD0(EasyInit, CURL*()); - MOCK_METHOD1(EasyCleanup, void(CURL*)); - MOCK_METHOD3(EasySetOptInt, CURLcode(CURL*, CURLoption, int)); - MOCK_METHOD3(EasySetOptStr, CURLcode(CURL*, CURLoption, const std::string&)); - MOCK_METHOD3(EasySetOptPtr, CURLcode(CURL*, CURLoption, void*)); - MOCK_METHOD3(EasySetOptCallback, CURLcode(CURL*, CURLoption, intptr_t)); - MOCK_METHOD3(EasySetOptOffT, CURLcode(CURL*, CURLoption, curl_off_t)); - MOCK_METHOD1(EasyPerform, CURLcode(CURL*)); - MOCK_CONST_METHOD3(EasyGetInfoInt, CURLcode(CURL*, CURLINFO, int*)); - MOCK_CONST_METHOD3(EasyGetInfoDbl, CURLcode(CURL*, CURLINFO, double*)); - MOCK_CONST_METHOD3(EasyGetInfoStr, CURLcode(CURL*, CURLINFO, std::string*)); - MOCK_CONST_METHOD3(EasyGetInfoPtr, CURLcode(CURL*, CURLINFO, void**)); - MOCK_CONST_METHOD1(EasyStrError, std::string(CURLcode)); - MOCK_METHOD0(MultiInit, CURLM*()); - MOCK_METHOD1(MultiCleanup, CURLMcode(CURLM*)); - MOCK_METHOD2(MultiInfoRead, CURLMsg*(CURLM*, int*)); - MOCK_METHOD2(MultiAddHandle, CURLMcode(CURLM*, CURL*)); - MOCK_METHOD2(MultiRemoveHandle, CURLMcode(CURLM*, CURL*)); - MOCK_METHOD3(MultiSetSocketCallback, - CURLMcode(CURLM*, curl_socket_callback, void*)); - MOCK_METHOD3(MultiSetTimerCallback, - CURLMcode(CURLM*, curl_multi_timer_callback, void*)); - MOCK_METHOD3(MultiAssign, CURLMcode(CURLM*, curl_socket_t, void*)); - MOCK_METHOD4(MultiSocketAction, CURLMcode(CURLM*, curl_socket_t, int, int*)); - MOCK_CONST_METHOD1(MultiStrError, std::string(CURLMcode)); - MOCK_METHOD2(MultiPerform, CURLMcode(CURLM*, int*)); - MOCK_METHOD5(MultiWait, - CURLMcode(CURLM*, curl_waitfd[], unsigned int, int, int*)); + MOCK_METHOD(CURL*, EasyInit, (), (override)); + MOCK_METHOD(void, EasyCleanup, (CURL*), (override)); + MOCK_METHOD(CURLcode, EasySetOptInt, (CURL*, CURLoption, int), (override)); + MOCK_METHOD(CURLcode, + EasySetOptStr, + (CURL*, CURLoption, const std::string&), + (override)); + MOCK_METHOD(CURLcode, EasySetOptPtr, (CURL*, CURLoption, void*), (override)); + MOCK_METHOD(CURLcode, + EasySetOptCallback, + (CURL*, CURLoption, intptr_t), + (override)); + MOCK_METHOD(CURLcode, + EasySetOptOffT, + (CURL*, CURLoption, curl_off_t), + (override)); + MOCK_METHOD(CURLcode, EasyPerform, (CURL*), (override)); + MOCK_METHOD(CURLcode, + EasyGetInfoInt, + (CURL*, CURLINFO, int*), + (const, override)); + MOCK_METHOD(CURLcode, + EasyGetInfoDbl, + (CURL*, CURLINFO, double*), + (const, override)); + MOCK_METHOD(CURLcode, + EasyGetInfoStr, + (CURL*, CURLINFO, std::string*), + (const, override)); + MOCK_METHOD(CURLcode, + EasyGetInfoPtr, + (CURL*, CURLINFO, void**), + (const, override)); + MOCK_METHOD(std::string, EasyStrError, (CURLcode), (const, override)); + MOCK_METHOD(CURLM*, MultiInit, (), (override)); + MOCK_METHOD(CURLMcode, MultiCleanup, (CURLM*), (override)); + MOCK_METHOD(CURLMsg*, MultiInfoRead, (CURLM*, int*), (override)); + MOCK_METHOD(CURLMcode, MultiAddHandle, (CURLM*, CURL*), (override)); + MOCK_METHOD(CURLMcode, MultiRemoveHandle, (CURLM*, CURL*), (override)); + MOCK_METHOD(CURLMcode, + MultiSetSocketCallback, + (CURLM*, curl_socket_callback, void*), + (override)); + MOCK_METHOD(CURLMcode, + MultiSetTimerCallback, + (CURLM*, curl_multi_timer_callback, void*), + (override)); + MOCK_METHOD(CURLMcode, + MultiAssign, + (CURLM*, curl_socket_t, void*), + (override)); + MOCK_METHOD(CURLMcode, + MultiSocketAction, + (CURLM*, curl_socket_t, int, int*), + (override)); + MOCK_METHOD(std::string, MultiStrError, (CURLMcode), (const, override)); + MOCK_METHOD(CURLMcode, MultiPerform, (CURLM*, int*), (override)); + MOCK_METHOD(CURLMcode, + MultiWait, + (CURLM*, curl_waitfd[], unsigned int, int, int*), + (override)); private: DISALLOW_COPY_AND_ASSIGN(MockCurlInterface); diff --git a/brillo/http/mock_transport.h b/brillo/http/mock_transport.h index 7504266..a9f5d46 100644 --- a/brillo/http/mock_transport.h +++ b/brillo/http/mock_transport.h @@ -8,6 +8,7 @@ #include <memory> #include <string> +#include <base/location.h> #include <base/macros.h> #include <brillo/http/http_transport.h> #include <gmock/gmock.h> @@ -19,21 +20,35 @@ class MockTransport : public Transport { public: MockTransport() = default; - MOCK_METHOD6(CreateConnection, - std::shared_ptr<Connection>(const std::string&, - const std::string&, - const HeaderList&, - const std::string&, - const std::string&, - brillo::ErrorPtr*)); - MOCK_METHOD2(RunCallbackAsync, - void(const base::Location&, const base::Closure&)); - MOCK_METHOD3(StartAsyncTransfer, RequestID(Connection*, - const SuccessCallback&, - const ErrorCallback&)); - MOCK_METHOD1(CancelRequest, bool(RequestID)); - MOCK_METHOD1(SetDefaultTimeout, void(base::TimeDelta)); - MOCK_METHOD1(SetLocalIpAddress, void(const std::string&)); + MOCK_METHOD(std::shared_ptr<Connection>, + CreateConnection, + (const std::string&, + const std::string&, + const HeaderList&, + const std::string&, + const std::string&, + brillo::ErrorPtr*), + (override)); + MOCK_METHOD(void, + RunCallbackAsync, + (const base::Location&, const base::Closure&), + (override)); + MOCK_METHOD(RequestID, + StartAsyncTransfer, + (Connection*, const SuccessCallback&, const ErrorCallback&), + (override)); + MOCK_METHOD(bool, CancelRequest, (RequestID), (override)); + MOCK_METHOD(void, SetDefaultTimeout, (base::TimeDelta), (override)); + MOCK_METHOD(void, SetLocalIpAddress, (const std::string&), (override)); + MOCK_METHOD(void, UseDefaultCertificate, (), (override)); + MOCK_METHOD(void, UseCustomCertificate, (Certificate), (override)); + MOCK_METHOD(void, + ResolveHostToIp, + (const std::string&, uint16_t, const std::string&), + (override)); + + protected: + MOCK_METHOD(void, ClearHost, (), (override)); private: DISALLOW_COPY_AND_ASSIGN(MockTransport); diff --git a/brillo/imageloader/manifest.cc b/brillo/imageloader/manifest.cc deleted file mode 100644 index 92789df..0000000 --- a/brillo/imageloader/manifest.cc +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <brillo/imageloader/manifest.h> - -#include <memory> -#include <utility> - -#include <base/json/json_string_value_serializer.h> -#include <base/strings/string_number_conversions.h> - -namespace brillo { -namespace imageloader { - -namespace { -// The current version of the manifest file. -constexpr int kCurrentManifestVersion = 1; -// The name of the version field in the manifest. -constexpr char kManifestVersionField[] = "manifest-version"; -// The name of the component version field in the manifest. -constexpr char kVersionField[] = "version"; -// The name of the field containing the image hash. -constexpr char kImageHashField[] = "image-sha256-hash"; -// The name of the bool field indicating whether component is removable. -constexpr char kIsRemovableField[] = "is-removable"; -// The name of the metadata field. -constexpr char kMetadataField[] = "metadata"; -// The name of the field containing the table hash. -constexpr char kTableHashField[] = "table-sha256-hash"; -// The name of the optional field containing the file system type. -constexpr char kFSType[] = "fs-type"; - -bool GetSHA256FromString(const std::string& hash_str, - std::vector<uint8_t>* bytes) { - if (!base::HexStringToBytes(hash_str, bytes)) - return false; - return bytes->size() == 32; -} - -// Ensure the metadata entry is a dictionary mapping strings to strings and -// parse it into |out_metadata| and return true if so. -bool ParseMetadata(const base::Value* metadata_element, - std::map<std::string, std::string>* out_metadata) { - DCHECK(out_metadata); - - const base::DictionaryValue* metadata_dict = nullptr; - if (!metadata_element->GetAsDictionary(&metadata_dict)) - return false; - - base::DictionaryValue::Iterator it(*metadata_dict); - for (; !it.IsAtEnd(); it.Advance()) { - std::string parsed_value; - if (!it.value().GetAsString(&parsed_value)) { - LOG(ERROR) << "Key \"" << it.key() << "\" did not map to string value"; - return false; - } - - (*out_metadata)[it.key()] = std::move(parsed_value); - } - - return true; -} - -} // namespace - -Manifest::Manifest() {} - -bool Manifest::ParseManifest(const std::string& manifest_raw) { - // Now deserialize the manifest json and read out the rest of the component. - int error_code; - std::string error_message; - JSONStringValueDeserializer deserializer(manifest_raw); - std::unique_ptr<base::Value> value = - deserializer.Deserialize(&error_code, &error_message); - - if (!value) { - LOG(ERROR) << "Could not deserialize the manifest file. Error " - << error_code << ": " << error_message; - return false; - } - - base::DictionaryValue* manifest_dict = nullptr; - if (!value->GetAsDictionary(&manifest_dict)) { - LOG(ERROR) << "Could not parse manifest file as JSON."; - return false; - } - - // This will have to be changed if the manifest version is bumped. - int version; - if (!manifest_dict->GetInteger(kManifestVersionField, &version)) { - LOG(ERROR) << "Could not parse manifest version field from manifest."; - return false; - } - if (version != kCurrentManifestVersion) { - LOG(ERROR) << "Unsupported version of the manifest."; - return false; - } - manifest_version_ = version; - - std::string image_hash_str; - if (!manifest_dict->GetString(kImageHashField, &image_hash_str)) { - LOG(ERROR) << "Could not parse image hash from manifest."; - return false; - } - - if (!GetSHA256FromString(image_hash_str, &(image_sha256_))) { - LOG(ERROR) << "Could not convert image hash to bytes."; - return false; - } - - std::string table_hash_str; - if (!manifest_dict->GetString(kTableHashField, &table_hash_str)) { - LOG(ERROR) << "Could not parse table hash from manifest."; - return false; - } - - if (!GetSHA256FromString(table_hash_str, &(table_sha256_))) { - LOG(ERROR) << "Could not convert table hash to bytes."; - return false; - } - - if (!manifest_dict->GetString(kVersionField, &(version_))) { - LOG(ERROR) << "Could not parse component version from manifest."; - return false; - } - - // The fs_type field is optional, and squashfs by default. - fs_type_ = FileSystem::kSquashFS; - std::string fs_type; - if (manifest_dict->GetString(kFSType, &fs_type)) { - if (fs_type == "ext4") { - fs_type_ = FileSystem::kExt4; - } else if (fs_type == "squashfs") { - fs_type_ = FileSystem::kSquashFS; - } else { - LOG(ERROR) << "Unsupported file system type: " << fs_type; - return false; - } - } - - if (!manifest_dict->GetBoolean(kIsRemovableField, &(is_removable_))) { - // If is_removable field does not exist, by default it is false. - is_removable_ = false; - } - - // Copy out the metadata, if it's there. - const base::Value* metadata = nullptr; - if (manifest_dict->Get(kMetadataField, &metadata)) { - if (!ParseMetadata(metadata, &(metadata_))) { - LOG(ERROR) << "Manifest metadata was malformed"; - return false; - } - } - - return true; -} - -int Manifest::manifest_version() const { - return manifest_version_; -} - -const std::vector<uint8_t>& Manifest::image_sha256() const { - return image_sha256_; -} - -const std::vector<uint8_t>& Manifest::table_sha256() const { - return table_sha256_; -} - -const std::string& Manifest::version() const { - return version_; -} - -FileSystem Manifest::fs_type() const { - return fs_type_; -} - -bool Manifest::is_removable() const { - return is_removable_; -} - -const std::map<std::string, std::string> Manifest::metadata() const { - return metadata_; -} - -} // namespace imageloader -} // namespace brillo diff --git a/brillo/imageloader/manifest.h b/brillo/imageloader/manifest.h deleted file mode 100644 index cfd7c3a..0000000 --- a/brillo/imageloader/manifest.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef LIBBRILLO_BRILLO_IMAGELOADER_MANIFEST_H_ -#define LIBBRILLO_BRILLO_IMAGELOADER_MANIFEST_H_ - -#include <map> -#include <string> -#include <vector> - -#include <base/macros.h> -#include <brillo/brillo_export.h> - -namespace brillo { -namespace imageloader { - -// The supported file systems for images. -enum class FileSystem { kExt4, kSquashFS }; - -// A class to parse and store imageloader.json manifest. -class BRILLO_EXPORT Manifest { - public: - Manifest(); - // Parse the manifest raw string. Return true if successful. - bool ParseManifest(const std::string& manifest_raw); - // Getters for manifest fields: - int manifest_version() const; - const std::vector<uint8_t>& image_sha256() const; - const std::vector<uint8_t>& table_sha256() const; - const std::string& version() const; - FileSystem fs_type() const; - bool is_removable() const; - const std::map<std::string, std::string> metadata() const; - - private: - // Manifest fields: - int manifest_version_; - std::vector<uint8_t> image_sha256_; - std::vector<uint8_t> table_sha256_; - std::string version_; - FileSystem fs_type_; - bool is_removable_; - std::map<std::string, std::string> metadata_; - - DISALLOW_COPY_AND_ASSIGN(Manifest); -}; - -} // namespace imageloader -} // namespace brillo - -#endif // LIBBRILLO_BRILLO_IMAGELOADER_MANIFEST_H_ diff --git a/brillo/imageloader/manifest_unittest.cc b/brillo/imageloader/manifest_unittest.cc deleted file mode 100644 index bca7e8b..0000000 --- a/brillo/imageloader/manifest_unittest.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <gtest/gtest.h> - -#include <brillo/imageloader/manifest.h> - -namespace brillo { -namespace imageloader { - -class ManifestTest : public testing::Test {}; - -TEST_F(ManifestTest, ParseManifest) { - const std::string fs_type = R"("ext4")"; - const std::string is_removable = R"(true)"; - const std::string image_sha256_hash = - R"("4CF41BD11362CCB4707FB93939DBB5AC48745EDFC9DC8D7702852FFAA81B3B3F")"; - const std::string table_sha256_hash = - R"("0E11DA3D7140C6B95496787F50D15152434EBA22B60443BFA7E054FF4C799276")"; - const std::string version = R"("9824.0.4")"; - const std::string manifest_version = R"(1)"; - const std::string manifest_raw = std::string() + R"( - { - "fs-type":)" + fs_type + R"(, - "is-removable":)" + is_removable + - R"(, - "image-sha256-hash":)" + image_sha256_hash + - R"(, - "table-sha256-hash":)" + table_sha256_hash + - R"(, - "version":)" + version + R"(, - "manifest-version":)" + manifest_version + - R"( - } - )"; - brillo::imageloader::Manifest manifest; - // Parse the manifest raw string. - ASSERT_TRUE(manifest.ParseManifest(manifest_raw)); - EXPECT_EQ(manifest.fs_type(), FileSystem::kExt4); - EXPECT_EQ(manifest.is_removable(), true); - EXPECT_NE(manifest.image_sha256().size(), 0); - EXPECT_NE(manifest.table_sha256().size(), 0); - EXPECT_NE(manifest.version().size(), 0); - EXPECT_EQ(manifest.manifest_version(), 1); -} - -} // namespace imageloader -} // namespace brillo diff --git a/brillo/key_value_store.cc b/brillo/key_value_store.cc index 7840427..46c1d5c 100644 --- a/brillo/key_value_store.cc +++ b/brillo/key_value_store.cc @@ -4,7 +4,6 @@ #include "brillo/key_value_store.h" -#include <map> #include <string> #include <vector> @@ -15,7 +14,6 @@ #include <brillo/strings/string_utils.h> #include <brillo/map_utils.h> -using std::map; using std::string; using std::vector; @@ -37,6 +35,11 @@ string TrimKey(const string& key) { } // namespace +KeyValueStore::KeyValueStore() = default; +KeyValueStore::~KeyValueStore() = default; +KeyValueStore::KeyValueStore(KeyValueStore&&) = default; +KeyValueStore& KeyValueStore::operator=(KeyValueStore&&) = default; + bool KeyValueStore::Load(const base::FilePath& path) { string file_data; if (!base::ReadFileToString(path, &file_data)) @@ -89,6 +92,10 @@ string KeyValueStore::SaveToString() const { return data; } +void KeyValueStore::Clear() { + store_.clear(); +} + bool KeyValueStore::GetString(const string& key, string* value) const { const auto key_value = store_.find(TrimKey(key)); if (key_value == store_.end()) diff --git a/brillo/key_value_store.h b/brillo/key_value_store.h index cc5fa40..0c8e614 100644 --- a/brillo/key_value_store.h +++ b/brillo/key_value_store.h @@ -21,15 +21,26 @@ namespace brillo { class BRILLO_EXPORT KeyValueStore { public: // Creates an empty KeyValueStore. - KeyValueStore() = default; - virtual ~KeyValueStore() = default; + KeyValueStore(); + virtual ~KeyValueStore(); + // Copying is expensive; disallow accidental copies. + KeyValueStore(const KeyValueStore&) = delete; + KeyValueStore& operator=(const KeyValueStore&) = delete; + KeyValueStore(KeyValueStore&&); + KeyValueStore& operator=(KeyValueStore&&); // Loads the key=value pairs from the given |path|. Lines starting with '#' // and empty lines are ignored, and whitespace around keys is trimmed. // Trailing backslashes may be used to extend values across multiple lines. // Adds all the read key=values to the store, overriding those already defined - // but persisting the ones that aren't present on the passed file. Returns - // whether reading the file succeeded. + // but persisting the ones that aren't present on the passed file. + // + // Returns true, if the entire file is loaded successfully. If an error occurs + // while loading, keeps the pairs that were loaded before the error, and + // returns false. + // + // This function does not clear its internal state before loading. To clear + // the internal state, call Clear(). bool Load(const base::FilePath& path); // Loads the key=value pairs parsing the text passed in |data|. See Load() for @@ -48,6 +59,9 @@ class BRILLO_EXPORT KeyValueStore { // these values will be rewritten on single lines), comments or empty lines. std::string SaveToString() const; + // Clears all the key-value pairs currently stored. + void Clear(); + // Getter for the given key. Returns whether the key was found on the store. bool GetString(const std::string& key, std::string* value) const; @@ -67,8 +81,6 @@ class BRILLO_EXPORT KeyValueStore { private: // The map storing all the key-value pairs. std::map<std::string, std::string> store_; - - DISALLOW_COPY_AND_ASSIGN(KeyValueStore); }; } // namespace brillo diff --git a/brillo/key_value_store_unittest.cc b/brillo/key_value_store_test.cc index 68875ef..ceb8df6 100644 --- a/brillo/key_value_store_unittest.cc +++ b/brillo/key_value_store_test.cc @@ -6,6 +6,7 @@ #include <map> #include <string> +#include <utility> #include <vector> #include <base/files/file_util.h> @@ -37,6 +38,35 @@ class KeyValueStoreTest : public ::testing::Test { KeyValueStore store_; // KeyValueStore under test. }; +TEST_F(KeyValueStoreTest, MoveConstructor) { + store_.SetBoolean("a_boolean", true); + store_.SetString("a_string", "a_value"); + + KeyValueStore moved_to(std::move(store_)); + bool b_value = false; + EXPECT_TRUE(moved_to.GetBoolean("a_boolean", &b_value)); + EXPECT_TRUE(b_value); + + std::string s_value; + EXPECT_TRUE(moved_to.GetString("a_string", &s_value)); + EXPECT_EQ(s_value, "a_value"); +} + +TEST_F(KeyValueStoreTest, MoveAssignmentOperator) { + store_.SetBoolean("a_boolean", true); + store_.SetString("a_string", "a_value"); + + KeyValueStore moved_to; + moved_to = std::move(store_); + bool b_value = false; + EXPECT_TRUE(moved_to.GetBoolean("a_boolean", &b_value)); + EXPECT_TRUE(b_value); + + std::string s_value; + EXPECT_TRUE(moved_to.GetString("a_string", &s_value)); + EXPECT_EQ(s_value, "a_value"); +} + TEST_F(KeyValueStoreTest, LoadAndSaveFromFile) { base::ScopedTempDir temp_dir_; CHECK(temp_dir_.CreateUniqueTempDir()); @@ -96,6 +126,26 @@ TEST_F(KeyValueStoreTest, LoadAndReloadTest) { } } +TEST_F(KeyValueStoreTest, MultipleLoads) { + // The internal state is not cleared before loading. + EXPECT_TRUE(store_.LoadFromString("A=B\n")); + EXPECT_TRUE(store_.LoadFromString("B=C\n")); + EXPECT_EQ(2, store_.GetKeys().size()); +} + +TEST_F(KeyValueStoreTest, PartialLoad) { + // The 2nd line is broken, but the pair from the first line should be kept. + EXPECT_FALSE(store_.LoadFromString("A=B\n=\n")); + EXPECT_EQ(1, store_.GetKeys().size()); +} + +TEST_F(KeyValueStoreTest, Clear) { + EXPECT_TRUE(store_.LoadFromString("A=B\n")); + EXPECT_EQ(1, store_.GetKeys().size()); + store_.Clear(); + EXPECT_EQ(0, store_.GetKeys().size()); +} + TEST_F(KeyValueStoreTest, SimpleBooleanTest) { bool result; EXPECT_FALSE(store_.GetBoolean("A", &result)); diff --git a/brillo/map_utils_unittest.cc b/brillo/map_utils_test.cc index 19bda1d..19bda1d 100644 --- a/brillo/map_utils_unittest.cc +++ b/brillo/map_utils_test.cc diff --git a/brillo/message_loops/base_message_loop.cc b/brillo/message_loops/base_message_loop.cc index 08465d7..c3499ab 100644 --- a/brillo/message_loops/base_message_loop.cc +++ b/brillo/message_loops/base_message_loop.cc @@ -6,6 +6,7 @@ #include <fcntl.h> #include <sys/stat.h> +#include <sys/sysmacros.h> #include <sys/types.h> #include <unistd.h> @@ -19,6 +20,7 @@ #include <linux/major.h> #endif +#include <utility> #include <vector> #include <base/bind.h> @@ -28,12 +30,11 @@ #include <base/run_loop.h> #include <base/strings/string_number_conversions.h> #include <base/strings/string_split.h> +#include <base/threading/thread_task_runner_handle.h> #include <brillo/location_logging.h> #include <brillo/strings/string_utils.h> -using base::Closure; - namespace { const char kMiscMinorPath[] = "/proc/misc"; @@ -47,24 +48,19 @@ const int BaseMessageLoop::kInvalidMinor = -1; const int BaseMessageLoop::kUninitializedMinor = -2; BaseMessageLoop::BaseMessageLoop() { - CHECK(!base::MessageLoop::current()) + CHECK(!base::ThreadTaskRunnerHandle::IsSet()) << "You can't create a base::MessageLoopForIO when another " "base::MessageLoop is already created for this thread."; - owned_base_loop_.reset(new base::MessageLoopForIO); + owned_base_loop_.reset(new base::MessageLoopForIO()); base_loop_ = owned_base_loop_.get(); + watcher_ = std::make_unique<base::FileDescriptorWatcher>(base_loop_); } BaseMessageLoop::BaseMessageLoop(base::MessageLoopForIO* base_loop) - : base_loop_(base_loop) {} + : base_loop_(base_loop), + watcher_(std::make_unique<base::FileDescriptorWatcher>(base_loop_)) {} BaseMessageLoop::~BaseMessageLoop() { - for (auto& io_task : io_tasks_) { - DVLOG_LOC(io_task.second.location(), 1) - << "Removing file descriptor watcher task_id " << io_task.first - << " leaked on BaseMessageLoop, scheduled from this location."; - io_task.second.StopWatching(); - } - // Note all pending canceled delayed tasks when destroying the message loop. size_t lazily_deleted_tasks = 0; for (const auto& delayed_task : delayed_tasks_) { @@ -83,14 +79,13 @@ BaseMessageLoop::~BaseMessageLoop() { MessageLoop::TaskId BaseMessageLoop::PostDelayedTask( const base::Location& from_here, - const Closure &task, + base::OnceClosure task, base::TimeDelta delay) { TaskId task_id = NextTaskId(); bool base_scheduled = base_loop_->task_runner()->PostDelayedTask( from_here, - base::Bind(&BaseMessageLoop::OnRanPostedTask, - weak_ptr_factory_.GetWeakPtr(), - task_id), + base::BindOnce(&BaseMessageLoop::OnRanPostedTask, + weak_ptr_factory_.GetWeakPtr(), task_id), delay); DVLOG_LOC(from_here, 1) << "Scheduling delayed task_id " << task_id << " to run in " << delay << "."; @@ -102,81 +97,13 @@ MessageLoop::TaskId BaseMessageLoop::PostDelayedTask( return task_id; } -MessageLoop::TaskId BaseMessageLoop::WatchFileDescriptor( - const base::Location& from_here, - int fd, - WatchMode mode, - bool persistent, - const Closure &task) { - // base::MessageLoopForIO CHECKS that "fd >= 0", so we handle that case here. - if (fd < 0) - return MessageLoop::kTaskIdNull; - - base::MessagePumpForIO::Mode base_mode = base::MessagePumpForIO::WATCH_READ; - switch (mode) { - case MessageLoop::kWatchRead: - base_mode = base::MessagePumpForIO::WATCH_READ; - break; - case MessageLoop::kWatchWrite: - base_mode = base::MessagePumpForIO::WATCH_WRITE; - break; - default: - return MessageLoop::kTaskIdNull; - } - - TaskId task_id = NextTaskId(); - auto it_bool = io_tasks_.emplace( - std::piecewise_construct, - std::forward_as_tuple(task_id), - std::forward_as_tuple( - from_here, this, task_id, fd, base_mode, persistent, task)); - // This should always insert a new element. - DCHECK(it_bool.second); - bool scheduled = it_bool.first->second.StartWatching(); - DVLOG_LOC(from_here, 1) - << "Watching fd " << fd << " for " - << (mode == MessageLoop::kWatchRead ? "reading" : "writing") - << (persistent ? " persistently" : " just once") - << " as task_id " << task_id - << (scheduled ? " successfully" : " failed."); - - if (!scheduled) { - io_tasks_.erase(task_id); - return MessageLoop::kTaskIdNull; - } - -#ifndef __ANDROID_HOST__ - // Determine if the passed fd is the binder file descriptor. For that, we need - // to check that is a special char device and that the major and minor device - // numbers match. The binder file descriptor can't be removed and added back - // to an epoll group when there's work available to be done by the file - // descriptor due to bugs in the binder driver (b/26524111) when used with - // epoll. Therefore, we flag the binder fd and never attempt to remove it. - // This may cause the binder file descriptor to be attended with higher - // priority and cause starvation of other events. - struct stat buf; - if (fstat(fd, &buf) == 0 && - S_ISCHR(buf.st_mode) && - major(buf.st_rdev) == MISC_MAJOR && - minor(buf.st_rdev) == GetBinderMinor()) { - it_bool.first->second.RunImmediately(); - } -#endif - - return task_id; -} - bool BaseMessageLoop::CancelTask(TaskId task_id) { if (task_id == kTaskIdNull) return false; auto delayed_task_it = delayed_tasks_.find(task_id); - if (delayed_task_it == delayed_tasks_.end()) { - // This might be an IOTask then. - auto io_task_it = io_tasks_.find(task_id); - if (io_task_it == io_tasks_.end()) - return false; - return io_task_it->second.CancelTask(); - } + if (delayed_task_it == delayed_tasks_.end()) + return false; + // A DelayedTask was found for this task_id at this point. // Check if the callback was already canceled but we have the entry in @@ -186,10 +113,10 @@ bool BaseMessageLoop::CancelTask(TaskId task_id) { DVLOG_LOC(delayed_task_it->second.location, 1) << "Removing task_id " << task_id << " scheduled from this location."; - // We reset to closure to a null Closure to release all the resources + // We reset to closure to a null OnceClosure to release all the resources // used by this closure at this point, but we don't remove the task_id from // delayed_tasks_ since we can't tell base::MessageLoopForIO to not run it. - delayed_task_it->second.closure = Closure(); + delayed_task_it->second.closure.Reset(); return true; } @@ -226,7 +153,7 @@ void BaseMessageLoop::BreakLoop() { base_run_loop_->Quit(); } -Closure BaseMessageLoop::QuitClosure() const { +base::RepeatingClosure BaseMessageLoop::QuitClosure() const { if (base_run_loop_ == nullptr) return base::DoNothing(); return base_run_loop_->QuitClosure(); @@ -238,8 +165,7 @@ MessageLoop::TaskId BaseMessageLoop::NextTaskId() { res = ++last_id_; // We would run out of memory before we run out of task ids. } while (!res || - delayed_tasks_.find(res) != delayed_tasks_.end() || - io_tasks_.find(res) != io_tasks_.end()); + delayed_tasks_.find(res) != delayed_tasks_.end()); return res; } @@ -252,9 +178,7 @@ void BaseMessageLoop::OnRanPostedTask(MessageLoop::TaskId task_id) { << " scheduled from this location."; // Mark the task as canceled while we are running it so CancelTask returns // false. - Closure closure = std::move(task_it->second.closure); - task_it->second.closure = Closure(); - closure.Run(); + std::move(task_it->second.closure).Run(); // If the |run_once_| flag is set, it is because we are instructed to run // only once callback. @@ -266,15 +190,6 @@ void BaseMessageLoop::OnRanPostedTask(MessageLoop::TaskId task_id) { delayed_tasks_.erase(task_it); } -void BaseMessageLoop::OnFileReadyPostedTask(MessageLoop::TaskId task_id) { - auto task_it = io_tasks_.find(task_id); - // Even if this task was canceled while we were waiting in the message loop - // for this method to run, the entry in io_tasks_ should still be present, but - // won't do anything. - DCHECK(task_it != io_tasks_.end()); - task_it->second.OnFileReadyPostedTask(); -} - int BaseMessageLoop::ParseBinderMinor( const std::string& file_contents) { int result = kInvalidMinor; @@ -308,141 +223,4 @@ unsigned int BaseMessageLoop::GetBinderMinor() { return binder_minor_; } -BaseMessageLoop::IOTask::IOTask(const base::Location& location, - BaseMessageLoop* loop, - MessageLoop::TaskId task_id, - int fd, - base::MessagePumpForIO::Mode base_mode, - bool persistent, - const Closure& task) - : location_(location), loop_(loop), task_id_(task_id), - fd_(fd), base_mode_(base_mode), persistent_(persistent), closure_(task), - fd_watcher_(FROM_HERE) {} - -bool BaseMessageLoop::IOTask::StartWatching() { - // Please see MessagePumpLibevent for definition. - static_assert(std::is_same<base::MessagePumpForIO, base::MessagePumpLibevent>::value, - "MessagePumpForIO::WatchFileDescriptor is not supported " - "when MessagePumpForIO is not a MessagePumpLibevent."); - - return static_cast<base::MessagePumpLibevent*>( - loop_->base_loop_->pump_.get())->WatchFileDescriptor( - fd_, persistent_, base_mode_, &fd_watcher_, this); -} - -void BaseMessageLoop::IOTask::StopWatching() { - // This is safe to call even if we are not watching for it. - fd_watcher_.StopWatchingFileDescriptor(); -} - -void BaseMessageLoop::IOTask::OnFileCanReadWithoutBlocking(int /* fd */) { - OnFileReady(); -} - -void BaseMessageLoop::IOTask::OnFileCanWriteWithoutBlocking(int /* fd */) { - OnFileReady(); -} - -void BaseMessageLoop::IOTask::OnFileReady() { - // For file descriptors marked with the immediate_run flag, we don't call - // StopWatching() and wait, instead we dispatch the callback immediately. - if (immediate_run_) { - posted_task_pending_ = true; - OnFileReadyPostedTask(); - return; - } - - // When the file descriptor becomes available we stop watching for it and - // schedule a task to run the callback from the main loop. The callback will - // run using the same scheduler used to run other delayed tasks, avoiding - // starvation of the available posted tasks if there are file descriptors - // always available. The new posted task will use the same TaskId as the - // current file descriptor watching task an could be canceled in either state, - // when waiting for the file descriptor or waiting in the main loop. - StopWatching(); - bool base_scheduled = loop_->base_loop_->task_runner()->PostTask( - location_, - base::Bind(&BaseMessageLoop::OnFileReadyPostedTask, - loop_->weak_ptr_factory_.GetWeakPtr(), - task_id_)); - posted_task_pending_ = true; - if (base_scheduled) { - DVLOG_LOC(location_, 1) - << "Dispatching task_id " << task_id_ << " for " - << (base_mode_ == base::MessagePumpForIO::WATCH_READ ? - "reading" : "writing") - << " file descriptor " << fd_ << ", scheduled from this location."; - } else { - // In the rare case that PostTask() fails, we fall back to run it directly. - // This would indicate a bigger problem with the message loop setup. - LOG(ERROR) << "Error on base::MessageLoopForIO::PostTask()."; - OnFileReadyPostedTask(); - } -} - -void BaseMessageLoop::IOTask::OnFileReadyPostedTask() { - // We can't access |this| after running the |closure_| since it could call - // CancelTask on its own task_id, so we copy the members we need now. - BaseMessageLoop* loop_ptr = loop_; - DCHECK(posted_task_pending_ = true); - posted_task_pending_ = false; - - // If this task was already canceled, the closure will be null and there is - // nothing else to do here. This execution doesn't count a step for RunOnce() - // unless we have a callback to run. - if (closure_.is_null()) { - loop_->io_tasks_.erase(task_id_); - return; - } - - DVLOG_LOC(location_, 1) - << "Running task_id " << task_id_ << " for " - << (base_mode_ == base::MessagePumpForIO::WATCH_READ ? - "reading" : "writing") - << " file descriptor " << fd_ << ", scheduled from this location."; - - if (persistent_) { - // In the persistent case we just run the callback. If this callback cancels - // the task id, we can't access |this| anymore, so we re-start watching the - // file descriptor before running the callback, unless this is a fd where - // we didn't stop watching the file descriptor when it became available. - if (!immediate_run_) - StartWatching(); - closure_.Run(); - } else { - // This will destroy |this|, the fd_watcher and therefore stop watching this - // file descriptor. - Closure closure_copy = std::move(closure_); - loop_->io_tasks_.erase(task_id_); - // Run the closure from the local copy we just made. - closure_copy.Run(); - } - - if (loop_ptr->run_once_) { - loop_ptr->run_once_ = false; - loop_ptr->BreakLoop(); - } -} - -bool BaseMessageLoop::IOTask::CancelTask() { - if (closure_.is_null()) - return false; - - DVLOG_LOC(location_, 1) - << "Removing task_id " << task_id_ << " scheduled from this location."; - - if (!posted_task_pending_) { - // Destroying the FileDescriptorWatcher implicitly stops watching the file - // descriptor. This will delete our instance. - loop_->io_tasks_.erase(task_id_); - return true; - } - // The IOTask is waiting for the message loop to run its delayed task, so - // it is not watching for the file descriptor. We release the closure - // resources now but keep the IOTask instance alive while we wait for the - // callback to run and delete the IOTask. - closure_ = Closure(); - return true; -} - } // namespace brillo diff --git a/brillo/message_loops/base_message_loop.h b/brillo/message_loops/base_message_loop.h index 163ea4f..75e4361 100644 --- a/brillo/message_loops/base_message_loop.h +++ b/brillo/message_loops/base_message_loop.h @@ -16,6 +16,7 @@ #include <memory> #include <string> +#include <base/files/file_descriptor_watcher_posix.h> #include <base/location.h> #include <base/memory/weak_ptr.h> #include <base/message_loop/message_loop.h> @@ -41,15 +42,9 @@ class BRILLO_EXPORT BaseMessageLoop : public MessageLoop { // MessageLoop overrides. TaskId PostDelayedTask(const base::Location& from_here, - const base::Closure& task, + base::OnceClosure task, base::TimeDelta delay) override; using MessageLoop::PostDelayedTask; - TaskId WatchFileDescriptor(const base::Location& from_here, - int fd, - WatchMode mode, - bool persistent, - const base::Closure& task) override; - using MessageLoop::WatchFileDescriptor; bool CancelTask(TaskId task_id) override; bool RunOnce(bool may_block) override; void Run() override; @@ -57,7 +52,7 @@ class BRILLO_EXPORT BaseMessageLoop : public MessageLoop { // Returns a callback that will quit the current message loop. If the message // loop is not running, an empty (null) callback is returned. - base::Closure QuitClosure() const; + base::RepeatingClosure QuitClosure() const; private: FRIEND_TEST(BaseMessageLoopTest, ParseBinderMinor); @@ -74,12 +69,6 @@ class BRILLO_EXPORT BaseMessageLoop : public MessageLoop { // scheduled with Post*Task() of id |task_id|, even if it was canceled. void OnRanPostedTask(MessageLoop::TaskId task_id); - // Called from the message loop when the IOTask should run the scheduled - // callback. This is a simple wrapper of IOTask::OnFileReadyPostedTask() - // posted from the BaseMessageLoop so it is deleted when the BaseMessageLoop - // goes out of scope since we can't cancel the callback otherwise. - void OnFileReadyPostedTask(MessageLoop::TaskId task_id); - // Return a new unused task_id. TaskId NextTaskId(); @@ -90,69 +79,7 @@ class BRILLO_EXPORT BaseMessageLoop : public MessageLoop { base::Location location; MessageLoop::TaskId task_id; - base::Closure closure; - }; - - class IOTask : public base::MessagePumpForIO::FdWatcher { - public: - IOTask(const base::Location& location, - BaseMessageLoop* loop, - MessageLoop::TaskId task_id, - int fd, - base::MessagePumpForIO::Mode base_mode, - bool persistent, - const base::Closure& task); - - const base::Location& location() const { return location_; } - - // Used to start/stop watching the file descriptor while keeping the - // IOTask entry available. - bool StartWatching(); - void StopWatching(); - - // Called from the message loop as a PostTask() when the file descriptor is - // available, scheduled to run from OnFileReady(). - void OnFileReadyPostedTask(); - - // Cancel the IOTask and returns whether it was actually canceled, with the - // same semantics as MessageLoop::CancelTask(). - bool CancelTask(); - - // Sets the closure to be run immediately whenever the file descriptor - // becomes ready. - void RunImmediately() { immediate_run_= true; } - - private: - base::Location location_; - BaseMessageLoop* loop_; - - // These are the arguments passed in the constructor, basically forwarding - // all the arguments passed to WatchFileDescriptor() plus the assigned - // TaskId for this task. - MessageLoop::TaskId task_id_; - int fd_; - base::MessagePumpForIO::Mode base_mode_; - bool persistent_; - base::Closure closure_; - - base::MessagePumpForIO::FdWatchController fd_watcher_; - - // Tells whether there is a pending call to OnFileReadPostedTask(). - bool posted_task_pending_{false}; - - // Whether the registered callback should be running immediately when the - // file descriptor is ready, as opposed to posting a task to the main loop - // to prevent starvation. - bool immediate_run_{false}; - - // base::MessageLoopForIO::Watcher overrides: - void OnFileCanReadWithoutBlocking(int fd) override; - void OnFileCanWriteWithoutBlocking(int fd) override; - - // Common implementation for both the read and write case. - void OnFileReady(); - - DISALLOW_COPY_AND_ASSIGN(IOTask); + base::OnceClosure closure; }; // The base::MessageLoopForIO instance owned by this class, if any. This @@ -162,9 +89,6 @@ class BRILLO_EXPORT BaseMessageLoop : public MessageLoop { // Tasks blocked on a timeout. std::map<MessageLoop::TaskId, DelayedTask> delayed_tasks_; - // Tasks blocked on I/O. - std::map<MessageLoop::TaskId, IOTask> io_tasks_; - // Flag to mark that we should run the message loop only one iteration. bool run_once_{false}; @@ -178,6 +102,9 @@ class BRILLO_EXPORT BaseMessageLoop : public MessageLoop { // point to that instance. base::MessageLoopForIO* base_loop_; + // FileDescriptorWatcher for |base_loop_|. This is used in AlarmTimer. + std::unique_ptr<base::FileDescriptorWatcher> watcher_; + // The RunLoop instance used to run the main loop from Run(). base::RunLoop* base_run_loop_{nullptr}; diff --git a/brillo/message_loops/base_message_loop_unittest.cc b/brillo/message_loops/base_message_loop_test.cc index 9e052a8..9e052a8 100644 --- a/brillo/message_loops/base_message_loop_unittest.cc +++ b/brillo/message_loops/base_message_loop_test.cc diff --git a/brillo/message_loops/fake_message_loop.cc b/brillo/message_loops/fake_message_loop.cc index 41f5b51..185b20c 100644 --- a/brillo/message_loops/fake_message_loop.cc +++ b/brillo/message_loops/fake_message_loop.cc @@ -15,7 +15,7 @@ FakeMessageLoop::FakeMessageLoop(base::SimpleTestClock* clock) MessageLoop::TaskId FakeMessageLoop::PostDelayedTask( const base::Location& from_here, - const base::Closure& task, + base::OnceClosure task, base::TimeDelta delay) { // If no SimpleTestClock was provided, we use the last time we fired a // callback. In this way, tasks scheduled from a Closure will have the right @@ -25,7 +25,7 @@ MessageLoop::TaskId FakeMessageLoop::PostDelayedTask( MessageLoop::TaskId current_id = ++last_id_; // FakeMessageLoop is limited to only 2^64 tasks. That should be enough. CHECK(current_id); - tasks_.emplace(current_id, ScheduledTask{from_here, false, task}); + tasks_.emplace(current_id, ScheduledTask{from_here, std::move(task)}); fire_order_.push(std::make_pair(current_time_ + delay, current_id)); VLOG_LOC(from_here, 1) << "Scheduling delayed task_id " << current_id << " to run at " << current_time_ + delay @@ -33,20 +33,6 @@ MessageLoop::TaskId FakeMessageLoop::PostDelayedTask( return current_id; } -MessageLoop::TaskId FakeMessageLoop::WatchFileDescriptor( - const base::Location& from_here, - int fd, - WatchMode mode, - bool persistent, - const base::Closure& task) { - MessageLoop::TaskId current_id = ++last_id_; - // FakeMessageLoop is limited to only 2^64 tasks. That should be enough. - CHECK(current_id); - tasks_.emplace(current_id, ScheduledTask{from_here, persistent, task}); - fds_watched_.emplace(std::make_pair(fd, mode), current_id); - return current_id; -} - bool FakeMessageLoop::CancelTask(TaskId task_id) { if (task_id == MessageLoop::kTaskIdNull) return false; @@ -58,36 +44,6 @@ bool FakeMessageLoop::CancelTask(TaskId task_id) { bool FakeMessageLoop::RunOnce(bool may_block) { if (test_clock_) current_time_ = test_clock_->Now(); - // Try to fire ready file descriptors first. - for (const auto& fd_mode : fds_ready_) { - const auto& fd_watched = fds_watched_.find(fd_mode); - if (fd_watched == fds_watched_.end()) - continue; - // The fd_watched->second task might have been canceled and we never removed - // it from the fds_watched_, so we fix that now. - const auto& scheduled_task_ref = tasks_.find(fd_watched->second); - if (scheduled_task_ref == tasks_.end()) { - fds_watched_.erase(fd_watched); - continue; - } - VLOG_LOC(scheduled_task_ref->second.location, 1) - << "Running task_id " << fd_watched->second - << " for watching file descriptor " << fd_mode.first << " for " - << (fd_mode.second == MessageLoop::kWatchRead ? "reading" : "writing") - << (scheduled_task_ref->second.persistent ? - " persistently" : " just once") - << " scheduled from this location."; - if (scheduled_task_ref->second.persistent) { - scheduled_task_ref->second.callback.Run(); - } else { - base::Closure callback = std::move(scheduled_task_ref->second.callback); - tasks_.erase(scheduled_task_ref); - fds_watched_.erase(fd_watched); - callback.Run(); - } - return true; - } - // Try to fire time-based callbacks afterwards. while (!fire_order_.empty() && (may_block || fire_order_.top().first <= current_time_)) { @@ -108,32 +64,22 @@ bool FakeMessageLoop::RunOnce(bool may_block) { // Move the Closure out of the map before delete it. We need to delete the // entry from the map before we call the callback, since calling CancelTask // for the task you are running now should fail and return false. - base::Closure callback = std::move(scheduled_task_ref->second.callback); + base::OnceClosure callback = std::move(scheduled_task_ref->second.callback); VLOG_LOC(scheduled_task_ref->second.location, 1) << "Running task_id " << task_ref.second << " at time " << current_time_ << " from this location."; tasks_.erase(scheduled_task_ref); - callback.Run(); + std::move(callback).Run(); return true; } return false; } -void FakeMessageLoop::SetFileDescriptorReadiness(int fd, - WatchMode mode, - bool ready) { - if (ready) - fds_ready_.emplace(fd, mode); - else - fds_ready_.erase(std::make_pair(fd, mode)); -} - bool FakeMessageLoop::PendingTasks() { for (const auto& task : tasks_) { VLOG_LOC(task.second.location, 1) - << "Pending " << (task.second.persistent ? "persistent " : "") - << "task_id " << task.first << " scheduled from here."; + << "Pending task_id " << task.first << " scheduled from here."; } return !tasks_.empty(); } diff --git a/brillo/message_loops/fake_message_loop.h b/brillo/message_loops/fake_message_loop.h index 4b6e8ac..783af1b 100644 --- a/brillo/message_loops/fake_message_loop.h +++ b/brillo/message_loops/fake_message_loop.h @@ -37,26 +37,14 @@ class BRILLO_EXPORT FakeMessageLoop : public MessageLoop { ~FakeMessageLoop() override = default; TaskId PostDelayedTask(const base::Location& from_here, - const base::Closure& task, + base::OnceClosure task, base::TimeDelta delay) override; using MessageLoop::PostDelayedTask; - TaskId WatchFileDescriptor(const base::Location& from_here, - int fd, - WatchMode mode, - bool persistent, - const base::Closure& task) override; - using MessageLoop::WatchFileDescriptor; bool CancelTask(TaskId task_id) override; bool RunOnce(bool may_block) override; // FakeMessageLoop methods: - // Pretend, for the purpose of the FakeMessageLoop watching for file - // descriptors, that the file descriptor |fd| readiness to perform the - // operation described by |mode| is |ready|. Initially, no file descriptor - // is ready for any operation. - void SetFileDescriptorReadiness(int fd, WatchMode mode, bool ready); - // Return whether there are peding tasks. Useful to check that no // callbacks were leaked. bool PendingTasks(); @@ -64,8 +52,7 @@ class BRILLO_EXPORT FakeMessageLoop : public MessageLoop { private: struct ScheduledTask { base::Location location; - bool persistent; - base::Closure callback; + base::OnceClosure callback; }; // The sparse list of scheduled pending callbacks. @@ -79,13 +66,6 @@ class BRILLO_EXPORT FakeMessageLoop : public MessageLoop { std::vector<std::pair<base::Time, MessageLoop::TaskId>>, std::greater<std::pair<base::Time, MessageLoop::TaskId>>> fire_order_; - // The bag of watched (fd, mode) pair associated with the TaskId that's - // watching them. - std::multimap<std::pair<int, WatchMode>, MessageLoop::TaskId> fds_watched_; - - // The set of (fd, mode) pairs that are faked as ready. - std::set<std::pair<int, WatchMode>> fds_ready_; - base::SimpleTestClock* test_clock_ = nullptr; base::Time current_time_ = base::Time::FromDoubleT(1246996800.); diff --git a/brillo/message_loops/fake_message_loop_unittest.cc b/brillo/message_loops/fake_message_loop_test.cc index 18f0b4b..a5d0607 100644 --- a/brillo/message_loops/fake_message_loop_unittest.cc +++ b/brillo/message_loops/fake_message_loop_test.cc @@ -13,10 +13,9 @@ #include <base/test/simple_test_clock.h> #include <gtest/gtest.h> -#include <brillo/bind_lambda.h> #include <brillo/message_loops/message_loop.h> -using base::Bind; +using base::BindOnce; using base::Time; using base::TimeDelta; using std::vector; @@ -46,17 +45,18 @@ TEST_F(FakeMessageLoopTest, CancelTaskInvalidValuesTest) { TEST_F(FakeMessageLoopTest, PostDelayedTaskRunsInOrder) { vector<int> order; - auto callback = [](std::vector<int>* order, int value) { - order->push_back(value); - }; - loop_->PostDelayedTask(Bind(callback, base::Unretained(&order), 1), - TimeDelta::FromSeconds(1)); - loop_->PostDelayedTask(Bind(callback, base::Unretained(&order), 4), - TimeDelta::FromSeconds(4)); - loop_->PostDelayedTask(Bind(callback, base::Unretained(&order), 3), - TimeDelta::FromSeconds(3)); - loop_->PostDelayedTask(Bind(callback, base::Unretained(&order), 2), - TimeDelta::FromSeconds(2)); + loop_->PostDelayedTask( + BindOnce([](vector<int>* order) { order->push_back(1); }, &order), + TimeDelta::FromSeconds(1)); + loop_->PostDelayedTask( + BindOnce([](vector<int>* order) { order->push_back(4); }, &order), + TimeDelta::FromSeconds(4)); + loop_->PostDelayedTask( + BindOnce([](vector<int>* order) { order->push_back(3); }, &order), + TimeDelta::FromSeconds(3)); + loop_->PostDelayedTask( + BindOnce([](vector<int>* order) { order->push_back(2); }, &order), + TimeDelta::FromSeconds(2)); // Run until all the tasks are run. loop_->Run(); EXPECT_EQ((vector<int>{1, 2, 3, 4}), order); @@ -85,34 +85,6 @@ TEST_F(FakeMessageLoopTest, PostDelayedTaskAdvancesTheTime) { EXPECT_EQ(start + TimeDelta::FromSeconds(3), clock_.Now()); } -TEST_F(FakeMessageLoopTest, WatchFileDescriptorWaits) { - int fd = 1234; - // We will simulate this situation. At the beginning, we will watch for a - // file descriptor that won't trigger for 10s. Then we will pretend it is - // ready after 10s and expect its callback to run just once. - int called = 0; - TaskId task_id = loop_->WatchFileDescriptor( - FROM_HERE, fd, MessageLoop::kWatchRead, false, - Bind([](int* called) { (*called)++; }, base::Unretained(&called))); - EXPECT_NE(MessageLoop::kTaskIdNull, task_id); - - EXPECT_NE(MessageLoop::kTaskIdNull, - loop_->PostDelayedTask(Bind(&FakeMessageLoop::BreakLoop, - base::Unretained(loop_.get())), - TimeDelta::FromSeconds(10))); - EXPECT_NE(MessageLoop::kTaskIdNull, - loop_->PostDelayedTask(Bind(&FakeMessageLoop::BreakLoop, - base::Unretained(loop_.get())), - TimeDelta::FromSeconds(20))); - loop_->Run(); - EXPECT_EQ(0, called); - - loop_->SetFileDescriptorReadiness(fd, MessageLoop::kWatchRead, true); - loop_->Run(); - EXPECT_EQ(1, called); - EXPECT_FALSE(loop_->CancelTask(task_id)); -} - TEST_F(FakeMessageLoopTest, PendingTasksTest) { loop_->PostDelayedTask(base::DoNothing(), TimeDelta::FromSeconds(1)); EXPECT_TRUE(loop_->PendingTasks()); diff --git a/brillo/message_loops/message_loop.h b/brillo/message_loops/message_loop.h index 1f65d96..13c4dc2 100644 --- a/brillo/message_loops/message_loop.h +++ b/brillo/message_loops/message_loop.h @@ -6,6 +6,7 @@ #define LIBBRILLO_BRILLO_MESSAGE_LOOPS_MESSAGE_LOOP_H_ #include <string> +#include <utility> #include <base/callback.h> #include <base/location.h> @@ -50,49 +51,20 @@ class BRILLO_EXPORT MessageLoop { // at a later point. // This methond can only be called from the same thread running the main loop. virtual TaskId PostDelayedTask(const base::Location& from_here, - const base::Closure& task, + base::OnceClosure task, base::TimeDelta delay) = 0; // Variant without the Location for easier usage. - TaskId PostDelayedTask(const base::Closure& task, base::TimeDelta delay) { - return PostDelayedTask(base::Location(), task, delay); + TaskId PostDelayedTask(base::OnceClosure task, base::TimeDelta delay) { + return PostDelayedTask(base::Location(), std::move(task), delay); } // A convenience method to schedule a call with no delay. // This methond can only be called from the same thread running the main loop. - TaskId PostTask(const base::Closure& task) { - return PostDelayedTask(task, base::TimeDelta()); + TaskId PostTask(base::OnceClosure task) { + return PostDelayedTask(std::move(task), base::TimeDelta()); } - TaskId PostTask(const base::Location& from_here, - const base::Closure& task) { - return PostDelayedTask(from_here, task, base::TimeDelta()); - } - - // Watch mode flag used to watch for file descriptors. - enum WatchMode { - kWatchRead, - kWatchWrite, - }; - - // Watch a file descriptor |fd| for it to be ready to perform the operation - // passed in |mode| without blocking. When that happens, the |task| closure - // will be executed. If |persistent| is true, the file descriptor will - // continue to be watched and |task| will continue to be called until the task - // is canceled with CancelTask(). - // Returns the TaskId describing this task. In case of error, returns - // kTaskIdNull. - virtual TaskId WatchFileDescriptor(const base::Location& from_here, - int fd, - WatchMode mode, - bool persistent, - const base::Closure& task) = 0; - - // Convenience function to call WatchFileDescriptor() without a location. - TaskId WatchFileDescriptor(int fd, - WatchMode mode, - bool persistent, - const base::Closure& task) { - return WatchFileDescriptor( - base::Location(), fd, mode, persistent, task); + TaskId PostTask(const base::Location& from_here, base::OnceClosure task) { + return PostDelayedTask(from_here, std::move(task), base::TimeDelta()); } // Cancel a scheduled task. Returns whether the task was canceled. For diff --git a/brillo/message_loops/message_loop_test.cc b/brillo/message_loops/message_loop_test.cc new file mode 100644 index 0000000..86c41a6 --- /dev/null +++ b/brillo/message_loops/message_loop_test.cc @@ -0,0 +1,139 @@ +// Copyright 2015 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/message_loops/message_loop.h> + +// These are the common tests for all the brillo::MessageLoop implementations +// that should conform to this interface's contracts. For extra +// implementation-specific tests see the particular implementation unittests in +// the *_test.cc files. + +#include <memory> +#include <vector> + +#include <base/bind.h> +#include <base/bind_helpers.h> +#include <base/location.h> +#include <base/posix/eintr_wrapper.h> +#include <gtest/gtest.h> + +#include <brillo/message_loops/base_message_loop.h> +#include <brillo/message_loops/message_loop_utils.h> +#include <brillo/unittest_utils.h> + +using base::BindOnce; +using base::BindRepeating; +using base::TimeDelta; + +namespace { + +// Convenience functions for passing to base::Bind{Once,Repeating}. +void SetToTrue(bool* b) { + *b = true; +} + +bool ReturnBool(bool *b) { + return *b; +} + +} // namespace + +namespace brillo { + +using TaskId = MessageLoop::TaskId; + +template <typename T> +class MessageLoopTest : public ::testing::Test { + protected: + void SetUp() override { + MessageLoopSetUp(); + EXPECT_TRUE(this->loop_.get()); + } + + std::unique_ptr<base::MessageLoopForIO> base_loop_; + + std::unique_ptr<MessageLoop> loop_; + + private: + // These MessageLoopSetUp() methods are used to setup each MessageLoop + // according to its constructor requirements. + void MessageLoopSetUp(); +}; + +template <> +void MessageLoopTest<BaseMessageLoop>::MessageLoopSetUp() { + base_loop_.reset(new base::MessageLoopForIO()); + loop_.reset(new BaseMessageLoop(base_loop_.get())); + loop_->SetAsCurrent(); +} + +// This setups gtest to run each one of the following TYPED_TEST test cases on +// on each implementation. +typedef ::testing::Types<BaseMessageLoop> MessageLoopTypes; +TYPED_TEST_CASE(MessageLoopTest, MessageLoopTypes); + + +TYPED_TEST(MessageLoopTest, CancelTaskInvalidValuesTest) { + EXPECT_FALSE(this->loop_->CancelTask(MessageLoop::kTaskIdNull)); + EXPECT_FALSE(this->loop_->CancelTask(1234)); +} + +TYPED_TEST(MessageLoopTest, PostTaskTest) { + bool called = false; + TaskId task_id = + this->loop_->PostTask(FROM_HERE, BindOnce(&SetToTrue, &called)); + EXPECT_NE(MessageLoop::kTaskIdNull, task_id); + MessageLoopRunMaxIterations(this->loop_.get(), 100); + EXPECT_TRUE(called); +} + +// Tests that we can cancel tasks right after we schedule them. +TYPED_TEST(MessageLoopTest, PostTaskCancelledTest) { + bool called = false; + TaskId task_id = + this->loop_->PostTask(FROM_HERE, BindOnce(&SetToTrue, &called)); + EXPECT_TRUE(this->loop_->CancelTask(task_id)); + MessageLoopRunMaxIterations(this->loop_.get(), 100); + EXPECT_FALSE(called); + // Can't remove a task you already removed. + EXPECT_FALSE(this->loop_->CancelTask(task_id)); +} + +TYPED_TEST(MessageLoopTest, PostDelayedTaskRunsEventuallyTest) { + bool called = false; + TaskId task_id = + this->loop_->PostDelayedTask(FROM_HERE, BindOnce(&SetToTrue, &called), + TimeDelta::FromMilliseconds(50)); + EXPECT_NE(MessageLoop::kTaskIdNull, task_id); + MessageLoopRunUntil(this->loop_.get(), TimeDelta::FromSeconds(10), + BindRepeating(&ReturnBool, &called)); + // Check that the main loop finished before the 10 seconds timeout, so it + // finished due to the callback being called and not due to the timeout. + EXPECT_TRUE(called); +} + +// Test that you can call the overloaded version of PostDelayedTask from +// MessageLoop. This is important because only one of the two methods is +// virtual, so you need to unhide the other when overriding the virtual one. +TYPED_TEST(MessageLoopTest, PostDelayedTaskWithoutLocation) { + this->loop_->PostDelayedTask(base::DoNothing(), TimeDelta()); + EXPECT_EQ(1, MessageLoopRunMaxIterations(this->loop_.get(), 100)); +} + +// Test that we can cancel the task we are running, and should just fail. +TYPED_TEST(MessageLoopTest, DeleteTaskFromSelf) { + bool cancel_result = true; // We would expect this to be false. + TaskId task_id; + task_id = this->loop_->PostTask( + FROM_HERE, + BindOnce( + [](bool* cancel_result, MessageLoop* loop, TaskId* task_id) { + *cancel_result = loop->CancelTask(*task_id); + }, + &cancel_result, this->loop_.get(), &task_id)); + EXPECT_EQ(1, MessageLoopRunMaxIterations(this->loop_.get(), 100)); + EXPECT_FALSE(cancel_result); +} + +} // namespace brillo diff --git a/brillo/message_loops/message_loop_unittest.cc b/brillo/message_loops/message_loop_unittest.cc deleted file mode 100644 index bda3336..0000000 --- a/brillo/message_loops/message_loop_unittest.cc +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright 2015 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <brillo/message_loops/message_loop.h> - -// These are the common tests for all the brillo::MessageLoop implementations -// that should conform to this interface's contracts. For extra -// implementation-specific tests see the particular implementation unittests in -// the *_unittest.cc files. - -#include <memory> -#include <vector> - -#include <base/bind.h> -#include <base/bind_helpers.h> -#include <base/location.h> -#include <base/posix/eintr_wrapper.h> -#include <gtest/gtest.h> - -#include <brillo/bind_lambda.h> -#include <brillo/unittest_utils.h> -#include <brillo/message_loops/base_message_loop.h> -#include <brillo/message_loops/message_loop_utils.h> - -using base::Bind; -using base::TimeDelta; - -namespace { - -// Convenience functions for passing to base::Bind. -void SetToTrue(bool* b) { - *b = true; -} - -bool ReturnBool(bool *b) { - return *b; -} - -void Increment(int* i) { - (*i)++; -} - -} // namespace - -namespace brillo { - -using TaskId = MessageLoop::TaskId; - -template <typename T> -class MessageLoopTest : public ::testing::Test { - protected: - void SetUp() override { - MessageLoopSetUp(); - EXPECT_TRUE(this->loop_.get()); - } - - std::unique_ptr<base::MessageLoopForIO> base_loop_; - - std::unique_ptr<MessageLoop> loop_; - - private: - // These MessageLoopSetUp() methods are used to setup each MessageLoop - // according to its constructor requirements. - void MessageLoopSetUp(); -}; - -template <> -void MessageLoopTest<BaseMessageLoop>::MessageLoopSetUp() { - base_loop_.reset(new base::MessageLoopForIO()); - loop_.reset(new BaseMessageLoop(base::MessageLoopForIO::current())); -} - -// This setups gtest to run each one of the following TYPED_TEST test cases on -// on each implementation. -typedef ::testing::Types<BaseMessageLoop> MessageLoopTypes; -TYPED_TEST_CASE(MessageLoopTest, MessageLoopTypes); - - -TYPED_TEST(MessageLoopTest, CancelTaskInvalidValuesTest) { - EXPECT_FALSE(this->loop_->CancelTask(MessageLoop::kTaskIdNull)); - EXPECT_FALSE(this->loop_->CancelTask(1234)); -} - -TYPED_TEST(MessageLoopTest, PostTaskTest) { - bool called = false; - TaskId task_id = this->loop_->PostTask(FROM_HERE, Bind(&SetToTrue, &called)); - EXPECT_NE(MessageLoop::kTaskIdNull, task_id); - MessageLoopRunMaxIterations(this->loop_.get(), 100); - EXPECT_TRUE(called); -} - -// Tests that we can cancel tasks right after we schedule them. -TYPED_TEST(MessageLoopTest, PostTaskCancelledTest) { - bool called = false; - TaskId task_id = this->loop_->PostTask(FROM_HERE, Bind(&SetToTrue, &called)); - EXPECT_TRUE(this->loop_->CancelTask(task_id)); - MessageLoopRunMaxIterations(this->loop_.get(), 100); - EXPECT_FALSE(called); - // Can't remove a task you already removed. - EXPECT_FALSE(this->loop_->CancelTask(task_id)); -} - -TYPED_TEST(MessageLoopTest, PostDelayedTaskRunsEventuallyTest) { - bool called = false; - TaskId task_id = this->loop_->PostDelayedTask( - FROM_HERE, Bind(&SetToTrue, &called), TimeDelta::FromMilliseconds(50)); - EXPECT_NE(MessageLoop::kTaskIdNull, task_id); - MessageLoopRunUntil(this->loop_.get(), - TimeDelta::FromSeconds(10), - Bind(&ReturnBool, &called)); - // Check that the main loop finished before the 10 seconds timeout, so it - // finished due to the callback being called and not due to the timeout. - EXPECT_TRUE(called); -} - -// Test that you can call the overloaded version of PostDelayedTask from -// MessageLoop. This is important because only one of the two methods is -// virtual, so you need to unhide the other when overriding the virtual one. -TYPED_TEST(MessageLoopTest, PostDelayedTaskWithoutLocation) { - this->loop_->PostDelayedTask(base::DoNothing(), TimeDelta()); - EXPECT_EQ(1, MessageLoopRunMaxIterations(this->loop_.get(), 100)); -} - -TYPED_TEST(MessageLoopTest, WatchForInvalidFD) { - bool called = false; - EXPECT_EQ(MessageLoop::kTaskIdNull, this->loop_->WatchFileDescriptor( - FROM_HERE, -1, MessageLoop::kWatchRead, true, - Bind(&SetToTrue, &called))); - EXPECT_EQ(MessageLoop::kTaskIdNull, this->loop_->WatchFileDescriptor( - FROM_HERE, -1, MessageLoop::kWatchWrite, true, - Bind(&SetToTrue, &called))); - EXPECT_EQ(0, MessageLoopRunMaxIterations(this->loop_.get(), 100)); - EXPECT_FALSE(called); -} - -TYPED_TEST(MessageLoopTest, CancelWatchedFileDescriptor) { - ScopedPipe pipe; - bool called = false; - TaskId task_id = this->loop_->WatchFileDescriptor( - FROM_HERE, pipe.reader, MessageLoop::kWatchRead, true, - Bind(&SetToTrue, &called)); - EXPECT_NE(MessageLoop::kTaskIdNull, task_id); - // The reader end is blocked because we didn't write anything to the writer - // end. - EXPECT_EQ(0, MessageLoopRunMaxIterations(this->loop_.get(), 100)); - EXPECT_FALSE(called); - EXPECT_TRUE(this->loop_->CancelTask(task_id)); -} - -// When you watch a file descriptor for reading, the guaranties are that a -// blocking call to read() on that file descriptor will not block. This should -// include the case when the other end of a pipe is closed or the file is empty. -TYPED_TEST(MessageLoopTest, WatchFileDescriptorTriggersWhenPipeClosed) { - ScopedPipe pipe; - bool called = false; - EXPECT_EQ(0, HANDLE_EINTR(close(pipe.writer))); - pipe.writer = -1; - TaskId task_id = this->loop_->WatchFileDescriptor( - FROM_HERE, pipe.reader, MessageLoop::kWatchRead, true, - Bind(&SetToTrue, &called)); - EXPECT_NE(MessageLoop::kTaskIdNull, task_id); - // The reader end is not blocked because we closed the writer end so a read on - // the reader end would return 0 bytes read. - EXPECT_NE(0, MessageLoopRunMaxIterations(this->loop_.get(), 10)); - EXPECT_TRUE(called); - EXPECT_TRUE(this->loop_->CancelTask(task_id)); -} - -// When a WatchFileDescriptor task is scheduled with |persistent| = true, we -// should keep getting a call whenever the file descriptor is ready. -TYPED_TEST(MessageLoopTest, WatchFileDescriptorPersistently) { - ScopedPipe pipe; - EXPECT_EQ(1, HANDLE_EINTR(write(pipe.writer, "a", 1))); - - int called = 0; - TaskId task_id = this->loop_->WatchFileDescriptor( - FROM_HERE, pipe.reader, MessageLoop::kWatchRead, true, - Bind(&Increment, &called)); - EXPECT_NE(MessageLoop::kTaskIdNull, task_id); - // We let the main loop run for 20 iterations to give it enough iterations to - // verify that our callback was called more than one. We only check that our - // callback is called more than once. - EXPECT_EQ(20, MessageLoopRunMaxIterations(this->loop_.get(), 20)); - EXPECT_LT(1, called); - EXPECT_TRUE(this->loop_->CancelTask(task_id)); -} - -TYPED_TEST(MessageLoopTest, WatchFileDescriptorNonPersistent) { - ScopedPipe pipe; - EXPECT_EQ(1, HANDLE_EINTR(write(pipe.writer, "a", 1))); - - int called = 0; - TaskId task_id = this->loop_->WatchFileDescriptor( - FROM_HERE, pipe.reader, MessageLoop::kWatchRead, false, - Bind(&Increment, &called)); - EXPECT_NE(MessageLoop::kTaskIdNull, task_id); - // We let the main loop run for 20 iterations but we just expect it to run - // at least once. The callback should be called exactly once since we - // scheduled it non-persistently. After it ran, we shouldn't be able to cancel - // this task. - EXPECT_LT(0, MessageLoopRunMaxIterations(this->loop_.get(), 20)); - EXPECT_EQ(1, called); - EXPECT_FALSE(this->loop_->CancelTask(task_id)); -} - -TYPED_TEST(MessageLoopTest, WatchFileDescriptorForReadAndWriteSimultaneously) { - ScopedSocketPair socks; - EXPECT_EQ(1, HANDLE_EINTR(write(socks.right, "a", 1))); - // socks.left should be able to read this "a" and should also be able to write - // without blocking since the kernel has some buffering for it. - - TaskId read_task_id = this->loop_->WatchFileDescriptor( - FROM_HERE, socks.left, MessageLoop::kWatchRead, true, - Bind([] (MessageLoop* loop, TaskId* read_task_id) { - EXPECT_TRUE(loop->CancelTask(*read_task_id)) - << "task_id" << *read_task_id; - }, this->loop_.get(), &read_task_id)); - EXPECT_NE(MessageLoop::kTaskIdNull, read_task_id); - - TaskId write_task_id = this->loop_->WatchFileDescriptor( - FROM_HERE, socks.left, MessageLoop::kWatchWrite, true, - Bind([] (MessageLoop* loop, TaskId* write_task_id) { - EXPECT_TRUE(loop->CancelTask(*write_task_id)); - }, this->loop_.get(), &write_task_id)); - EXPECT_NE(MessageLoop::kTaskIdNull, write_task_id); - - EXPECT_LT(0, MessageLoopRunMaxIterations(this->loop_.get(), 20)); - - EXPECT_FALSE(this->loop_->CancelTask(read_task_id)); - EXPECT_FALSE(this->loop_->CancelTask(write_task_id)); -} - -// Test that we can cancel the task we are running, and should just fail. -TYPED_TEST(MessageLoopTest, DeleteTaskFromSelf) { - bool cancel_result = true; // We would expect this to be false. - TaskId task_id; - task_id = this->loop_->PostTask( - FROM_HERE, - Bind([](bool* cancel_result, MessageLoop* loop, TaskId* task_id) { - *cancel_result = loop->CancelTask(*task_id); - }, &cancel_result, this->loop_.get(), &task_id)); - EXPECT_EQ(1, MessageLoopRunMaxIterations(this->loop_.get(), 100)); - EXPECT_FALSE(cancel_result); -} - -// Test that we can cancel a non-persistent file descriptor watching callback, -// which should fail. -TYPED_TEST(MessageLoopTest, DeleteNonPersistenIOTaskFromSelf) { - ScopedPipe pipe; - TaskId task_id = this->loop_->WatchFileDescriptor( - FROM_HERE, pipe.writer, MessageLoop::kWatchWrite, false /* persistent */, - Bind([](MessageLoop* loop, TaskId* task_id) { - EXPECT_FALSE(loop->CancelTask(*task_id)); - *task_id = MessageLoop::kTaskIdNull; - }, this->loop_.get(), &task_id)); - EXPECT_NE(MessageLoop::kTaskIdNull, task_id); - EXPECT_EQ(1, MessageLoopRunMaxIterations(this->loop_.get(), 100)); - EXPECT_EQ(MessageLoop::kTaskIdNull, task_id); -} - -// Test that we can cancel a persistent file descriptor watching callback from -// the same callback. -TYPED_TEST(MessageLoopTest, DeletePersistenIOTaskFromSelf) { - ScopedPipe pipe; - TaskId task_id = this->loop_->WatchFileDescriptor( - FROM_HERE, pipe.writer, MessageLoop::kWatchWrite, true /* persistent */, - Bind([](MessageLoop* loop, TaskId* task_id) { - EXPECT_TRUE(loop->CancelTask(*task_id)); - *task_id = MessageLoop::kTaskIdNull; - }, this->loop_.get(), &task_id)); - EXPECT_NE(MessageLoop::kTaskIdNull, task_id); - EXPECT_EQ(1, MessageLoopRunMaxIterations(this->loop_.get(), 100)); - EXPECT_EQ(MessageLoop::kTaskIdNull, task_id); -} - -// Test that we can cancel several persistent file descriptor watching callbacks -// from a scheduled callback. In the BaseMessageLoop implementation, this code -// will cause us to cancel an IOTask that has a pending delayed task, but -// otherwise is a valid test case on all implementations. -TYPED_TEST(MessageLoopTest, DeleteAllPersistenIOTaskFromSelf) { - const int kNumTasks = 5; - ScopedPipe pipes[kNumTasks]; - TaskId task_ids[kNumTasks]; - - for (int i = 0; i < kNumTasks; ++i) { - task_ids[i] = this->loop_->WatchFileDescriptor( - FROM_HERE, pipes[i].writer, MessageLoop::kWatchWrite, - true /* persistent */, - Bind([] (MessageLoop* loop, TaskId* task_ids) { - for (int j = 0; j < kNumTasks; ++j) { - // Once we cancel all the tasks, none should run, so this code runs - // only once from one callback. - EXPECT_TRUE(loop->CancelTask(task_ids[j])); - task_ids[j] = MessageLoop::kTaskIdNull; - } - }, this->loop_.get(), task_ids)); - } - MessageLoopRunMaxIterations(this->loop_.get(), 100); - for (int i = 0; i < kNumTasks; ++i) { - EXPECT_EQ(MessageLoop::kTaskIdNull, task_ids[i]); - } -} - -// Test that if there are several tasks watching for file descriptors to be -// available or simply waiting in the message loop are fairly scheduled to run. -// In other words, this test ensures that having a file descriptor always -// available doesn't prevent other file descriptors watching tasks or delayed -// tasks to be dispatched, causing starvation. -TYPED_TEST(MessageLoopTest, AllTasksAreEqual) { - int total_calls = 0; - - // First, schedule a repeating timeout callback to run from the main loop. - int timeout_called = 0; - base::Closure timeout_callback; - MessageLoop::TaskId timeout_task; - timeout_callback = base::Bind( - [](MessageLoop* loop, int* timeout_called, int* total_calls, - base::Closure* timeout_callback, MessageLoop::TaskId* timeout_task) { - (*timeout_called)++; - (*total_calls)++; - *timeout_task = loop->PostTask(FROM_HERE, *timeout_callback); - if (*total_calls > 100) - loop->BreakLoop(); - }, - this->loop_.get(), &timeout_called, &total_calls, &timeout_callback, - &timeout_task); - timeout_task = this->loop_->PostTask(FROM_HERE, timeout_callback); - - // Second, schedule several file descriptor watchers. - const int kNumTasks = 3; - ScopedPipe pipes[kNumTasks]; - MessageLoop::TaskId tasks[kNumTasks]; - - int reads[kNumTasks] = {}; - base::Callback<void(int)> fd_callback = base::Bind( - [](MessageLoop* loop, ScopedPipe* pipes, int* reads, - int* total_calls, int i) { - reads[i]++; - (*total_calls)++; - char c; - EXPECT_EQ(1, HANDLE_EINTR(read(pipes[i].reader, &c, 1))); - if (*total_calls > 100) - loop->BreakLoop(); - }, this->loop_.get(), pipes, reads, &total_calls); - - for (int i = 0; i < kNumTasks; ++i) { - tasks[i] = this->loop_->WatchFileDescriptor( - FROM_HERE, pipes[i].reader, MessageLoop::kWatchRead, - true /* persistent */, - Bind(fd_callback, i)); - // Make enough bytes available on each file descriptor. This should not - // block because we set the size of the file descriptor buffer when - // creating it. - std::vector<char> blob(1000, 'a'); - EXPECT_EQ(blob.size(), - HANDLE_EINTR(write(pipes[i].writer, blob.data(), blob.size()))); - } - this->loop_->Run(); - EXPECT_GT(total_calls, 100); - // We run the loop up 100 times and expect each callback to run at least 10 - // times. A good scheduler should balance these callbacks. - EXPECT_GE(timeout_called, 10); - EXPECT_TRUE(this->loop_->CancelTask(timeout_task)); - for (int i = 0; i < kNumTasks; ++i) { - EXPECT_GE(reads[i], 10) << "Reading from pipes[" << i << "], fd " - << pipes[i].reader; - EXPECT_TRUE(this->loop_->CancelTask(tasks[i])); - } -} - -} // namespace brillo diff --git a/brillo/message_loops/message_loop_utils.cc b/brillo/message_loops/message_loop_utils.cc index 9ebe865..0f3214b 100644 --- a/brillo/message_loops/message_loop_utils.cc +++ b/brillo/message_loops/message_loop_utils.cc @@ -4,20 +4,19 @@ #include <brillo/message_loops/message_loop_utils.h> +#include <base/bind.h> #include <base/location.h> -#include <brillo/bind_lambda.h> namespace brillo { -void MessageLoopRunUntil( - MessageLoop* loop, - base::TimeDelta timeout, - base::Callback<bool()> terminate) { +void MessageLoopRunUntil(MessageLoop* loop, + base::TimeDelta timeout, + base::RepeatingCallback<bool()> terminate) { bool timeout_called = false; MessageLoop::TaskId task_id = loop->PostDelayedTask( FROM_HERE, - base::Bind([](bool* timeout_called) { *timeout_called = true; }, - base::Unretained(&timeout_called)), + base::BindOnce([](bool* timeout_called) { *timeout_called = true; }, + &timeout_called), timeout); while (!timeout_called && (terminate.is_null() || !terminate.Run())) loop->RunOnce(true); diff --git a/brillo/message_loops/message_loop_utils.h b/brillo/message_loops/message_loop_utils.h index d49ebdf..7384ddb 100644 --- a/brillo/message_loops/message_loop_utils.h +++ b/brillo/message_loops/message_loop_utils.h @@ -18,7 +18,7 @@ namespace brillo { BRILLO_EXPORT void MessageLoopRunUntil( MessageLoop* loop, base::TimeDelta timeout, - base::Callback<bool()> terminate); + base::RepeatingCallback<bool()> terminate); // Run the MessageLoop |loop| for up to |iterations| times without blocking. // Return the number of tasks run. diff --git a/brillo/message_loops/mock_message_loop.h b/brillo/message_loops/mock_message_loop.h index 9f9a1e4..357ec24 100644 --- a/brillo/message_loops/mock_message_loop.h +++ b/brillo/message_loops/mock_message_loop.h @@ -37,17 +37,9 @@ class BRILLO_EXPORT MockMessageLoop : public MessageLoop { &fake_loop_, static_cast<TaskId(FakeMessageLoop::*)( const base::Location&, - const base::Closure&, + base::OnceClosure, base::TimeDelta)>( &FakeMessageLoop::PostDelayedTask))); - ON_CALL(*this, WatchFileDescriptor( - ::testing::_, ::testing::_, ::testing::_, ::testing::_, ::testing::_)) - .WillByDefault(::testing::Invoke( - &fake_loop_, - static_cast<TaskId(FakeMessageLoop::*)( - const base::Location&, int, WatchMode, bool, - const base::Closure&)>( - &FakeMessageLoop::WatchFileDescriptor))); ON_CALL(*this, CancelTask(::testing::_)) .WillByDefault(::testing::Invoke(&fake_loop_, &FakeMessageLoop::CancelTask)); @@ -57,20 +49,13 @@ class BRILLO_EXPORT MockMessageLoop : public MessageLoop { } ~MockMessageLoop() override = default; - MOCK_METHOD3(PostDelayedTask, - TaskId(const base::Location& from_here, - const base::Closure& task, - base::TimeDelta delay)); + MOCK_METHOD(TaskId, + PostDelayedTask, + (const base::Location&, base::OnceClosure, base::TimeDelta), + (override)); using MessageLoop::PostDelayedTask; - MOCK_METHOD5(WatchFileDescriptor, - TaskId(const base::Location& from_here, - int fd, - WatchMode mode, - bool persistent, - const base::Closure& task)); - using MessageLoop::WatchFileDescriptor; - MOCK_METHOD1(CancelTask, bool(TaskId task_id)); - MOCK_METHOD1(RunOnce, bool(bool may_block)); + MOCK_METHOD(bool, CancelTask, (TaskId), (override)); + MOCK_METHOD(bool, RunOnce, (bool), (override)); // Returns the actual FakeMessageLoop instance so default actions can be // override with other actions or call diff --git a/brillo/mime_utils_unittest.cc b/brillo/mime_utils_test.cc index a7595dc..a7595dc 100644 --- a/brillo/mime_utils_unittest.cc +++ b/brillo/mime_utils_test.cc diff --git a/brillo/minijail/minijail.cc b/brillo/minijail/minijail.cc index 305f073..9f88585 100644 --- a/brillo/minijail/minijail.cc +++ b/brillo/minijail/minijail.cc @@ -11,6 +11,9 @@ using std::vector; namespace brillo { +static base::LazyInstance<Minijail>::DestructorAtExit g_minijail + = LAZY_INSTANCE_INITIALIZER; + Minijail::Minijail() {} Minijail::~Minijail() {} @@ -65,6 +68,14 @@ void Minijail::ResetSignalMask(struct minijail* jail) { minijail_reset_signal_mask(jail); } +void Minijail::CloseOpenFds(struct minijail* jail) { + minijail_close_open_fds(jail); +} + +void Minijail::PreserveFd(struct minijail* jail, int parent_fd, int child_fd) { + minijail_preserve_fd(jail, parent_fd, child_fd); +} + void Minijail::Enter(struct minijail* jail) { minijail_enter(jail); } @@ -110,6 +121,23 @@ bool Minijail::RunPipes(struct minijail* jail, #endif // __ANDROID__ } +bool Minijail::RunEnvPipes(struct minijail* jail, + vector<char*> args, + vector<char*> env, + pid_t* pid, + int* stdin, + int* stdout, + int* stderr) { +#if defined(__ANDROID__) + return minijail_run_env_pid_pipes_no_preload(jail, args[0], args.data(), + env.data(), pid, stdin, stdout, + stderr) == 0; +#else + return minijail_run_env_pid_pipes(jail, args[0], args.data(), env.data(), pid, + stdin, stdout, stderr) == 0; +#endif // __ANDROID__ +} + bool Minijail::RunAndDestroy(struct minijail* jail, vector<char*> args, pid_t* pid) { @@ -146,4 +174,16 @@ bool Minijail::RunPipesAndDestroy(struct minijail* jail, return res; } +bool Minijail::RunEnvPipesAndDestroy(struct minijail* jail, + vector<char*> args, + vector<char*> env, + pid_t* pid, + int* stdin, + int* stdout, + int* stderr) { + bool res = RunEnvPipes(jail, args, env, pid, stdin, stdout, stderr); + Destroy(jail); + return res; +} + } // namespace brillo diff --git a/brillo/minijail/minijail.h b/brillo/minijail/minijail.h index 15167cf..6cdc7ad 100644 --- a/brillo/minijail/minijail.h +++ b/brillo/minijail/minijail.h @@ -12,6 +12,9 @@ extern "C" { #include <sys/types.h> } +#include <base/lazy_instance.h> +#include <brillo/brillo_export.h> + #include <libminijail.h> #include "base/macros.h" @@ -19,7 +22,7 @@ extern "C" { namespace brillo { // A Minijail abstraction allowing Minijail mocking in tests. -class Minijail { +class BRILLO_EXPORT Minijail { public: virtual ~Minijail(); @@ -55,6 +58,12 @@ class Minijail { // minijail_reset_signal_mask virtual void ResetSignalMask(struct minijail* jail); + // minijail_close_open_fds + virtual void CloseOpenFds(struct minijail* jail); + + // minijail_preserve_fd + virtual void PreserveFd(struct minijail* jail, int parent_fd, int child_fd); + // minijail_enter virtual void Enter(struct minijail* jail); @@ -80,6 +89,14 @@ class Minijail { int* stdout, int* stderr); + // minijail_run_env_pid_pipes + virtual bool RunEnvPipes(struct minijail* jail, + std::vector<char*> args, + std::vector<char*> env, + pid_t* pid, + int* stdin, + int* stdout, + int* stderr); // Run() and Destroy() virtual bool RunAndDestroy(struct minijail* jail, std::vector<char*> args, @@ -104,10 +121,21 @@ class Minijail { int* stdout, int* stderr); + // RunEnvPipes() and Destroy() + virtual bool RunEnvPipesAndDestroy(struct minijail* jail, + std::vector<char*> args, + std::vector<char*> env, + pid_t* pid, + int* stdin, + int* stdout, + int* stderr); + protected: Minijail(); private: + friend base::LazyInstanceTraitsBase<Minijail>; + DISALLOW_COPY_AND_ASSIGN(Minijail); }; diff --git a/brillo/minijail/mock_minijail.h b/brillo/minijail/mock_minijail.h index a855632..21e7ad7 100644 --- a/brillo/minijail/mock_minijail.h +++ b/brillo/minijail/mock_minijail.h @@ -19,45 +19,70 @@ class MockMinijail : public brillo::Minijail { MockMinijail() {} virtual ~MockMinijail() {} - MOCK_METHOD0(New, struct minijail*()); - MOCK_METHOD1(Destroy, void(struct minijail*)); + MOCK_METHOD(struct minijail*, New, (), (override)); + MOCK_METHOD(void, Destroy, (struct minijail*), (override)); - MOCK_METHOD3(DropRoot, - bool(struct minijail* jail, - const char* user, - const char* group)); - MOCK_METHOD2(UseSeccompFilter, void(struct minijail* jail, const char* path)); - MOCK_METHOD2(UseCapabilities, void(struct minijail* jail, uint64_t capmask)); - MOCK_METHOD1(ResetSignalMask, void(struct minijail* jail)); - MOCK_METHOD1(Enter, void(struct minijail* jail)); - MOCK_METHOD3(Run, - bool(struct minijail* jail, - std::vector<char*> args, - pid_t* pid)); - MOCK_METHOD3(RunSync, - bool(struct minijail* jail, - std::vector<char*> args, - int* status)); - MOCK_METHOD3(RunAndDestroy, - bool(struct minijail* jail, - std::vector<char*> args, - pid_t* pid)); - MOCK_METHOD3(RunSyncAndDestroy, - bool(struct minijail* jail, - std::vector<char*> args, - int* status)); - MOCK_METHOD4(RunPipeAndDestroy, - bool(struct minijail* jail, - std::vector<char*> args, - pid_t* pid, - int* stdin)); - MOCK_METHOD6(RunPipesAndDestroy, - bool(struct minijail* jail, - std::vector<char*> args, - pid_t* pid, - int* stdin, - int* stdout, - int* stderr)); + MOCK_METHOD(bool, + DropRoot, + (struct minijail*, const char*, const char*), + (override)); + MOCK_METHOD(void, + UseSeccompFilter, + (struct minijail*, const char*), + (override)); + MOCK_METHOD(void, UseCapabilities, (struct minijail*, uint64_t), (override)); + MOCK_METHOD(void, ResetSignalMask, (struct minijail*), (override)); + MOCK_METHOD(void, CloseOpenFds, (struct minijail*), (override)); + MOCK_METHOD(void, PreserveFd, (struct minijail*, int, int), (override)); + MOCK_METHOD(void, Enter, (struct minijail*), (override)); + MOCK_METHOD(bool, + Run, + (struct minijail*, std::vector<char*>, pid_t*), + (override)); + MOCK_METHOD(bool, + RunSync, + (struct minijail*, std::vector<char*>, int*), + (override)); + MOCK_METHOD(bool, + RunPipes, + (struct minijail*, std::vector<char*>, pid_t*, int*, int*, int*), + (override)); + MOCK_METHOD(bool, + RunEnvPipes, + (struct minijail*, + std::vector<char*>, + std::vector<char*>, + pid_t*, + int*, + int*, + int*), + (override)); + MOCK_METHOD(bool, + RunAndDestroy, + (struct minijail*, std::vector<char*>, pid_t*), + (override)); + MOCK_METHOD(bool, + RunSyncAndDestroy, + (struct minijail*, std::vector<char*>, int*), + (override)); + MOCK_METHOD(bool, + RunPipeAndDestroy, + (struct minijail*, std::vector<char*>, pid_t*, int*), + (override)); + MOCK_METHOD(bool, + RunPipesAndDestroy, + (struct minijail*, std::vector<char*>, pid_t*, int*, int*, int*), + (override)); + MOCK_METHOD(bool, + RunEnvPipesAndDestroy, + (struct minijail*, + std::vector<char*>, + std::vector<char*>, + pid_t*, + int*, + int*, + int*), + (override)); private: DISALLOW_COPY_AND_ASSIGN(MockMinijail); diff --git a/brillo/namespaces/OWNERS b/brillo/namespaces/OWNERS new file mode 100644 index 0000000..e96f211 --- /dev/null +++ b/brillo/namespaces/OWNERS @@ -0,0 +1,2 @@ +betuls@chromium.org +jorgelo@chromium.org diff --git a/brillo/namespaces/mock_platform.h b/brillo/namespaces/mock_platform.h new file mode 100644 index 0000000..1b96b46 --- /dev/null +++ b/brillo/namespaces/mock_platform.h @@ -0,0 +1,37 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_NAMESPACES_MOCK_PLATFORM_H_ +#define LIBBRILLO_BRILLO_NAMESPACES_MOCK_PLATFORM_H_ + +#include "brillo/namespaces/platform.h" + +#include <string> + +#include <base/files/file_path.h> +#include <gmock/gmock.h> + +namespace brillo { + +class MockPlatform : public Platform { + public: + MockPlatform() {} + virtual ~MockPlatform() {} + + MOCK_METHOD(bool, Unmount, (const base::FilePath&, bool, bool*), (override)); + MOCK_METHOD(pid_t, Fork, (), (override)); + MOCK_METHOD(pid_t, Waitpid, (pid_t, int*), (override)); + MOCK_METHOD(int, + Mount, + (const std::string&, + const std::string&, + const std::string&, + uint64_t, + const void*), + (override)); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_NAMESPACES_MOCK_PLATFORM_H_ diff --git a/brillo/namespaces/mount_namespace.cc b/brillo/namespaces/mount_namespace.cc new file mode 100644 index 0000000..1944983 --- /dev/null +++ b/brillo/namespaces/mount_namespace.cc @@ -0,0 +1,114 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Contains the implementation of class MountNamespace for libbrillo. + +#include "brillo/namespaces/mount_namespace.h" + +#include <sched.h> +#include <sys/mount.h> +#include <sys/types.h> + +#include <string> + +#include <base/files/file_path.h> +#include <base/files/file_util.h> +#include <base/logging.h> +#include <base/strings/stringprintf.h> +#include <brillo/namespaces/platform.h> + +namespace brillo { +MountNamespace::MountNamespace(const base::FilePath& ns_path, + Platform* platform) + : ns_path_(ns_path), platform_(platform), exists_(false) {} + +MountNamespace::~MountNamespace() { + if (exists_) + Destroy(); +} + +bool MountNamespace::Create() { + if (platform_->FileSystemIsNsfs(ns_path_)) { + LOG(ERROR) << "Mount namespace at " << ns_path_.value() + << " already exists."; + return false; + } + int fd_mounted[2]; + int fd_unshared[2]; + char byte = '\0'; + if (pipe(fd_mounted) != 0) { + PLOG(ERROR) << "Cannot create mount signalling pipe"; + return false; + } + if (pipe(fd_unshared) != 0) { + PLOG(ERROR) << "Cannot create unshare signalling pipe"; + return false; + } + pid_t pid = platform_->Fork(); + if (pid < 0) { + PLOG(ERROR) << "Fork failed"; + } else if (pid == 0) { + // Child. + close(fd_mounted[1]); + close(fd_unshared[0]); + if (unshare(CLONE_NEWNS) != 0) { + PLOG(ERROR) << "unshare(CLONE_NEWNS) failed"; + exit(1); + } + base::WriteFileDescriptor(fd_unshared[1], &byte, 1); + base::ReadFromFD(fd_mounted[0], &byte, 1); + exit(0); + } else { + // Parent. + close(fd_mounted[0]); + close(fd_unshared[1]); + std::string proc_ns_path = base::StringPrintf("/proc/%d/ns/mnt", pid); + bool mount_success = true; + base::ReadFromFD(fd_unshared[0], &byte, 1); + if (platform_->Mount(proc_ns_path, ns_path_.value(), "", MS_BIND) != 0) { + PLOG(ERROR) << "Mount(" << proc_ns_path << ", " << ns_path_.value() + << ", MS_BIND) failed"; + mount_success = false; + } + base::WriteFileDescriptor(fd_mounted[1], &byte, 1); + + int status; + if (platform_->Waitpid(pid, &status) < 0) { + PLOG(ERROR) << "waitpid(" << pid << ") failed"; + return false; + } + if (!WIFEXITED(status)) { + LOG(ERROR) << "Child process did not exit normally."; + } else if (WEXITSTATUS(status) != 0) { + LOG(ERROR) << "Child process failed."; + } else { + exists_ = mount_success; + } + } + return exists_; +} + +bool MountNamespace::Destroy() { + if (!exists_) { + LOG(ERROR) << "Mount namespace at " << ns_path_.value() + << "does not exist, cannot destroy"; + return false; + } + bool was_busy; + if (!platform_->Unmount(ns_path_, false /*lazy*/, &was_busy)) { + PLOG(ERROR) << "Failed to unmount " << ns_path_.value(); + if (was_busy) { + LOG(ERROR) << ns_path_.value().c_str() << " was busy"; + } + // If Unmount() fails, keep the object valid by keeping |exists_| + // set to true. + return false; + } else { + VLOG(1) << "Unmounted namespace at " << ns_path_.value(); + } + exists_ = false; + return true; +} + +} // namespace brillo diff --git a/brillo/namespaces/mount_namespace.h b/brillo/namespaces/mount_namespace.h new file mode 100644 index 0000000..bfadff0 --- /dev/null +++ b/brillo/namespaces/mount_namespace.h @@ -0,0 +1,70 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_NAMESPACES_MOUNT_NAMESPACE_H_ +#define LIBBRILLO_BRILLO_NAMESPACES_MOUNT_NAMESPACE_H_ + +#include "brillo/namespaces/platform.h" + +#include <base/files/file_path.h> +#include <base/macros.h> +#include <brillo/brillo_export.h> + +namespace brillo { + +class BRILLO_EXPORT MountNamespaceInterface { + // An interface declaring the basic functionality of a mount namespace bound + // to a specific path. This basic functionality consists of reporting the + // namespace path. + public: + virtual ~MountNamespaceInterface() = default; + + virtual const base::FilePath& path() const = 0; +}; + +class BRILLO_EXPORT UnownedMountNamespace : public MountNamespaceInterface { + // A class to store and retrieve the path of a persistent namespace. This + // class doesn't create nor destroy the namespace. + public: + explicit UnownedMountNamespace(const base::FilePath& ns_path) + : ns_path_(ns_path) {} + + ~UnownedMountNamespace() override; + + const base::FilePath& path() const override { return ns_path_; } + + private: + base::FilePath ns_path_; + + DISALLOW_COPY_AND_ASSIGN(UnownedMountNamespace); +}; + +class BRILLO_EXPORT MountNamespace : public MountNamespaceInterface { + // A class to create a persistent mount namespace bound to a specific path. + // A new mount namespace is unshared from the mount namespace of the calling + // process when Create() is called; the namespace of the calling process + // remains unchanged. Recurring creation on a path is not allowed. + // + // Given that we cannot ensure that creation always succeeds this class is not + // fully RAII, but once the namespace is created (with Create()), it will be + // destroyed when the object goes out of scope. + public: + MountNamespace(const base::FilePath& ns_path, Platform* platform); + ~MountNamespace() override; + + bool Create(); + bool Destroy(); + const base::FilePath& path() const override { return ns_path_; } + + private: + base::FilePath ns_path_; + Platform* platform_; + bool exists_; + + DISALLOW_COPY_AND_ASSIGN(MountNamespace); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_NAMESPACES_MOUNT_NAMESPACE_H_ diff --git a/brillo/namespaces/mount_namespace_test.cc b/brillo/namespaces/mount_namespace_test.cc new file mode 100644 index 0000000..1bfa038 --- /dev/null +++ b/brillo/namespaces/mount_namespace_test.cc @@ -0,0 +1,92 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/namespaces/mock_platform.h" +#include "brillo/namespaces/mount_namespace.h" +#include "brillo/namespaces/platform.h" + +#include <unistd.h> + +#include <memory> + +#include <base/files/file_path.h> +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +using ::testing::_; +using ::testing::DoAll; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SetArgPointee; + +namespace brillo { + +class MountNamespaceTest : public ::testing::Test { + public: + MountNamespaceTest() {} + ~MountNamespaceTest() {} + void SetUp() {} + + void TearDown() {} + + protected: + NiceMock<MockPlatform> platform_; + + private: + DISALLOW_COPY_AND_ASSIGN(MountNamespaceTest); +}; + +TEST_F(MountNamespaceTest, CreateNamespace) { + std::unique_ptr<MountNamespace> ns = + std::make_unique<MountNamespace>(base::FilePath(), &platform_); + EXPECT_CALL(platform_, Fork()).WillOnce(Return(1)); + EXPECT_CALL(platform_, Mount(_, _, _, _, _)).WillOnce(Return(0)); + EXPECT_CALL(platform_, Waitpid(_, _)) + .WillOnce(DoAll(SetArgPointee<1>(0x00000000), Return(0))); + EXPECT_TRUE(ns->Create()); + EXPECT_CALL(platform_, Unmount(ns->path(), _, _)).WillOnce(Return(true)); +} + +TEST_F(MountNamespaceTest, CreateNamespaceFailedOnWaitpid) { + std::unique_ptr<MountNamespace> ns = + std::make_unique<MountNamespace>(base::FilePath(), &platform_); + EXPECT_CALL(platform_, Fork()).WillOnce(Return(1)); + EXPECT_CALL(platform_, Mount(_, _, _, _, _)).WillOnce(Return(0)); + EXPECT_CALL(platform_, Waitpid(_, _)).WillOnce(Return(-1)); + EXPECT_FALSE(ns->Create()); +} + +TEST_F(MountNamespaceTest, CreateNamespaceFailedOnMount) { + std::unique_ptr<MountNamespace> ns = + std::make_unique<MountNamespace>(base::FilePath(), &platform_); + EXPECT_CALL(platform_, Fork()).WillOnce(Return(1)); + EXPECT_CALL(platform_, Mount(_, _, _, _, _)).WillOnce(Return(-1)); + EXPECT_FALSE(ns->Create()); +} + +TEST_F(MountNamespaceTest, CreateNamespaceFailedOnStatus) { + std::unique_ptr<MountNamespace> ns = + std::make_unique<MountNamespace>(base::FilePath(), &platform_); + EXPECT_CALL(platform_, Fork()).WillOnce(Return(1)); + EXPECT_CALL(platform_, Mount(_, _, _, _, _)).WillOnce(Return(0)); + EXPECT_CALL(platform_, Waitpid(_, _)) + .WillOnce(DoAll(SetArgPointee<1>(0xFFFFFFFF), Return(0))); + EXPECT_FALSE(ns->Create()); +} + +TEST_F(MountNamespaceTest, DestroyAfterUnmountFailsAndUnmountSucceeds) { + std::unique_ptr<MountNamespace> ns = + std::make_unique<MountNamespace>(base::FilePath(), &platform_); + EXPECT_CALL(platform_, Fork()).WillOnce(Return(1)); + EXPECT_CALL(platform_, Mount(_, _, _, _, _)).WillOnce(Return(0)); + EXPECT_CALL(platform_, Waitpid(_, _)) + .WillOnce(DoAll(SetArgPointee<1>(0x00000000), Return(0))); + EXPECT_TRUE(ns->Create()); + EXPECT_CALL(platform_, Unmount(ns->path(), _, _)).WillOnce(Return(false)); + EXPECT_FALSE(ns->Destroy()); + EXPECT_CALL(platform_, Unmount(ns->path(), _, _)).WillOnce(Return(true)); + EXPECT_TRUE(ns->Destroy()); +} + +} // namespace brillo diff --git a/brillo/namespaces/platform.cc b/brillo/namespaces/platform.cc new file mode 100644 index 0000000..5fe9140 --- /dev/null +++ b/brillo/namespaces/platform.cc @@ -0,0 +1,75 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Contains the implementation of class Platform for libbrillo. + +#include "brillo/namespaces/platform.h" + +#include <errno.h> +#include <linux/magic.h> +#include <stdio.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <sys/vfs.h> +#include <sys/wait.h> + +#include <memory> + +#include <base/files/file_path.h> + +using base::FilePath; + +namespace brillo { + +Platform::Platform() {} + +Platform::~Platform() {} + +bool Platform::FileSystemIsNsfs(const FilePath& ns_path) { + struct statfs buff; + if (statfs(ns_path.value().c_str(), &buff) < 0) { + PLOG(ERROR) << "Statfs() error for " << ns_path.value(); + return false; + } + if ((uint64_t)buff.f_type == NSFS_MAGIC) { + return true; + } + return false; +} + +bool Platform::Unmount(const FilePath& path, bool lazy, bool* was_busy) { + int flags = 0; + if (lazy) { + flags = MNT_DETACH; + } + if (umount2(path.value().c_str(), flags) != 0) { + if (was_busy) { + *was_busy = (errno == EBUSY); + } + return false; + } + if (was_busy) { + *was_busy = false; + } + return true; +} + +int Platform::Mount(const std::string& source, + const std::string& target, + const std::string& fs_type, + uint64_t mount_flags, + const void* data) { + return mount(source.c_str(), target.c_str(), fs_type.c_str(), mount_flags, + data); +} + +pid_t Platform::Fork() { + return fork(); +} + +pid_t Platform::Waitpid(pid_t pid, int* status) { + return waitpid(pid, status, 0); +} + +} // namespace brillo diff --git a/brillo/namespaces/platform.h b/brillo/namespaces/platform.h new file mode 100644 index 0000000..6ef6a73 --- /dev/null +++ b/brillo/namespaces/platform.h @@ -0,0 +1,71 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_NAMESPACES_PLATFORM_H_ +#define LIBBRILLO_BRILLO_NAMESPACES_PLATFORM_H_ + +#include <sys/types.h> + +#include <memory> +#include <string> + +#include <base/files/file_path.h> +#include <base/macros.h> +#include <brillo/brillo_export.h> + +namespace brillo { +// Platform specific routines abstraction layer. +// Also helps us to be able to mock them in tests. +class BRILLO_EXPORT Platform { + public: + Platform(); + + virtual ~Platform(); + // Calls the platform fork() function and returns the pid returned + // by fork(). + virtual pid_t Fork(); + + // Calls the platform unmount. + // + // Parameters + // path - The path to unmount + // lazy - Whether to call a lazy unmount + // was_busy (OUT) - Set to true on return if the mount point was busy + virtual bool Unmount(const base::FilePath& path, bool lazy, bool* was_busy); + + // Calls the platform mount. + // + // Parameters + // source - The path to mount from + // target - The path to mount to + // fs_type - File system type of the mount + // mount_flags - Flags spesifying the type of the mount operation + // data - Mount options + virtual int Mount(const std::string& source, + const std::string& target, + const std::string& fs_type, + uint64_t mount_flags, + const void* = nullptr); + + // Checks the file system type of the |path| and returns true if the + // filesystem type is nsfs. + // + // Parameters + // path - The path to check the file system type + virtual bool FileSystemIsNsfs(const base::FilePath& path); + + // Calls the platform waitpid() function and returns the value returned by + // waitpid(). + // + // Parameters + // pid - The child pid to be waited on + // status (OUT)- Termination status of a child process. + virtual pid_t Waitpid(pid_t pid, int* status); + + DISALLOW_COPY_AND_ASSIGN(Platform); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_NAMESPACES_PLATFORM_H_ diff --git a/brillo/osrelease_reader.cc b/brillo/osrelease_reader.cc index c8f660e..7e533c0 100644 --- a/brillo/osrelease_reader.cc +++ b/brillo/osrelease_reader.cc @@ -52,4 +52,9 @@ void OsReleaseReader::Load(const base::FilePath& root_dir) { initialized_ = true; } +std::vector<std::string> OsReleaseReader::GetKeys() const { + CHECK(initialized_) << "OsReleaseReader.Load() must be called first."; + return store_.GetKeys(); +} + } // namespace brillo diff --git a/brillo/osrelease_reader.h b/brillo/osrelease_reader.h index f29c14d..372d3a1 100644 --- a/brillo/osrelease_reader.h +++ b/brillo/osrelease_reader.h @@ -10,6 +10,7 @@ #define LIBBRILLO_BRILLO_OSRELEASE_READER_H_ #include <string> +#include <vector> #include <brillo/brillo_export.h> #include <brillo/key_value_store.h> @@ -36,6 +37,9 @@ class BRILLO_EXPORT OsReleaseReader final { // Getter for the given key. Returns whether the key was found on the store. bool GetString(const std::string& key, std::string* value) const; + // Getter for all the keys in /etc/os-release. + std::vector<std::string> GetKeys() const; + private: // The map storing all the key-value pairs. KeyValueStore store_; diff --git a/brillo/osrelease_reader_unittest.cc b/brillo/osrelease_reader_test.cc index 9381367..9381367 100644 --- a/brillo/osrelease_reader_unittest.cc +++ b/brillo/osrelease_reader_test.cc diff --git a/brillo/process.cc b/brillo/process.cc index ead6f20..54e91f0 100644 --- a/brillo/process.cc +++ b/brillo/process.cc @@ -195,7 +195,7 @@ void ProcessImpl::CloseUnusedFileDescriptors() { // Since we're just trying to close anything we can find, // ignore any error return values of close(). IGNORE_EINTR(close(fd)); - } + } } bool ProcessImpl::Start() { @@ -309,7 +309,7 @@ bool ProcessImpl::Start() { } else { execv(argv[0], &argv[0]); } - PLOG(ERROR) << "Exec of " << argv[0] << " failed:"; + PLOG(ERROR) << "Exec of " << argv[0] << " failed"; _exit(kErrorExitStatus); } else { // Still executing inside the parent process with known child pid. diff --git a/brillo/process_information.h b/brillo/process_information.h index 3f0a2c9..13134bd 100644 --- a/brillo/process_information.h +++ b/brillo/process_information.h @@ -31,8 +31,8 @@ class BRILLO_EXPORT ProcessInformation { const std::vector<std::string>& get_cmd_line() { return cmd_line_; } - // Set the command line array. This method DOES swap out the contents of - // |value|. The caller should expect an empty set on return. + // Set the collection of open files. This method DOES swap out the contents + // of |value|. The caller should expect an empty set on return. void set_open_files(std::set<std::string>* value) { open_files_.clear(); open_files_.swap(*value); @@ -40,8 +40,8 @@ class BRILLO_EXPORT ProcessInformation { const std::set<std::string>& get_open_files() { return open_files_; } - // Set the command line array. This method DOES swap out the contents of - // |value|. The caller should expect an empty string on return. + // Set the current working directory. This method DOES swap out the contents + // of |value|. The caller should expect an empty string on return. void set_cwd(std::string* value) { cwd_.clear(); cwd_.swap(*value); diff --git a/brillo/process_mock.h b/brillo/process_mock.h index 92ffa0a..cc33681 100644 --- a/brillo/process_mock.h +++ b/brillo/process_mock.h @@ -19,29 +19,29 @@ class ProcessMock : public Process { ProcessMock() {} virtual ~ProcessMock() {} - MOCK_METHOD1(AddArg, void(const std::string& arg)); - MOCK_METHOD1(RedirectInput, void(const std::string& input_file)); - MOCK_METHOD1(RedirectOutput, void(const std::string& output_file)); - MOCK_METHOD2(RedirectUsingPipe, void(int child_fd, bool is_input)); - MOCK_METHOD2(BindFd, void(int parent_fd, int child_fd)); - MOCK_METHOD1(SetUid, void(uid_t)); - MOCK_METHOD1(SetGid, void(gid_t)); - MOCK_METHOD1(SetCapabilities, void(uint64_t capmask)); - MOCK_METHOD1(ApplySyscallFilter, void(const std::string& path)); - MOCK_METHOD0(EnterNewPidNamespace, void()); - MOCK_METHOD1(SetInheritParentSignalMask, void(bool)); - MOCK_METHOD1(SetPreExecCallback, void(const PreExecCallback&)); - MOCK_METHOD1(SetSearchPath, void(bool)); - MOCK_METHOD1(GetPipe, int(int child_fd)); - MOCK_METHOD0(Start, bool()); - MOCK_METHOD0(Wait, int()); - MOCK_METHOD0(Run, int()); - MOCK_METHOD0(pid, pid_t()); - MOCK_METHOD2(Kill, bool(int signal, int timeout)); - MOCK_METHOD1(Reset, void(pid_t)); - MOCK_METHOD1(ResetPidByFile, bool(const std::string& pid_file)); - MOCK_METHOD0(Release, pid_t()); - MOCK_METHOD1(SetCloseUnusedFileDescriptors, void(bool close_unused_fds)); + MOCK_METHOD(void, AddArg, (const std::string&), (override)); + MOCK_METHOD(void, RedirectInput, (const std::string&), (override)); + MOCK_METHOD(void, RedirectOutput, (const std::string&), (override)); + MOCK_METHOD(void, RedirectUsingPipe, (int, bool), (override)); + MOCK_METHOD(void, BindFd, (int, int), (override)); + MOCK_METHOD(void, SetUid, (uid_t), (override)); + MOCK_METHOD(void, SetGid, (gid_t), (override)); + MOCK_METHOD(void, SetCapabilities, (uint64_t), (override)); + MOCK_METHOD(void, ApplySyscallFilter, (const std::string&), (override)); + MOCK_METHOD(void, EnterNewPidNamespace, (), (override)); + MOCK_METHOD(void, SetInheritParentSignalMask, (bool), (override)); + MOCK_METHOD(void, SetPreExecCallback, (const PreExecCallback&), (override)); + MOCK_METHOD(void, SetSearchPath, (bool), (override)); + MOCK_METHOD(int, GetPipe, (int), (override)); + MOCK_METHOD(bool, Start, (), (override)); + MOCK_METHOD(int, Wait, (), (override)); + MOCK_METHOD(int, Run, (), (override)); + MOCK_METHOD(pid_t, pid, (), (override)); + MOCK_METHOD(bool, Kill, (int, int), (override)); + MOCK_METHOD(void, Reset, (pid_t), (override)); + MOCK_METHOD(bool, ResetPidByFile, (const std::string&), (override)); + MOCK_METHOD(pid_t, Release, (), (override)); + MOCK_METHOD(void, SetCloseUnusedFileDescriptors, (bool), (override)); }; } // namespace brillo diff --git a/brillo/process_reaper.cc b/brillo/process_reaper.cc index 0da3b5d..82e3f56 100644 --- a/brillo/process_reaper.cc +++ b/brillo/process_reaper.cc @@ -8,6 +8,8 @@ #include <sys/types.h> #include <sys/wait.h> +#include <utility> + #include <base/bind.h> #include <base/posix/eintr_wrapper.h> #include <brillo/asynchronous_signal_handler.h> @@ -37,10 +39,11 @@ void ProcessReaper::Unregister() { bool ProcessReaper::WatchForChild(const base::Location& from_here, pid_t pid, - const ChildCallback& callback) { + ChildCallback callback) { if (watched_processes_.find(pid) != watched_processes_.end()) return false; - watched_processes_.emplace(pid, WatchedProcess{from_here, callback}); + watched_processes_.emplace( + pid, WatchedProcess{from_here, std::move(callback)}); return true; } @@ -79,7 +82,7 @@ bool ProcessReaper::HandleSIGCHLD( << info.si_status << " (code = " << info.si_code << ")"; ChildCallback callback = std::move(proc->second.callback); watched_processes_.erase(proc); - callback.Run(info); + std::move(callback).Run(info); } } diff --git a/brillo/process_reaper.h b/brillo/process_reaper.h index 7b70a8d..4e348a3 100644 --- a/brillo/process_reaper.h +++ b/brillo/process_reaper.h @@ -19,7 +19,7 @@ namespace brillo { class BRILLO_EXPORT ProcessReaper final { public: // The callback called when a child exits. - using ChildCallback = base::Callback<void(const siginfo_t&)>; + using ChildCallback = base::OnceCallback<void(const siginfo_t&)>; ProcessReaper() = default; ~ProcessReaper(); @@ -41,7 +41,7 @@ class BRILLO_EXPORT ProcessReaper final { // as a siginfo_t. See wait(2) for details about siginfo_t. bool WatchForChild(const base::Location& from_here, pid_t pid, - const ChildCallback& callback); + ChildCallback callback); // Stop watching child process |pid|. This is useful in situations // where the child process may have been reaped outside of the signal diff --git a/brillo/process_reaper_unittest.cc b/brillo/process_reaper_test.cc index 98498f7..7b68236 100644 --- a/brillo/process_reaper_unittest.cc +++ b/brillo/process_reaper_test.cc @@ -12,7 +12,6 @@ #include <base/location.h> #include <base/message_loop/message_loop.h> #include <brillo/asynchronous_signal_handler.h> -#include <brillo/bind_lambda.h> #include <brillo/message_loops/base_message_loop.h> #include <gtest/gtest.h> @@ -74,7 +73,7 @@ TEST_F(ProcessReaperTest, UnregisterAndReregister) { TEST_F(ProcessReaperTest, ReapExitedChild) { pid_t pid = ForkChildAndExit(123); - EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::Bind( + EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::BindOnce( [](MessageLoop* loop, const siginfo_t& info) { EXPECT_EQ(CLD_EXITED, info.si_code); EXPECT_EQ(123, info.si_status); @@ -91,7 +90,7 @@ TEST_F(ProcessReaperTest, ReapedChildrenMatchCallbacks) { // Different processes will have different exit values. int exit_value = 1 + i; pid_t pid = ForkChildAndExit(exit_value); - EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::Bind( + EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::BindOnce( [](MessageLoop* loop, int exit_value, int* running_children, const siginfo_t& info) { EXPECT_EQ(CLD_EXITED, info.si_code); @@ -110,7 +109,7 @@ TEST_F(ProcessReaperTest, ReapedChildrenMatchCallbacks) { TEST_F(ProcessReaperTest, ReapKilledChild) { pid_t pid = ForkChildAndKill(SIGKILL); - EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::Bind( + EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::BindOnce( [](MessageLoop* loop, const siginfo_t& info) { EXPECT_EQ(CLD_KILLED, info.si_code); EXPECT_EQ(SIGKILL, info.si_status); @@ -121,7 +120,7 @@ TEST_F(ProcessReaperTest, ReapKilledChild) { TEST_F(ProcessReaperTest, ReapKilledAndForgottenChild) { pid_t pid = ForkChildAndExit(0); - EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::Bind( + EXPECT_TRUE(process_reaper_.WatchForChild(FROM_HERE, pid, base::BindOnce( [](MessageLoop* loop, const siginfo_t& /* info */) { ADD_FAILURE() << "Child process was still tracked."; loop->BreakLoop(); diff --git a/brillo/process_unittest.cc b/brillo/process_test.cc index f65cf34..533a8f0 100644 --- a/brillo/process_unittest.cc +++ b/brillo/process_test.cc @@ -12,8 +12,8 @@ #include <gtest/gtest.h> #include "brillo/process_mock.h" -#include "brillo/unittest_utils.h" #include "brillo/test_helpers.h" +#include "brillo/unittest_utils.h" using base::FilePath; diff --git a/brillo/proto_file_io.cc b/brillo/proto_file_io.cc new file mode 100644 index 0000000..47f3413 --- /dev/null +++ b/brillo/proto_file_io.cc @@ -0,0 +1,40 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/proto_file_io.h" + +#include <utility> + +#include <base/files/file.h> +#include <base/macros.h> +#include <google/protobuf/io/zero_copy_stream_impl.h> +#include <google/protobuf/text_format.h> + +namespace brillo { + +bool ReadTextProtobuf(const base::FilePath& proto_file, + google::protobuf::Message* out_proto) { + DCHECK(out_proto); + + base::File file(proto_file, base::File::FLAG_OPEN | base::File::FLAG_READ); + if (!file.IsValid()) { + DLOG(ERROR) << "Could not open \"" << proto_file.value() + << "\": " << base::File::ErrorToString(file.error_details()); + return false; + } + + return ReadTextProtobuf(file.GetPlatformFile(), out_proto); +} + +bool ReadTextProtobuf(int fd, google::protobuf::Message* out_proto) { + google::protobuf::io::FileInputStream input_stream(fd); + return google::protobuf::TextFormat::Parse(&input_stream, out_proto); +} + +bool WriteTextProtobuf(int fd, const google::protobuf::Message& proto) { + google::protobuf::io::FileOutputStream output_stream(fd); + return google::protobuf::TextFormat::Print(proto, &output_stream); +} + +} // namespace brillo diff --git a/brillo/proto_file_io.h b/brillo/proto_file_io.h new file mode 100644 index 0000000..77051cc --- /dev/null +++ b/brillo/proto_file_io.h @@ -0,0 +1,29 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_PROTO_FILE_IO_H_ +#define LIBBRILLO_BRILLO_PROTO_FILE_IO_H_ + +#include <base/files/file_path.h> +#include <brillo/brillo_export.h> +#include <google/protobuf/message.h> + +namespace brillo { + +// Simple utilities for serializing and deserializing protobufs in +// text format. For an example of the format, see the docs at +// https://developers.google.com/protocol-buffers/docs/overview#whynotxml + +BRILLO_EXPORT bool ReadTextProtobuf(const base::FilePath& proto_file, + google::protobuf::Message* out_proto); + +BRILLO_EXPORT bool ReadTextProtobuf(int fd, + google::protobuf::Message* out_proto); + +BRILLO_EXPORT bool WriteTextProtobuf(int fd, + const google::protobuf::Message& proto); + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_PROTO_FILE_IO_H_ diff --git a/brillo/scoped_mount_namespace.cc b/brillo/scoped_mount_namespace.cc new file mode 100644 index 0000000..09f3c75 --- /dev/null +++ b/brillo/scoped_mount_namespace.cc @@ -0,0 +1,66 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/scoped_mount_namespace.h" + +#include <fcntl.h> +#include <sched.h> +#include <sys/stat.h> +#include <sys/types.h> + +#include <string> +#include <utility> + +#include <base/posix/eintr_wrapper.h> +#include <base/strings/stringprintf.h> + +namespace { +constexpr char kCurrentMountNamespacePath[] = "/proc/self/ns/mnt"; +} // anonymous namespace + +namespace brillo { + +ScopedMountNamespace::ScopedMountNamespace(base::ScopedFD mount_namespace_fd) + : mount_namespace_fd_(std::move(mount_namespace_fd)) {} + +ScopedMountNamespace::~ScopedMountNamespace() { + PLOG_IF(ERROR, setns(mount_namespace_fd_.get(), CLONE_NEWNS) != 0) + << "Ignoring failure to restore original mount namespace"; +} + +// static +std::unique_ptr<ScopedMountNamespace> ScopedMountNamespace::CreateForPid( + pid_t pid) { + std::string ns_path = base::StringPrintf("/proc/%d/ns/mnt", pid); + return CreateFromPath(base::FilePath(ns_path)); +} + +// static +std::unique_ptr<ScopedMountNamespace> ScopedMountNamespace::CreateFromPath( + const base::FilePath& ns_path) { + base::ScopedFD original_mount_namespace_fd( + HANDLE_EINTR(open(kCurrentMountNamespacePath, O_RDONLY))); + if (!original_mount_namespace_fd.is_valid()) { + PLOG(ERROR) << "Failed to open original mount namespace FD at " + << kCurrentMountNamespacePath; + return nullptr; + } + + base::ScopedFD mount_namespace_fd( + HANDLE_EINTR(open(ns_path.value().c_str(), O_RDONLY))); + if (!mount_namespace_fd.is_valid()) { + PLOG(ERROR) << "Failed to open mount namespace FD at " << ns_path.value(); + return nullptr; + } + + if (setns(mount_namespace_fd.get(), CLONE_NEWNS) != 0) { + PLOG(ERROR) << "Failed to enter mount namespace at " << ns_path.value(); + return nullptr; + } + + return std::make_unique<ScopedMountNamespace>( + std::move(original_mount_namespace_fd)); +} + +} // namespace brillo diff --git a/brillo/scoped_mount_namespace.h b/brillo/scoped_mount_namespace.h new file mode 100644 index 0000000..e029d78 --- /dev/null +++ b/brillo/scoped_mount_namespace.h @@ -0,0 +1,44 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_SCOPED_MOUNT_NAMESPACE_H_ +#define LIBBRILLO_BRILLO_SCOPED_MOUNT_NAMESPACE_H_ + +#include <memory> + +#include <base/macros.h> +#include <base/files/file_path.h> +#include <base/files/scoped_file.h> + +#include <brillo/brillo_export.h> + +namespace brillo { + +// A class that restores a mount namespace when it goes out of scope. This can +// be done by entering another process' mount namespace by using +// CreateForPid(), or by supplying a mount namespace FD directly. +class BRILLO_EXPORT ScopedMountNamespace { + public: + // Enters the process identified by |pid|'s mount namespace and returns a + // unique_ptr that restores the original mount namespace when it goes out of + // scope. + static std::unique_ptr<ScopedMountNamespace> CreateForPid(pid_t pid); + + // Enters the mount namespace identified by |path| and returns a unique_ptr + // that restores the original mount namespace when it goes out of scope. + static std::unique_ptr<ScopedMountNamespace> CreateFromPath( + const base::FilePath& ns_path); + + explicit ScopedMountNamespace(base::ScopedFD mount_namespace_fd); + ~ScopedMountNamespace(); + + private: + base::ScopedFD mount_namespace_fd_; + + DISALLOW_COPY_AND_ASSIGN(ScopedMountNamespace); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_SCOPED_MOUNT_NAMESPACE_H_ diff --git a/brillo/scoped_umask.cc b/brillo/scoped_umask.cc new file mode 100644 index 0000000..ac6b208 --- /dev/null +++ b/brillo/scoped_umask.cc @@ -0,0 +1,19 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/scoped_umask.h" + +#include <sys/stat.h> + +namespace brillo { + +ScopedUmask::ScopedUmask(mode_t new_umask) { + saved_umask_ = umask(new_umask); +} + +ScopedUmask::~ScopedUmask() { + umask(saved_umask_); +} + +} // namespace brillo diff --git a/brillo/scoped_umask.h b/brillo/scoped_umask.h new file mode 100644 index 0000000..5369e83 --- /dev/null +++ b/brillo/scoped_umask.h @@ -0,0 +1,52 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_SCOPED_UMASK_H_ +#define LIBBRILLO_BRILLO_SCOPED_UMASK_H_ + +#include <sys/types.h> + +#include <base/macros.h> +#include <brillo/brillo_export.h> + +namespace brillo { + +// ScopedUmask is a helper class for temporarily setting the umask before a +// set of operations. umask(2) is never expected to fail. +class BRILLO_EXPORT ScopedUmask { + public: + explicit ScopedUmask(mode_t new_umask); + ~ScopedUmask(); + + private: + mode_t saved_umask_; + + // Avoid reusing ScopedUmask for multiple masks. DISALLOW_COPY_AND_ASSIGN + // deletes the copy constructor and operator=, but there are other situations + // where reassigning a new ScopedUmask to an existing ScopedUmask object + // is problematic: + // + // /* starting umask: default_value + // auto a = std::make_unique<ScopedUmask>(first_value); + // ... code here ... + // a.reset(ScopedUmask(new_value)); + // + // Here, the order of destruction of the old object and the construction of + // the new object is inverted. The recommended usage would be: + // + // { + // ScopedUmask a(old_value); + // ... code here ... + // } + // + // { + // ScopedUmask a(new_value); + // ... code here ... + // } + DISALLOW_COPY_AND_ASSIGN(ScopedUmask); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_SCOPED_UMASK_H_ diff --git a/brillo/scoped_umask_test.cc b/brillo/scoped_umask_test.cc new file mode 100644 index 0000000..d1caa3c --- /dev/null +++ b/brillo/scoped_umask_test.cc @@ -0,0 +1,57 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/scoped_umask.h" + +#include <fcntl.h> + +#include <base/files/file_path.h> +#include <base/files/file_util.h> +#include <base/files/scoped_file.h> +#include <base/files/scoped_temp_dir.h> +#include <gtest/gtest.h> + +namespace brillo { +namespace { + +constexpr int kPermissions600 = + base::FILE_PERMISSION_READ_BY_USER | base::FILE_PERMISSION_WRITE_BY_USER; +constexpr int kPermissions700 = base::FILE_PERMISSION_USER_MASK; +constexpr mode_t kMask700 = ~(0700); +constexpr mode_t kMask600 = ~(0600); + +void CheckFilePermissions(const base::FilePath& path, + int expected_permissions) { + int mode = 0; + // Try to create a file with broader permissions than the mask may provide. + base::ScopedFD fd( + HANDLE_EINTR(open(path.value().c_str(), O_WRONLY | O_CREAT, 0777))); + EXPECT_TRUE(fd.is_valid()); + EXPECT_TRUE(base::GetPosixFilePermissions(path, &mode)); + EXPECT_EQ(mode, expected_permissions); +} + +} // namespace + +TEST(ScopedUmask, CheckUmaskScope) { + base::ScopedTempDir tmpdir; + CHECK(tmpdir.CreateUniqueTempDir()); + + brillo::ScopedUmask outer_scoped_umask_(kMask700); + CheckFilePermissions(tmpdir.GetPath().AppendASCII("file1.txt"), + kPermissions700); + { + // A new scoped umask should result in different permissions for files + // created in this scope. + brillo::ScopedUmask inner_scoped_umask_(kMask600); + CheckFilePermissions(tmpdir.GetPath().AppendASCII("file2.txt"), + kPermissions600); + } + // Since inner_scoped_umask_ has been deconstructed, permissions on all new + // files should now use outer_scoped_umask_. + CheckFilePermissions(tmpdir.GetPath().AppendASCII("file3.txt"), + kPermissions700); +} + +} // namespace brillo diff --git a/brillo/secure_allocator.h b/brillo/secure_allocator.h new file mode 100644 index 0000000..de0b348 --- /dev/null +++ b/brillo/secure_allocator.h @@ -0,0 +1,241 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_SECURE_ALLOCATOR_H_ +#define LIBBRILLO_BRILLO_SECURE_ALLOCATOR_H_ + +#include <errno.h> +#include <sys/mman.h> +#include <unistd.h> + +#include <limits> +#include <memory> + +#include <base/callback_helpers.h> +#include <base/logging.h> +#include <brillo/brillo_export.h> + +namespace brillo { +// SecureAllocator is a stateless derivation of std::allocator that clears +// the contents of the object on deallocation. Additionally, to prevent the +// memory from being leaked, we use the following defensive mechanisms: +// +// 1. Use page-aligned memory so that it can be locked (therefore, use mmap() +// instead of malloc()). Note that mlock()s are not inherited over fork(), +// +// 2. Always allocate memory in multiples of pages: this adds a memory overhead +// of ~1 page for each object. Moreover, the extra memory is not available +// for the allocated object to expand into: the container expects that the +// memory allocated to it matches the size set in reserve(). +// TODO(sarthakkukreti): Figure out if it is possible to propagate the real +// capacity to the container without an intrusive change to the STL. +// [Example: allow __recommend() override in allocators for containers.] +// +// 3. Mark the memory segments as undumpable, unmergeable. +// +// 4. Use MADV_WIPEONFORK: +// this results in a new anonymous vma instead of copying over the contents +// of the secure object after a fork(). By default [MADV_DOFORK], the vma is +// marked as copy-on-write, and the first process which writes to the secure +// object after fork get a new copy. This may break the security guarantees +// setup above. Another alternative is to use MADV_DONTFORK which results in +// the memory mapping not getting copied over to child process at all: this +// may result in cases where if the child process gets segmentation faults +// on attempts to access virtual addresses in the secure object's address +// range, +// +// With MADV_WIPEONFORK, the child processes can access the secure object +// memory safely, but the contents of the secure object appear as zero to +// the child process. Note that threads share the virtual address space and +// secure objects would be transparent across threads. +// TODO(sarthakkukreti): Figure out patterns to pass secure data over fork(). +template <typename T> +class BRILLO_PRIVATE SecureAllocator : public std::allocator<T> { + public: + using typename std::allocator<T>::pointer; + using typename std::allocator<T>::size_type; + using typename std::allocator<T>::value_type; + + // Constructors that wrap over std::allocator. + // Make sure that the allocator's static members are only allocated once. + SecureAllocator() noexcept : std::allocator<T>() {} + SecureAllocator(const SecureAllocator& other) noexcept + : std::allocator<T>(other) {} + + template <class U> + SecureAllocator(const SecureAllocator<U>& other) noexcept + : std::allocator<T>(other) {} + + template <typename U> struct rebind { + typedef SecureAllocator<U> other; + }; + + // Return cached max_size. Deprecated in C++17, removed in C++20. + size_type max_size() const { return max_size_; } + + // Allocation: allocate ceil(size/pagesize) for holding the data. + pointer allocate(size_type n, pointer hint = nullptr) { + pointer buffer = nullptr; + // Check if n can be theoretically allocated. + CHECK_LT(n, max_size()); + + // std::allocator is expected to throw a std::bad_alloc on failing to + // allocate the memory correctly. Instead of returning a nullptr, which + // confuses the standard template library, use CHECK(false) to crash on + // the failure path. + base::ScopedClosureRunner fail_on_allocation_error(base::Bind([]() { + PLOG(ERROR) << "Failed to allocate secure memory"; + CHECK(false); + })); + + // Check if n = 0: there's nothing to allocate; + if (n == 0) + return nullptr; + + // Calculate the page-aligned buffer size. + size_type buffer_size = CalculatePageAlignedBufferSize(n); + + // Memory locking granularity is per-page: mmap ceil(size/page size) pages. + buffer = reinterpret_cast<pointer>( + mmap(nullptr, buffer_size, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); + if (buffer == MAP_FAILED) + return nullptr; + + // Lock buffer into physical memory. + if (mlock(buffer, buffer_size)) { + CHECK_NE(errno, ENOMEM) << "It is likely that SecureAllocator have " + "exceeded the RLIMIT_MEMLOCK limit"; + return nullptr; + } + + // Mark memory as non dumpable in a core dump. + if (madvise(buffer, buffer_size, MADV_DONTDUMP)) + return nullptr; + + // Mark memory as non mergeable with another page, even if the contents + // are the same. + if (madvise(buffer, buffer_size, MADV_UNMERGEABLE)) { + // MADV_UNMERGEABLE is only available if the kernel has been configured + // with CONFIG_KSM set. If the CONFIG_KSM flag has not been set, then + // pages are not mergeable so this madvise option is not necessary. + // + // In the case where CONFIG_KSM is not set, EINVAL is the error set. + // Since this error value is expected in some cases, we don't return a + // nullptr. + if (errno != EINVAL) + return nullptr; + } + + // Make this mapping available to child processes but don't copy data from + // the secure object's pages during fork. With MADV_DONTFORK, the + // vma is not mapped in the child process which leads to segmentation + // faults if the child process tries to access this address. For example, + // if the parent process creates a SecureObject, forks() and the child + // process tries to call the destructor at the virtual address. + if (madvise(buffer, buffer_size, MADV_WIPEONFORK)) + return nullptr; + + ignore_result(fail_on_allocation_error.Release()); + + // Allocation was successful. + return buffer; + } + + // Destroy object before deallocation. Deprecated in C++17, removed in C++20. + // After destroying the object, clear the contents of where the object was + // stored. + template <class U> + void destroy(U* p) { + // Return if the pointer is invalid. + if (!p) + return; + std::allocator<U>::destroy(p); + clear_contents(p, sizeof(U)); + } + + virtual void deallocate(pointer p, size_type n) { + // Check if n can be theoretically deallocated. + CHECK_LT(n, max_size()); + + // Check if n = 0 or p is a nullptr: there's nothing to deallocate; + if (n == 0 || !p) + return; + + // Calculate the page-aligned buffer size. + size_type buffer_size = CalculatePageAlignedBufferSize(n); + + clear_contents(p, buffer_size); + munlock(p, buffer_size); + munmap(p, buffer_size); + } + + protected: +// Force memset to not be optimized out. +// Original source commit: 31b02653c2560f8331934e879263beda44c6cc76 +// Repo: https://android.googlesource.com/platform/external/minijail +#if defined(__clang__) +#define __attribute_no_opt __attribute__((optnone)) +#else +#define __attribute_no_opt __attribute__((__optimize__(0))) +#endif + + // Zero-out all bytes in the allocated buffer. + virtual void __attribute_no_opt clear_contents(pointer v, size_type n) { + if (!v) + return; + memset(v, 0, n); + } + +#undef __attribute_no_opt + + private: + // Calculate the page-aligned buffer size. + size_t CalculatePageAlignedBufferSize(size_type n) { + size_type real_size = n * sizeof(value_type); + size_type page_aligned_remainder = real_size % page_size_; + size_type padding = + page_aligned_remainder != 0 ? page_size_ - page_aligned_remainder : 0; + return real_size + padding; + } + + static size_t CalculatePageSize() { + long ret = sysconf(_SC_PAGESIZE); // NOLINT [runtime/int] + + // Initialize page size. + CHECK_GT(ret, 0L); + return ret; + } + + // Since the allocator reuses page size and max size consistently, + // cache these values initially and reuse. + static size_t GetMaxSizeForType(size_t page_size) { + // Initialize max size that can be theoretically allocated. + // Calculate the max size that is page-aligned. + size_t max_theoretical_size = std::numeric_limits<size_type>::max(); + size_t max_page_aligned_size = + max_theoretical_size - (max_theoretical_size % page_size); + + return max_page_aligned_size / sizeof(value_type); + } + + // Page size on system. + static const size_type page_size_; + // Max theoretical count for type on system. + static const size_type max_size_; +}; + +// Inline definitions are only allowed for static const members with integral +// constexpr initializers, define static members of SecureAllocator types here. +template <typename T> +const typename SecureAllocator<T>::size_type SecureAllocator<T>::page_size_ = + SecureAllocator<T>::CalculatePageSize(); + +template <typename T> +const typename SecureAllocator<T>::size_type SecureAllocator<T>::max_size_ = + SecureAllocator<T>::GetMaxSizeForType(SecureAllocator<T>::page_size_); + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_SECURE_ALLOCATOR_H_ diff --git a/brillo/secure_blob.cc b/brillo/secure_blob.cc index f4b797f..70950f6 100644 --- a/brillo/secure_blob.cc +++ b/brillo/secure_blob.cc @@ -11,6 +11,23 @@ namespace brillo { +namespace { + +bool ConvertHexToBytes(char c, uint8_t* v) { + if (c >= '0' && c <='9') + *v = c - '0'; + else if (c >= 'a' && c <= 'f') + *v = c - 'a' + 10; + else if (c >= 'A' && c <= 'F') + *v = c - 'A' + 10; + else + return false; + + return true; +} + +} // namespace + std::string BlobToString(const Blob& blob) { return std::string(blob.begin(), blob.end()); } @@ -109,4 +126,38 @@ int SecureMemcmp(const void* s1, const void* s2, size_t n) { return result != 0; } +// base::HexEncode and base::HexStringToBytes use strings, which may leak +// contents. These functions are alternatives that keep all contents +// within secured memory. +SecureBlob SecureBlobToSecureHex(const SecureBlob& blob) { + std::string kHexChars("0123456789ABCDEF"); + SecureBlob hex(blob.size() * 2, 0); + const char* blob_char_data = blob.char_data(); + + // Each input byte creates two output hex characters. + for (size_t i = 0; i < blob.size(); ++i) { + hex[(i * 2)] = kHexChars[(blob_char_data[i] >> 4) & 0xf]; + hex[(i * 2) + 1] = kHexChars[blob_char_data[i] & 0xf]; + } + return hex; +} + +SecureBlob SecureHexToSecureBlob(const SecureBlob& hex) { + SecureBlob blob(hex.size()/2, 0); + + if (hex.size() == 0 || hex.size() % 2) + return SecureBlob(); + + for (size_t i = 0; i < hex.size(); i++) { + uint8_t v; + // Check for invalid characters. + if (!ConvertHexToBytes(hex[i], &v)) + return SecureBlob(); + + blob[i/2] = (blob[i/2] << 4) | (v & 0xf); + } + + return blob; +} + } // namespace brillo diff --git a/brillo/secure_blob.h b/brillo/secure_blob.h index 7b6d03c..7705c1a 100644 --- a/brillo/secure_blob.h +++ b/brillo/secure_blob.h @@ -14,7 +14,10 @@ namespace brillo { +// TODO(sarthakkukreti): remove temp. SecureVector once we break SecureBlob's +// dependence on std::vector<uint8_t> using Blob = std::vector<uint8_t>; +using SecureVector = std::vector<uint8_t>; // Conversion of Blob to/from std::string, where the string holds raw byte // contents. @@ -69,6 +72,14 @@ BRILLO_EXPORT BRILLO_DISABLE_ASAN void* SecureMemset(void* v, int c, size_t n); // [n] and not on the relationship of the match between [s1] and [s2]. BRILLO_EXPORT int SecureMemcmp(const void* s1, const void* s2, size_t n); +// Conversion of SecureBlob data to/from SecureBlob hex. This is useful +// for sensitive data like encryption keys, that should, in the ideal case never +// be exposed as strings in the first place. In case the existing data or hex +// string is already exposed as a std::string, it is preferable to use the +// BlobToString variant. +BRILLO_EXPORT SecureBlob SecureBlobToSecureHex(const SecureBlob& blob); +BRILLO_EXPORT SecureBlob SecureHexToSecureBlob(const SecureBlob& hex); + } // namespace brillo #endif // LIBBRILLO_BRILLO_SECURE_BLOB_H_ diff --git a/brillo/secure_blob_unittest.cc b/brillo/secure_blob_test.cc index ff95d0f..7242e86 100644 --- a/brillo/secure_blob_unittest.cc +++ b/brillo/secure_blob_test.cc @@ -5,6 +5,7 @@ // Unit tests for SecureBlob. #include "brillo/asan.h" +#include "brillo/secure_allocator.h" #include "brillo/secure_blob.h" #include <algorithm> @@ -227,4 +228,87 @@ TEST_F(SecureBlobTest, HexStringToSecureBlob) { EXPECT_EQ(blob[15], 0x0f); } +// Override clear_contents() to check whether memory has been cleared. +template <typename T> +class TestSecureAllocator : public SecureAllocator<T> { + public: + using typename SecureAllocator<T>::pointer; + using typename SecureAllocator<T>::size_type; + using typename SecureAllocator<T>::value_type; + + int GetErasedCount() { return erased_count; } + + protected: + void clear_contents(pointer p, size_type n) override { + SecureAllocator<T>::clear_contents(p, n); + unsigned char *v = reinterpret_cast<unsigned char*>(p); + for (int i = 0; i < n; i++) { + EXPECT_EQ(v[i], 0); + erased_count++; + } + } + + private: + int erased_count = 0; +}; + +TEST(SecureAllocator, ErasureOnDeallocation) { + // Make sure that the contents are cleared on deallocation. + TestSecureAllocator<char> e; + + char *test_string_addr = e.allocate(15); + snprintf(test_string_addr, sizeof(test_string_addr), "Test String"); + + // Deallocate memory; the mock class should check for cleared data. + e.deallocate(test_string_addr, 15); + // The deallocation should have traversed the complete page. + EXPECT_EQ(e.GetErasedCount(), 4096); +} + +TEST(SecureAllocator, MultiPageCorrectness) { + // Make sure that the contents are cleared on deallocation. + TestSecureAllocator<uint64_t> e; + + // Allocate 4100*8 bytes. + uint64_t *test_array = e.allocate(4100); + + // Check if the space was correctly allocated for long long. + for (int i = 0; i < 4100; i++) + test_array[i] = 0xF0F0F0F0F0F0F0F0; + + // Deallocate memory; the mock class should check for cleared data. + e.deallocate(test_array, 4100); + // 36864 bytes is the next highest size that is a multiple of the page size. + EXPECT_EQ(e.GetErasedCount(), 36864); +} + +// DeathTests fork a new process and check how it proceeds. Take advantage +// of this and check if the value of SecureString is passed on to +// forked children. +#if GTEST_IS_THREADSAFE +// Check if the contents of the container are zeroed out. +void CheckPropagationOnFork(const brillo::SecureBlob& forked_blob, + const Blob& reference) { + LOG(INFO) << forked_blob.to_string(); + for (int i = 0; i < forked_blob.size(); i++) { + CHECK_NE(reference[i], forked_blob[i]); + CHECK_EQ(forked_blob[i], 0); + } + exit(0); +} + +TEST(SecureAllocatorDeathTest, ErasureOnFork) { + Blob reference = BlobFromString("Test String"); + SecureBlob erasable_blob(reference.begin(), reference.end()); + + EXPECT_EXIT(CheckPropagationOnFork(erasable_blob, reference), + ::testing::ExitedWithCode(0), ""); + + // In the original process, check the SecureBlob to see if it has not + // changed. + for (int i = 0; i < erasable_blob.size(); i++) + EXPECT_EQ(erasable_blob[i], reference[i]); +} +#endif // GTEST_IS_THREADSAFE + } // namespace brillo diff --git a/brillo/streams/fake_stream.cc b/brillo/streams/fake_stream.cc index 498b9d4..9d7a044 100644 --- a/brillo/streams/fake_stream.cc +++ b/brillo/streams/fake_stream.cc @@ -5,6 +5,7 @@ #include <brillo/streams/fake_stream.h> #include <algorithm> +#include <utility> #include <base/bind.h> #include <brillo/message_loops/message_loop.h> @@ -185,7 +186,7 @@ bool FakeStream::IsReadBufferEmpty() const { bool FakeStream::PopReadPacket() { if (incoming_queue_.empty()) return false; - const InputDataPacket& packet = incoming_queue_.front(); + InputDataPacket& packet = incoming_queue_.front(); input_ptr_ = 0; input_buffer_ = std::move(packet.data); delay_input_until_ = clock_->Now() + packet.delay_before; @@ -250,7 +251,7 @@ bool FakeStream::IsWriteBufferFull() const { bool FakeStream::PopWritePacket() { if (outgoing_queue_.empty()) return false; - const OutputDataPacket& packet = outgoing_queue_.front(); + OutputDataPacket& packet = outgoing_queue_.front(); expected_output_data_ = std::move(packet.data); delay_output_until_ = clock_->Now() + packet.delay_before; max_output_buffer_size_ = packet.expected_size; diff --git a/brillo/streams/fake_stream_unittest.cc b/brillo/streams/fake_stream_test.cc index 2404514..2e83e3b 100644 --- a/brillo/streams/fake_stream_unittest.cc +++ b/brillo/streams/fake_stream_test.cc @@ -4,11 +4,12 @@ #include <brillo/streams/fake_stream.h> +#include <memory> #include <vector> +#include <base/bind.h> #include <base/callback.h> #include <base/test/simple_test_clock.h> -#include <brillo/bind_lambda.h> #include <brillo/message_loops/mock_message_loop.h> #include <gmock/gmock.h> #include <gtest/gtest.h> diff --git a/brillo/streams/file_stream.cc b/brillo/streams/file_stream.cc index 7b28a5a..70b25dd 100644 --- a/brillo/streams/file_stream.cc +++ b/brillo/streams/file_stream.cc @@ -4,12 +4,15 @@ #include <brillo/streams/file_stream.h> -#include <algorithm> #include <fcntl.h> #include <sys/stat.h> #include <unistd.h> +#include <algorithm> +#include <utility> + #include <base/bind.h> +#include <base/files/file_descriptor_watcher_posix.h> #include <base/files/file_util.h> #include <base/posix/eintr_wrapper.h> #include <brillo/errors/error_codes.h> @@ -84,15 +87,11 @@ class FileDescriptor : public FileStream::FileDescriptorInterface { ErrorPtr* error) override { if (stream_utils::IsReadAccessMode(mode)) { CHECK(read_data_callback_.is_null()); - MessageLoop::current()->CancelTask(read_watcher_); - read_watcher_ = MessageLoop::current()->WatchFileDescriptor( - FROM_HERE, + read_watcher_ = base::FileDescriptorWatcher::WatchReadable( fd_, - MessageLoop::WatchMode::kWatchRead, - false, // persistent - base::Bind(&FileDescriptor::OnFileCanReadWithoutBlocking, - base::Unretained(this))); - if (read_watcher_ == MessageLoop::kTaskIdNull) { + base::BindRepeating(&FileDescriptor::OnReadable, + base::Unretained(this))); + if (!read_watcher_) { Error::AddTo(error, FROM_HERE, errors::stream::kDomain, errors::stream::kInvalidParameter, "File descriptor doesn't support watching for reading."); @@ -102,15 +101,11 @@ class FileDescriptor : public FileStream::FileDescriptorInterface { } if (stream_utils::IsWriteAccessMode(mode)) { CHECK(write_data_callback_.is_null()); - MessageLoop::current()->CancelTask(write_watcher_); - write_watcher_ = MessageLoop::current()->WatchFileDescriptor( - FROM_HERE, + write_watcher_ = base::FileDescriptorWatcher::WatchWritable( fd_, - MessageLoop::WatchMode::kWatchWrite, - false, // persistent - base::Bind(&FileDescriptor::OnFileCanWriteWithoutBlocking, - base::Unretained(this))); - if (write_watcher_ == MessageLoop::kTaskIdNull) { + base::BindRepeating(&FileDescriptor::OnWritable, + base::Unretained(this))); + if (!write_watcher_) { Error::AddTo(error, FROM_HERE, errors::stream::kDomain, errors::stream::kInvalidParameter, "File descriptor doesn't support watching for writing."); @@ -155,31 +150,26 @@ class FileDescriptor : public FileStream::FileDescriptorInterface { void CancelPendingAsyncOperations() override { read_data_callback_.Reset(); - if (read_watcher_ != MessageLoop::kTaskIdNull) { - MessageLoop::current()->CancelTask(read_watcher_); - read_watcher_ = MessageLoop::kTaskIdNull; - } - + read_watcher_ = nullptr; write_data_callback_.Reset(); - if (write_watcher_ != MessageLoop::kTaskIdNull) { - MessageLoop::current()->CancelTask(write_watcher_); - write_watcher_ = MessageLoop::kTaskIdNull; - } + write_watcher_ = nullptr; } // Called from the brillo::MessageLoop when the file descriptor is available // for reading. - void OnFileCanReadWithoutBlocking() { + void OnReadable() { CHECK(!read_data_callback_.is_null()); - DataCallback cb = read_data_callback_; - read_data_callback_.Reset(); + + read_watcher_ = nullptr; + DataCallback cb = std::move(read_data_callback_); cb.Run(Stream::AccessMode::READ); } - void OnFileCanWriteWithoutBlocking() { + void OnWritable() { CHECK(!write_data_callback_.is_null()); - DataCallback cb = write_data_callback_; - write_data_callback_.Reset(); + + write_watcher_ = nullptr; + DataCallback cb = std::move(write_data_callback_); cb.Run(Stream::AccessMode::WRITE); } @@ -198,9 +188,9 @@ class FileDescriptor : public FileStream::FileDescriptorInterface { DataCallback read_data_callback_; DataCallback write_data_callback_; - // MessageLoop tasks monitoring read/write operations on the file descriptor. - MessageLoop::TaskId read_watcher_{MessageLoop::kTaskIdNull}; - MessageLoop::TaskId write_watcher_{MessageLoop::kTaskIdNull}; + // Monitoring read/write operations on the file descriptor. + std::unique_ptr<base::FileDescriptorWatcher::Controller> read_watcher_; + std::unique_ptr<base::FileDescriptorWatcher::Controller> write_watcher_; DISALLOW_COPY_AND_ASSIGN(FileDescriptor); }; diff --git a/brillo/streams/file_stream.h b/brillo/streams/file_stream.h index 1cf39b5..bf54617 100644 --- a/brillo/streams/file_stream.h +++ b/brillo/streams/file_stream.h @@ -5,6 +5,8 @@ #ifndef LIBBRILLO_BRILLO_STREAMS_FILE_STREAM_H_ #define LIBBRILLO_BRILLO_STREAMS_FILE_STREAM_H_ +#include <memory> + #include <base/files/file_path.h> #include <base/macros.h> #include <brillo/brillo_export.h> diff --git a/brillo/streams/file_stream_unittest.cc b/brillo/streams/file_stream_test.cc index 210725e..36bad07 100644 --- a/brillo/streams/file_stream_unittest.cc +++ b/brillo/streams/file_stream_test.cc @@ -4,18 +4,20 @@ #include <brillo/streams/file_stream.h> +#include <sys/stat.h> + #include <limits> #include <numeric> #include <string> -#include <sys/stat.h> +#include <utility> #include <vector> +#include <base/bind.h> #include <base/files/file_util.h> #include <base/files/scoped_temp_dir.h> #include <base/message_loop/message_loop.h> #include <base/rand_util.h> #include <base/run_loop.h> -#include <brillo/bind_lambda.h> #include <brillo/errors/error_codes.h> #include <brillo/message_loops/base_message_loop.h> #include <brillo/message_loops/message_loop_utils.h> @@ -23,6 +25,7 @@ #include <gmock/gmock.h> #include <gtest/gtest.h> +using testing::DoAll; using testing::InSequence; using testing::Return; using testing::ReturnArg; @@ -130,20 +133,23 @@ void SetToTrue(bool* target, const Error* /* error */) { // A mock file descriptor wrapper to test low-level file API used by FileStream. class MockFileDescriptor : public FileStream::FileDescriptorInterface { public: - MOCK_CONST_METHOD0(IsOpen, bool()); - MOCK_METHOD2(Read, ssize_t(void*, size_t)); - MOCK_METHOD2(Write, ssize_t(const void*, size_t)); - MOCK_METHOD2(Seek, off64_t(off64_t, int)); - MOCK_CONST_METHOD0(GetFileMode, mode_t()); - MOCK_CONST_METHOD0(GetSize, uint64_t()); - MOCK_CONST_METHOD1(Truncate, int(off64_t)); - MOCK_METHOD0(Flush, int()); - MOCK_METHOD0(Close, int()); - MOCK_METHOD3(WaitForData, - bool(Stream::AccessMode, const DataCallback&, ErrorPtr*)); - MOCK_METHOD3(WaitForDataBlocking, - int(Stream::AccessMode, base::TimeDelta, Stream::AccessMode*)); - MOCK_METHOD0(CancelPendingAsyncOperations, void()); + MOCK_METHOD(bool, IsOpen, (), (const, override)); + MOCK_METHOD(ssize_t, Read, (void*, size_t), (override)); + MOCK_METHOD(ssize_t, Write, (const void*, size_t), (override)); + MOCK_METHOD(off64_t, Seek, (off64_t, int), (override)); + MOCK_METHOD(mode_t, GetFileMode, (), (const, override)); + MOCK_METHOD(uint64_t, GetSize, (), (const, override)); + MOCK_METHOD(int, Truncate, (off64_t), (const, override)); + MOCK_METHOD(int, Close, (), (override)); + MOCK_METHOD(bool, + WaitForData, + (Stream::AccessMode, const DataCallback&, ErrorPtr*), + (override)); + MOCK_METHOD(int, + WaitForDataBlocking, + (Stream::AccessMode, base::TimeDelta, Stream::AccessMode*), + (override)); + MOCK_METHOD(void, CancelPendingAsyncOperations, (), (override)); }; class FileStreamTest : public testing::Test { diff --git a/brillo/streams/input_stream_set.cc b/brillo/streams/input_stream_set.cc index 986efac..847bf05 100644 --- a/brillo/streams/input_stream_set.cc +++ b/brillo/streams/input_stream_set.cc @@ -4,6 +4,8 @@ #include <brillo/streams/input_stream_set.h> +#include <utility> + #include <base/bind.h> #include <brillo/message_loops/message_loop.h> #include <brillo/streams/stream_errors.h> @@ -170,7 +172,7 @@ bool InputStreamSet::WaitForData( return stream->WaitForData(mode, callback, error); } - MessageLoop::current()->PostTask(FROM_HERE, base::Bind(callback, mode)); + MessageLoop::current()->PostTask(FROM_HERE, base::BindOnce(callback, mode)); return true; } diff --git a/brillo/streams/input_stream_set_unittest.cc b/brillo/streams/input_stream_set_test.cc index 3268d96..9a29248 100644 --- a/brillo/streams/input_stream_set_unittest.cc +++ b/brillo/streams/input_stream_set_test.cc @@ -4,13 +4,14 @@ #include <brillo/streams/input_stream_set.h> +#include <memory> + #include <brillo/errors/error_codes.h> #include <brillo/streams/mock_stream.h> #include <brillo/streams/stream_errors.h> #include <gmock/gmock.h> #include <gtest/gtest.h> -using testing::An; using testing::DoAll; using testing::InSequence; using testing::Return; diff --git a/brillo/streams/memory_containers.h b/brillo/streams/memory_containers.h index d3cb205..22488d8 100644 --- a/brillo/streams/memory_containers.h +++ b/brillo/streams/memory_containers.h @@ -6,6 +6,7 @@ #define LIBBRILLO_BRILLO_STREAMS_MEMORY_CONTAINERS_H_ #include <string> +#include <utility> #include <vector> #include <brillo/brillo_export.h> diff --git a/brillo/streams/memory_containers_unittest.cc b/brillo/streams/memory_containers_test.cc index 2f0bf38..8b56ade 100644 --- a/brillo/streams/memory_containers_unittest.cc +++ b/brillo/streams/memory_containers_test.cc @@ -26,14 +26,20 @@ class MockContiguousBuffer : public data_container::ContiguousBufferBase { public: MockContiguousBuffer() = default; - MOCK_METHOD2(Resize, bool(size_t, ErrorPtr*)); - MOCK_CONST_METHOD0(GetSize, size_t()); - MOCK_CONST_METHOD0(IsReadOnly, bool()); - - MOCK_CONST_METHOD2(GetReadOnlyBuffer, const void*(size_t, ErrorPtr*)); - MOCK_METHOD2(GetBuffer, void*(size_t, ErrorPtr*)); - - MOCK_CONST_METHOD3(CopyMemoryBlock, void(void*, const void*, size_t)); + MOCK_METHOD(bool, Resize, (size_t, ErrorPtr*), (override)); + MOCK_METHOD(size_t, GetSize, (), (const, override)); + MOCK_METHOD(bool, IsReadOnly, (), (const, override)); + + MOCK_METHOD(const void*, + GetReadOnlyBuffer, + (size_t, ErrorPtr*), + (const, override)); + MOCK_METHOD(void*, GetBuffer, (size_t, ErrorPtr*), (override)); + + MOCK_METHOD(void, + CopyMemoryBlock, + (void*, const void*, size_t), + (const, override)); private: DISALLOW_COPY_AND_ASSIGN(MockContiguousBuffer); diff --git a/brillo/streams/memory_stream.cc b/brillo/streams/memory_stream.cc index 54f127a..f4f9cca 100644 --- a/brillo/streams/memory_stream.cc +++ b/brillo/streams/memory_stream.cc @@ -185,7 +185,7 @@ bool MemoryStream::CheckContainer(ErrorPtr* error) const { bool MemoryStream::WaitForData(AccessMode mode, const base::Callback<void(AccessMode)>& callback, ErrorPtr* /* error */) { - MessageLoop::current()->PostTask(FROM_HERE, base::Bind(callback, mode)); + MessageLoop::current()->PostTask(FROM_HERE, base::BindOnce(callback, mode)); return true; } diff --git a/brillo/streams/memory_stream.h b/brillo/streams/memory_stream.h index b4927a8..e748f47 100644 --- a/brillo/streams/memory_stream.h +++ b/brillo/streams/memory_stream.h @@ -5,7 +5,9 @@ #ifndef LIBBRILLO_BRILLO_STREAMS_MEMORY_STREAM_H_ #define LIBBRILLO_BRILLO_STREAMS_MEMORY_STREAM_H_ +#include <memory> #include <string> +#include <utility> #include <vector> #include <base/macros.h> diff --git a/brillo/streams/memory_stream_unittest.cc b/brillo/streams/memory_stream_test.cc index 75278f7..28a88fa 100644 --- a/brillo/streams/memory_stream_unittest.cc +++ b/brillo/streams/memory_stream_test.cc @@ -8,6 +8,7 @@ #include <limits> #include <numeric> #include <string> +#include <utility> #include <vector> #include <brillo/streams/stream_errors.h> @@ -32,11 +33,17 @@ class MockMemoryContainer : public data_container::DataContainerInterface { public: MockMemoryContainer() = default; - MOCK_METHOD5(Read, bool(void*, size_t, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD5(Write, bool(const void*, size_t, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD2(Resize, bool(size_t, ErrorPtr*)); - MOCK_CONST_METHOD0(GetSize, size_t()); - MOCK_CONST_METHOD0(IsReadOnly, bool()); + MOCK_METHOD(bool, + Read, + (void*, size_t, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + Write, + (const void*, size_t, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, Resize, (size_t, ErrorPtr*), (override)); + MOCK_METHOD(size_t, GetSize, (), (const, override)); + MOCK_METHOD(bool, IsReadOnly, (), (const, override)); private: DISALLOW_COPY_AND_ASSIGN(MockMemoryContainer); diff --git a/brillo/streams/mock_stream.h b/brillo/streams/mock_stream.h index 934912a..45f83ed 100644 --- a/brillo/streams/mock_stream.h +++ b/brillo/streams/mock_stream.h @@ -16,55 +16,82 @@ class MockStream : public Stream { public: MockStream() = default; - MOCK_CONST_METHOD0(IsOpen, bool()); - MOCK_CONST_METHOD0(CanRead, bool()); - MOCK_CONST_METHOD0(CanWrite, bool()); - MOCK_CONST_METHOD0(CanSeek, bool()); - MOCK_CONST_METHOD0(CanGetSize, bool()); + MOCK_METHOD(bool, IsOpen, (), (const, override)); + MOCK_METHOD(bool, CanRead, (), (const, override)); + MOCK_METHOD(bool, CanWrite, (), (const, override)); + MOCK_METHOD(bool, CanSeek, (), (const, override)); + MOCK_METHOD(bool, CanGetSize, (), (const, override)); - MOCK_CONST_METHOD0(GetSize, uint64_t()); - MOCK_METHOD2(SetSizeBlocking, bool(uint64_t, ErrorPtr*)); - MOCK_CONST_METHOD0(GetRemainingSize, uint64_t()); + MOCK_METHOD(uint64_t, GetSize, (), (const, override)); + MOCK_METHOD(bool, SetSizeBlocking, (uint64_t, ErrorPtr*), (override)); + MOCK_METHOD(uint64_t, GetRemainingSize, (), (const, override)); - MOCK_CONST_METHOD0(GetPosition, uint64_t()); - MOCK_METHOD4(Seek, bool(int64_t, Whence, uint64_t*, ErrorPtr*)); + MOCK_METHOD(uint64_t, GetPosition, (), (const, override)); + MOCK_METHOD(bool, Seek, (int64_t, Whence, uint64_t*, ErrorPtr*), (override)); - MOCK_METHOD5(ReadAsync, bool(void*, - size_t, - const base::Callback<void(size_t)>&, - const ErrorCallback&, - ErrorPtr*)); - MOCK_METHOD5(ReadAllAsync, bool(void*, - size_t, - const base::Closure&, - const ErrorCallback&, - ErrorPtr*)); - MOCK_METHOD5(ReadNonBlocking, bool(void*, size_t, size_t*, bool*, ErrorPtr*)); - MOCK_METHOD4(ReadBlocking, bool(void*, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD3(ReadAllBlocking, bool(void*, size_t, ErrorPtr*)); + MOCK_METHOD(bool, + ReadAsync, + (void*, + size_t, + const base::Callback<void(size_t)>&, + const ErrorCallback&, + ErrorPtr*), + (override)); + MOCK_METHOD( + bool, + ReadAllAsync, + (void*, size_t, const base::Closure&, const ErrorCallback&, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + ReadNonBlocking, + (void*, size_t, size_t*, bool*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + ReadBlocking, + (void*, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, ReadAllBlocking, (void*, size_t, ErrorPtr*), (override)); - MOCK_METHOD5(WriteAsync, bool(const void*, - size_t, - const base::Callback<void(size_t)>&, - const ErrorCallback&, - ErrorPtr*)); - MOCK_METHOD5(WriteAllAsync, bool(const void*, - size_t, - const base::Closure&, - const ErrorCallback&, - ErrorPtr*)); - MOCK_METHOD4(WriteNonBlocking, bool(const void*, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD4(WriteBlocking, bool(const void*, size_t, size_t*, ErrorPtr*)); - MOCK_METHOD3(WriteAllBlocking, bool(const void*, size_t, ErrorPtr*)); + MOCK_METHOD(bool, + WriteAsync, + (const void*, + size_t, + const base::Callback<void(size_t)>&, + const ErrorCallback&, + ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WriteAllAsync, + (const void*, + size_t, + const base::Closure&, + const ErrorCallback&, + ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WriteNonBlocking, + (const void*, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WriteBlocking, + (const void*, size_t, size_t*, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WriteAllBlocking, + (const void*, size_t, ErrorPtr*), + (override)); - MOCK_METHOD1(FlushBlocking, bool(ErrorPtr*)); - MOCK_METHOD1(CloseBlocking, bool(ErrorPtr*)); + MOCK_METHOD(bool, FlushBlocking, (ErrorPtr*), (override)); + MOCK_METHOD(bool, CloseBlocking, (ErrorPtr*), (override)); - MOCK_METHOD3(WaitForData, bool(AccessMode, - const base::Callback<void(AccessMode)>&, - ErrorPtr*)); - MOCK_METHOD4(WaitForDataBlocking, - bool(AccessMode, base::TimeDelta, AccessMode*, ErrorPtr*)); + MOCK_METHOD(bool, + WaitForData, + (AccessMode, const base::Callback<void(AccessMode)>&, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WaitForDataBlocking, + (AccessMode, base::TimeDelta, AccessMode*, ErrorPtr*), + (override)); private: DISALLOW_COPY_AND_ASSIGN(MockStream); diff --git a/brillo/streams/openssl_stream_bio.cc b/brillo/streams/openssl_stream_bio.cc index a63d9c0..478b112 100644 --- a/brillo/streams/openssl_stream_bio.cc +++ b/brillo/streams/openssl_stream_bio.cc @@ -13,9 +13,32 @@ namespace brillo { namespace { +// TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped. +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static void BIO_set_data(BIO* a, void* ptr) { + a->ptr = ptr; +} + +static void* BIO_get_data(BIO* a) { + return a->ptr; +} + +static void BIO_set_init(BIO* a, int init) { + a->init = init; +} + +static int BIO_get_init(BIO* a) { + return a->init; +} + +static void BIO_set_shutdown(BIO* a, int shut) { + a->shutdown = shut; +} +#endif + // Internal functions for implementing OpenSSL BIO on brillo::Stream. int stream_write(BIO* bio, const char* buf, int size) { - brillo::Stream* stream = static_cast<brillo::Stream*>(bio->ptr); + brillo::Stream* stream = static_cast<brillo::Stream*>(BIO_get_data(bio)); size_t written = 0; BIO_clear_retry_flags(bio); if (!stream->WriteNonBlocking(buf, size, &written, nullptr)) @@ -30,7 +53,7 @@ int stream_write(BIO* bio, const char* buf, int size) { } int stream_read(BIO* bio, char* buf, int size) { - brillo::Stream* stream = static_cast<brillo::Stream*>(bio->ptr); + brillo::Stream* stream = static_cast<brillo::Stream*>(BIO_get_data(bio)); size_t read = 0; BIO_clear_retry_flags(bio); bool eos = false; @@ -49,16 +72,16 @@ int stream_read(BIO* bio, char* buf, int size) { // NOLINTNEXTLINE(runtime/int) long stream_ctrl(BIO* bio, int cmd, long /* num */, void* /* ptr */) { if (cmd == BIO_CTRL_FLUSH) { - brillo::Stream* stream = static_cast<brillo::Stream*>(bio->ptr); + brillo::Stream* stream = static_cast<brillo::Stream*>(BIO_get_data(bio)); return stream->FlushBlocking(nullptr) ? 1 : 0; } return 0; } int stream_new(BIO* bio) { - bio->shutdown = 0; // By default do not close underlying stream on shutdown. - bio->init = 0; - bio->num = -1; // not used. + // By default do not close underlying stream on shutdown. + BIO_set_shutdown(bio, 0); + BIO_set_init(bio, 0); return 1; } @@ -66,13 +89,17 @@ int stream_free(BIO* bio) { if (!bio) return 0; - if (bio->init) { - bio->ptr = nullptr; - bio->init = 0; + if (BIO_get_init(bio)) { + BIO_set_data(bio, nullptr); + BIO_set_init(bio, 0); } return 1; } +#if OPENSSL_VERSION_NUMBER < 0x10100000L +// TODO(crbug.com/984789): Remove #ifdef once support for OpenSSL <1.1 is +// dropped. + // BIO_METHOD structure describing the BIO built on top of brillo::Stream. BIO_METHOD stream_method = { 0x7F | BIO_TYPE_SOURCE_SINK, // type: 0x7F is an arbitrary unused type ID. @@ -87,13 +114,37 @@ BIO_METHOD stream_method = { nullptr, // callback function, not used }; +BIO_METHOD* stream_get_method() { + return &stream_method; +} + +#else + +BIO_METHOD* stream_get_method() { + static BIO_METHOD* stream_method; + + if (!stream_method) { + stream_method = BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, + "stream"); + BIO_meth_set_write(stream_method, stream_write); + BIO_meth_set_read(stream_method, stream_read); + BIO_meth_set_ctrl(stream_method, stream_ctrl); + BIO_meth_set_create(stream_method, stream_new); + BIO_meth_set_destroy(stream_method, stream_free); + } + + return stream_method; +} + +#endif + } // anonymous namespace BIO* BIO_new_stream(brillo::Stream* stream) { - BIO* bio = BIO_new(&stream_method); + BIO* bio = BIO_new(stream_get_method()); if (bio) { - bio->ptr = stream; - bio->init = 1; + BIO_set_data(bio, stream); + BIO_set_init(bio, 1); } return bio; } diff --git a/brillo/streams/openssl_stream_bio_unittests.cc b/brillo/streams/openssl_stream_bio_test.cc index a80710d..a80710d 100644 --- a/brillo/streams/openssl_stream_bio_unittests.cc +++ b/brillo/streams/openssl_stream_bio_test.cc diff --git a/brillo/streams/stream.cc b/brillo/streams/stream.cc index 6a40c00..80f2df4 100644 --- a/brillo/streams/stream.cc +++ b/brillo/streams/stream.cc @@ -213,8 +213,9 @@ bool Stream::ReadAsyncImpl( if (force_async_callback) { MessageLoop::current()->PostTask( FROM_HERE, - base::Bind(&Stream::OnReadAsyncDone, weak_ptr_factory_.GetWeakPtr(), - success_callback, read, eos)); + base::BindOnce(&Stream::OnReadAsyncDone, + weak_ptr_factory_.GetWeakPtr(), + success_callback, read, eos)); } else { is_async_read_pending_ = false; success_callback.Run(read, eos); @@ -277,8 +278,9 @@ bool Stream::WriteAsyncImpl( if (force_async_callback) { MessageLoop::current()->PostTask( FROM_HERE, - base::Bind(&Stream::OnWriteAsyncDone, weak_ptr_factory_.GetWeakPtr(), - success_callback, written)); + base::BindOnce(&Stream::OnWriteAsyncDone, + weak_ptr_factory_.GetWeakPtr(), + success_callback, written)); } else { is_async_write_pending_ = false; success_callback.Run(written); diff --git a/brillo/streams/stream_unittest.cc b/brillo/streams/stream_test.cc index c341cde..8cb99a9 100644 --- a/brillo/streams/stream_unittest.cc +++ b/brillo/streams/stream_test.cc @@ -6,11 +6,11 @@ #include <limits> +#include <base/bind.h> #include <base/callback.h> #include <gmock/gmock.h> #include <gtest/gtest.h> -#include <brillo/bind_lambda.h> #include <brillo/message_loops/fake_message_loop.h> #include <brillo/streams/stream_errors.h> @@ -42,39 +42,48 @@ class MockStreamImpl : public Stream { public: MockStreamImpl() = default; - MOCK_CONST_METHOD0(IsOpen, bool()); - MOCK_CONST_METHOD0(CanRead, bool()); - MOCK_CONST_METHOD0(CanWrite, bool()); - MOCK_CONST_METHOD0(CanSeek, bool()); - MOCK_CONST_METHOD0(CanGetSize, bool()); + MOCK_METHOD(bool, IsOpen, (), (const, override)); + MOCK_METHOD(bool, CanRead, (), (const, override)); + MOCK_METHOD(bool, CanWrite, (), (const, override)); + MOCK_METHOD(bool, CanSeek, (), (const, override)); + MOCK_METHOD(bool, CanGetSize, (), (const, override)); - MOCK_CONST_METHOD0(GetSize, uint64_t()); - MOCK_METHOD2(SetSizeBlocking, bool(uint64_t, ErrorPtr*)); - MOCK_CONST_METHOD0(GetRemainingSize, uint64_t()); + MOCK_METHOD(uint64_t, GetSize, (), (const, override)); + MOCK_METHOD(bool, SetSizeBlocking, (uint64_t, ErrorPtr*), (override)); + MOCK_METHOD(uint64_t, GetRemainingSize, (), (const, override)); - MOCK_CONST_METHOD0(GetPosition, uint64_t()); - MOCK_METHOD4(Seek, bool(int64_t, Whence, uint64_t*, ErrorPtr*)); + MOCK_METHOD(uint64_t, GetPosition, (), (const, override)); + MOCK_METHOD(bool, Seek, (int64_t, Whence, uint64_t*, ErrorPtr*), (override)); // Omitted: ReadAsync // Omitted: ReadAllAsync - MOCK_METHOD5(ReadNonBlocking, bool(void*, size_t, size_t*, bool*, ErrorPtr*)); + MOCK_METHOD(bool, + ReadNonBlocking, + (void*, size_t, size_t*, bool*, ErrorPtr*), + (override)); // Omitted: ReadBlocking // Omitted: ReadAllBlocking // Omitted: WriteAsync // Omitted: WriteAllAsync - MOCK_METHOD4(WriteNonBlocking, bool(const void*, size_t, size_t*, ErrorPtr*)); + MOCK_METHOD(bool, + WriteNonBlocking, + (const void*, size_t, size_t*, ErrorPtr*), + (override)); // Omitted: WriteBlocking // Omitted: WriteAllBlocking - MOCK_METHOD1(FlushBlocking, bool(ErrorPtr*)); - MOCK_METHOD1(CloseBlocking, bool(ErrorPtr*)); + MOCK_METHOD(bool, FlushBlocking, (ErrorPtr*), (override)); + MOCK_METHOD(bool, CloseBlocking, (ErrorPtr*), (override)); - MOCK_METHOD3(WaitForData, bool(AccessMode, - const base::Callback<void(AccessMode)>&, - ErrorPtr*)); - MOCK_METHOD4(WaitForDataBlocking, - bool(AccessMode, base::TimeDelta, AccessMode*, ErrorPtr*)); + MOCK_METHOD(bool, + WaitForData, + (AccessMode, const base::Callback<void(AccessMode)>&, ErrorPtr*), + (override)); + MOCK_METHOD(bool, + WaitForDataBlocking, + (AccessMode, base::TimeDelta, AccessMode*, ErrorPtr*), + (override)); private: DISALLOW_COPY_AND_ASSIGN(MockStreamImpl); @@ -333,7 +342,10 @@ TEST(Stream, ReadBlocking) { TEST(Stream, ReadAllBlocking) { class MockReadBlocking : public MockStreamImpl { public: - MOCK_METHOD4(ReadBlocking, bool(void*, size_t, size_t*, ErrorPtr*)); + MOCK_METHOD(bool, + ReadBlocking, + (void*, size_t, size_t*, ErrorPtr*), + (override)); } stream_mock; char buf[1024]; @@ -471,7 +483,10 @@ TEST(Stream, WriteBlocking) { TEST(Stream, WriteAllBlocking) { class MockWritelocking : public MockStreamImpl { public: - MOCK_METHOD4(WriteBlocking, bool(const void*, size_t, size_t*, ErrorPtr*)); + MOCK_METHOD(bool, + WriteBlocking, + (const void*, size_t, size_t*, ErrorPtr*), + (override)); } stream_mock; char buf[1024]; diff --git a/brillo/streams/stream_utils.cc b/brillo/streams/stream_utils.cc index 3f7a14a..5029e3a 100644 --- a/brillo/streams/stream_utils.cc +++ b/brillo/streams/stream_utils.cc @@ -4,7 +4,11 @@ #include <brillo/streams/stream_utils.h> +#include <algorithm> #include <limits> +#include <memory> +#include <utility> +#include <vector> #include <base/bind.h> #include <brillo/message_loops/message_loop.h> @@ -209,7 +213,7 @@ void CopyData(StreamPtr in_stream, state->success_callback = success_callback; state->error_callback = error_callback; brillo::MessageLoop::current()->PostTask(FROM_HERE, - base::Bind(&PerformRead, state)); + base::BindOnce(&PerformRead, state)); } } // namespace stream_utils diff --git a/brillo/streams/stream_utils_unittest.cc b/brillo/streams/stream_utils_test.cc index f27d233..50fb67e 100644 --- a/brillo/streams/stream_utils_unittest.cc +++ b/brillo/streams/stream_utils_test.cc @@ -5,6 +5,9 @@ #include <brillo/streams/stream_utils.h> #include <limits> +#include <memory> +#include <string> +#include <utility> #include <base/bind.h> #include <brillo/message_loops/fake_message_loop.h> @@ -14,9 +17,7 @@ #include <gmock/gmock.h> #include <gtest/gtest.h> -using testing::DoAll; using testing::InSequence; -using testing::Return; using testing::StrictMock; using testing::_; @@ -24,7 +25,7 @@ ACTION_TEMPLATE(InvokeAsyncCallback, HAS_1_TEMPLATE_PARAMS(int, k), AND_1_VALUE_PARAMS(size)) { brillo::MessageLoop::current()->PostTask( - FROM_HERE, base::Bind(std::get<k>(args), size)); + FROM_HERE, base::BindOnce(std::get<k>(args), size)); return true; } @@ -41,7 +42,8 @@ ACTION_TEMPLATE(InvokeAsyncErrorCallback, brillo::ErrorPtr error; brillo::Error::AddTo(&error, FROM_HERE, "test", code, "message"); brillo::MessageLoop::current()->PostTask( - FROM_HERE, base::Bind(std::get<k>(args), base::Owned(error.release()))); + FROM_HERE, base::BindOnce(std::get<k>(args), + base::Owned(error.release()))); return true; } diff --git a/brillo/streams/tls_stream.cc b/brillo/streams/tls_stream.cc index fde4193..4b8a227 100644 --- a/brillo/streams/tls_stream.cc +++ b/brillo/streams/tls_stream.cc @@ -7,6 +7,7 @@ #include <algorithm> #include <limits> #include <string> +#include <utility> #include <vector> #include <openssl/err.h> @@ -67,6 +68,11 @@ const char kCACertificatePath[] = namespace brillo { +// TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped. +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#define TLS_client_method() TLSv1_2_client_method() +#endif + // Helper implementation of TLS stream used to hide most of OpenSSL inner // workings from the users of brillo::TlsStream. class TlsStream::TlsStreamImpl { @@ -341,7 +347,7 @@ bool TlsStream::TlsStreamImpl::Init(StreamPtr socket, const base::Closure& success_callback, const Stream::ErrorCallback& error_callback, ErrorPtr* error) { - ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); + ctx_.reset(SSL_CTX_new(TLS_client_method())); if (!ctx_) return ReportError(error, FROM_HERE, "Cannot create SSL_CTX"); @@ -387,10 +393,10 @@ bool TlsStream::TlsStreamImpl::Init(StreamPtr socket, if (MessageLoop::ThreadHasCurrent()) { MessageLoop::current()->PostTask( FROM_HERE, - base::Bind(&TlsStreamImpl::DoHandshake, - weak_ptr_factory_.GetWeakPtr(), - success_callback, - error_callback)); + base::BindOnce(&TlsStreamImpl::DoHandshake, + weak_ptr_factory_.GetWeakPtr(), + success_callback, + error_callback)); } else { DoHandshake(success_callback, error_callback); } diff --git a/brillo/strings/string_utils_unittest.cc b/brillo/strings/string_utils_test.cc index c554e74..c554e74 100644 --- a/brillo/strings/string_utils_unittest.cc +++ b/brillo/strings/string_utils_test.cc diff --git a/brillo/syslog_logging_unittest.cc b/brillo/syslog_logging_test.cc index e852e50..e852e50 100644 --- a/brillo/syslog_logging_unittest.cc +++ b/brillo/syslog_logging_test.cc diff --git a/brillo/timezone/EST_test.tzif b/brillo/timezone/EST_test.tzif Binary files differnew file mode 100644 index 0000000..ae34663 --- /dev/null +++ b/brillo/timezone/EST_test.tzif diff --git a/brillo/timezone/Indian_Christmas_test.tzif b/brillo/timezone/Indian_Christmas_test.tzif Binary files differnew file mode 100644 index 0000000..066c1e9 --- /dev/null +++ b/brillo/timezone/Indian_Christmas_test.tzif diff --git a/brillo/timezone/Pacific_Fiji_test.tzif b/brillo/timezone/Pacific_Fiji_test.tzif Binary files differnew file mode 100644 index 0000000..76ae63e --- /dev/null +++ b/brillo/timezone/Pacific_Fiji_test.tzif diff --git a/brillo/timezone/tzif_parser.cc b/brillo/timezone/tzif_parser.cc new file mode 100644 index 0000000..5756196 --- /dev/null +++ b/brillo/timezone/tzif_parser.cc @@ -0,0 +1,164 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "brillo/timezone/tzif_parser.h" + +#include <arpa/inet.h> +#include <stdint.h> +#include <string.h> +#include <utility> +#include <vector> + +#include <base/files/file.h> +#include <base/files/file_path.h> +#include <base/logging.h> +#include <base/stl_util.h> +#include <base/strings/string_util.h> + +namespace { + +struct tzif_header { + char magic[4]; + char version; + char reserved[15]; + int32_t ttisgmtcnt; + int32_t ttisstdcnt; + int32_t leapcnt; + int32_t timecnt; + int32_t typecnt; + int32_t charcnt; +}; + +bool ReadInt(base::File* file, int32_t* out_int) { + DCHECK(out_int); + int32_t buf; + int read = file->ReadAtCurrentPos(reinterpret_cast<char*>(&buf), sizeof(buf)); + if (read != sizeof(buf)) { + return false; + } + // Values are stored in network byte order (highest-order byte first). + // We probably need to convert them to match the endianness of our system. + *out_int = ntohl(buf); + return true; +} + +bool ParseTzifHeader(base::File* tzfile, struct tzif_header* header) { + DCHECK(header); + int read = tzfile->ReadAtCurrentPos(header->magic, sizeof(header->magic)); + if (read != sizeof(header->magic)) { + return false; + } + if (memcmp(header->magic, "TZif", 4) != 0) { + return false; + } + + read = tzfile->ReadAtCurrentPos(&header->version, sizeof(header->version)); + if (read != sizeof(header->version)) { + return false; + } + if (header->version != '\0' && header->version != '2' && + header->version != '3') { + return false; + } + + read = tzfile->ReadAtCurrentPos(header->reserved, sizeof(header->reserved)); + if (read != sizeof(header->reserved)) { + return false; + } + for (size_t i = 0; i < sizeof(header->reserved); i++) { + if (header->reserved[i] != 0) { + return false; + } + } + + if (!ReadInt(tzfile, &header->ttisgmtcnt) || header->ttisgmtcnt < 0) { + return false; + } + if (!ReadInt(tzfile, &header->ttisstdcnt) || header->ttisstdcnt < 0) { + return false; + } + if (!ReadInt(tzfile, &header->leapcnt) || header->leapcnt < 0) { + return false; + } + if (!ReadInt(tzfile, &header->timecnt) || header->timecnt < 0) { + return false; + } + if (!ReadInt(tzfile, &header->typecnt) || header->typecnt < 0) { + return false; + } + if (!ReadInt(tzfile, &header->charcnt) || header->charcnt < 0) { + return false; + } + return true; +} + +} // namespace + +namespace brillo { + +namespace timezone { + +base::Optional<std::string> GetPosixTimezone(const base::FilePath& tzif_path) { + base::FilePath to_parse; + if (tzif_path.IsAbsolute()) { + to_parse = tzif_path; + } else { + to_parse = base::FilePath("/usr/share/zoneinfo").Append(tzif_path); + } + base::File tzfile(to_parse, base::File::FLAG_OPEN | base::File::FLAG_READ); + struct tzif_header first_header; + if (!tzfile.IsValid() || !ParseTzifHeader(&tzfile, &first_header)) { + return base::nullopt; + } + + if (first_header.version == '\0') { + // Generating a POSIX-style TZ string from a TZif version 1 file is hard; + // TZ strings need relative dates (1st Sunday in March, 1st Sunday in Nov, + // etc.), but the version 1 files only give absolute dates for each + // year up to 2037. Fortunately version 1 files are no longer created by + // the newer versions of the timezone-data package. + // + // Because of this, we're not going to try and handle this, and instead just + // return an error. + return base::nullopt; + } + + // TZif versions 2 and 3 embed a POSIX-style TZ string after their + // contents. We read that out and return it. + + // Skip past the first body section and all of the second section. See + // 'man tzfile' for an explanation of these offset values. + int64_t first_body_size = + 4 * first_header.timecnt + 1 * first_header.timecnt + + (4 + 1 + 1) * first_header.typecnt + 1 * first_header.charcnt + + (4 + 4) * first_header.leapcnt + 1 * first_header.ttisstdcnt + + 1 * first_header.ttisgmtcnt; + tzfile.Seek(base::File::FROM_CURRENT, first_body_size); + + struct tzif_header second_header; + if (!ParseTzifHeader(&tzfile, &second_header)) { + return base::nullopt; + } + + int64_t second_body_size = + 8 * second_header.timecnt + 1 * second_header.timecnt + + (4 + 1 + 1) * second_header.typecnt + 1 * second_header.charcnt + + (8 + 4) * second_header.leapcnt + 1 * second_header.ttisstdcnt + + 1 * second_header.ttisgmtcnt; + int64_t offset = tzfile.Seek(base::File::FROM_CURRENT, second_body_size); + + std::string time_string(tzfile.GetLength() - offset, '\0'); + if (tzfile.ReadAtCurrentPos(base::data(time_string), time_string.size()) != + time_string.size()) { + return base::nullopt; + } + + // According to the spec, the embedded string is enclosed by '\n' characters. + base::TrimWhitespaceASCII(time_string, base::TRIM_ALL, &time_string); + return std::move(time_string); +} + +} // namespace timezone + +} // namespace brillo diff --git a/brillo/timezone/tzif_parser.h b/brillo/timezone/tzif_parser.h new file mode 100644 index 0000000..058079c --- /dev/null +++ b/brillo/timezone/tzif_parser.h @@ -0,0 +1,29 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_TIMEZONE_TZIF_PARSER_H_ +#define LIBBRILLO_BRILLO_TIMEZONE_TZIF_PARSER_H_ + +#include <string> + +#include <base/files/file_path.h> +#include <base/optional.h> +#include <brillo/brillo_export.h> + +namespace brillo { + +namespace timezone { + +// GetPosixTimezone takes a path to a tzfile, and returns the POSIX timezone in +// a string. See 'man tzfile' for more info on the format. If |tzif_path| is a +// relative path, it will be appended to /usr/share/zoneinfo/, otherwise +// |tzif_path| as an absolute path will be used directly. +base::Optional<std::string> BRILLO_EXPORT GetPosixTimezone( + const base::FilePath& tzif_path); + +} // namespace timezone + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_TIMEZONE_TZIF_PARSER_H_ diff --git a/brillo/timezone/tzif_parser_test.cc b/brillo/timezone/tzif_parser_test.cc new file mode 100644 index 0000000..305da4d --- /dev/null +++ b/brillo/timezone/tzif_parser_test.cc @@ -0,0 +1,46 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <stdlib.h> + +#include <base/files/file_path.h> +#include <gtest/gtest.h> + +#include "brillo/timezone/tzif_parser.h" + +namespace brillo { + +namespace timezone { + +class TzifParserTest : public ::testing::Test { + public: + TzifParserTest() { + source_dir_ = + base::FilePath(getenv("SRC")).Append("brillo").Append("timezone"); + } + + protected: + base::FilePath source_dir_; +}; + +TEST_F(TzifParserTest, EST) { + auto posix_result = GetPosixTimezone(source_dir_.Append("EST_test.tzif")); + EXPECT_EQ(posix_result, "EST5"); +} + +TEST_F(TzifParserTest, TzifVersionTwo) { + auto posix_result = + GetPosixTimezone(source_dir_.Append("Indian_Christmas_test.tzif")); + EXPECT_EQ(posix_result, "<+07>-7"); +} + +TEST_F(TzifParserTest, TzifVersionThree) { + auto posix_result = + GetPosixTimezone(source_dir_.Append("Pacific_Fiji_test.tzif")); + EXPECT_EQ(posix_result, "<+12>-12<+13>,M11.1.0,M1.2.2/123"); +} + +} // namespace timezone + +} // namespace brillo diff --git a/brillo/type_list.h b/brillo/type_list.h new file mode 100644 index 0000000..b14ef1e --- /dev/null +++ b/brillo/type_list.h @@ -0,0 +1,69 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_TYPE_LIST_H_ +#define LIBBRILLO_BRILLO_TYPE_LIST_H_ + +#include <type_traits> + +namespace brillo { + +template <typename... Ts> +struct TypeList {}; + +namespace type_list { + +template <typename... Ts> +struct is_one_of { + static constexpr bool value = false; +}; + +template <typename T, typename Head, typename... Tail> +struct is_one_of<T, TypeList<Head, Tail...>> { + static constexpr bool value = + std::is_same<T, Head>::value || is_one_of<T, TypeList<Tail...>>::value; +}; + +} // namespace type_list + +// Enables a template if the type T is in the typelist Types. Since std::same is +// used to determine equivalence of types, cv-qualifiers (const and volatile) +// *are* important. Note that typedefs and type aliases do not define new types. +// +// Example: +// using ValidTypes = TypeList<int32_t, float>; +// +// template <typename T, typename = EnableIfIsOneOf<T, ValidTypes>> +// void f(){} +// +// using integer = int32_t; +// ... +// f<int32_t>(); // Fine. +// f<float>(); // Fine. +// f<integer>(); // Fine. +// f<const int32_t>(); // Error; no matching function for call to 'f'. +// f<uint32_t>(); // Error; no matching function for call to 'f'. +template <typename T, typename Types> +using EnableIfIsOneOf = + std::enable_if_t<type_list::is_one_of<T, Types>::value>; + +// Enables a template if the type T is in the typelist Types and T is an +// arithmetic type (some sort of int or floating-point number). +template <typename T, typename Types> +using EnableIfIsOneOfArithmetic = + std::enable_if_t<std::is_arithmetic<T>::value && + type_list::is_one_of<T, Types>::value, + int>; + +// Enables a template if the type T is in the typelist Types and T is not an +// arithmetic type (is void, nullptr_t, or a non-fundamental type). +template <typename T, typename Types> +using EnableIfIsOneOfNonArithmetic = + std::enable_if_t<!std::is_arithmetic<T>::value && + type_list::is_one_of<T, Types>::value, + int>; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_TYPE_LIST_H_ diff --git a/brillo/type_name_undecorate.cc b/brillo/type_name_undecorate.cc index b588170..b24a746 100644 --- a/brillo/type_name_undecorate.cc +++ b/brillo/type_name_undecorate.cc @@ -5,6 +5,8 @@ #include <brillo/type_name_undecorate.h> #include <cstring> +#include <map> +#include <string> #ifdef __GNUG__ #include <cstdlib> diff --git a/brillo/type_name_undecorate_unittest.cc b/brillo/type_name_undecorate_test.cc index 604c0fb..a41c6cd 100644 --- a/brillo/type_name_undecorate_unittest.cc +++ b/brillo/type_name_undecorate_test.cc @@ -4,6 +4,8 @@ #include <brillo/type_name_undecorate.h> +#include <map> + #include <brillo/variant_dictionary.h> #include <gtest/gtest.h> diff --git a/brillo/udev/OWNERS b/brillo/udev/OWNERS new file mode 100644 index 0000000..f426deb --- /dev/null +++ b/brillo/udev/OWNERS @@ -0,0 +1,3 @@ +amistry@chromium.org +ejcaruso@chromium.org +wbbradley@chromium.org diff --git a/brillo/udev/mock_udev.h b/brillo/udev/mock_udev.h new file mode 100644 index 0000000..8494bab --- /dev/null +++ b/brillo/udev/mock_udev.h @@ -0,0 +1,48 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev.h> +#include <brillo/udev/udev_device.h> +#include <brillo/udev/udev_enumerate.h> +#include <brillo/udev/udev_monitor.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdev : public Udev { + public: + MockUdev() : Udev(nullptr) {} + ~MockUdev() override = default; + + MOCK_METHOD(std::unique_ptr<UdevDevice>, + CreateDeviceFromSysPath, + (const char*), + (override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, + CreateDeviceFromDeviceNumber, + (char, dev_t), + (override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, + CreateDeviceFromSubsystemSysName, + (const char*, const char*), + (override)); + MOCK_METHOD(std::unique_ptr<UdevEnumerate>, CreateEnumerate, (), (override)); + MOCK_METHOD(std::unique_ptr<UdevMonitor>, + CreateMonitorFromNetlink, + (const char*), + (override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdev); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_H_ diff --git a/brillo/udev/mock_udev_device.h b/brillo/udev/mock_udev_device.h new file mode 100644 index 0000000..6e812d1 --- /dev/null +++ b/brillo/udev/mock_udev_device.h @@ -0,0 +1,68 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_DEVICE_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_DEVICE_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_device.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdevDevice : public UdevDevice { + public: + MockUdevDevice() = default; + ~MockUdevDevice() override = default; + + MOCK_METHOD(std::unique_ptr<UdevDevice>, GetParent, (), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, + GetParentWithSubsystemDeviceType, + (const char*, const char*), + (const, override)); + MOCK_METHOD(bool, IsInitialized, (), (const, override)); + MOCK_METHOD(uint64_t, GetMicrosecondsSinceInitialized, (), (const, override)); + MOCK_METHOD(uint64_t, GetSequenceNumber, (), (const, override)); + MOCK_METHOD(const char*, GetDevicePath, (), (const, override)); + MOCK_METHOD(const char*, GetDeviceNode, (), (const, override)); + MOCK_METHOD(dev_t, GetDeviceNumber, (), (const, override)); + MOCK_METHOD(const char*, GetDeviceType, (), (const, override)); + MOCK_METHOD(const char*, GetDriver, (), (const, override)); + MOCK_METHOD(const char*, GetSubsystem, (), (const, override)); + MOCK_METHOD(const char*, GetSysPath, (), (const, override)); + MOCK_METHOD(const char*, GetSysName, (), (const, override)); + MOCK_METHOD(const char*, GetSysNumber, (), (const, override)); + MOCK_METHOD(const char*, GetAction, (), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetDeviceLinksListEntry, + (), + (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetPropertiesListEntry, + (), + (const, override)); + MOCK_METHOD(const char*, GetPropertyValue, (const char*), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetTagsListEntry, + (), + (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetSysAttributeListEntry, + (), + (const, override)); + MOCK_METHOD(const char*, + GetSysAttributeValue, + (const char*), + (const, override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, Clone, (), (override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdevDevice); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_DEVICE_H_ diff --git a/brillo/udev/mock_udev_enumerate.h b/brillo/udev/mock_udev_enumerate.h new file mode 100644 index 0000000..faf94fc --- /dev/null +++ b/brillo/udev/mock_udev_enumerate.h @@ -0,0 +1,49 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_ENUMERATE_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_ENUMERATE_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_enumerate.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdevEnumerate : public UdevEnumerate { + public: + MockUdevEnumerate() = default; + ~MockUdevEnumerate() override = default; + + MOCK_METHOD(bool, AddMatchSubsystem, (const char*), (override)); + MOCK_METHOD(bool, AddNoMatchSubsystem, (const char*), (override)); + MOCK_METHOD(bool, + AddMatchSysAttribute, + (const char*, const char*), + (override)); + MOCK_METHOD(bool, + AddNoMatchSysAttribute, + (const char*, const char*), + (override)); + MOCK_METHOD(bool, AddMatchProperty, (const char*, const char*), (override)); + MOCK_METHOD(bool, AddMatchSysName, (const char*), (override)); + MOCK_METHOD(bool, AddMatchTag, (const char*), (override)); + MOCK_METHOD(bool, AddMatchIsInitialized, (), (override)); + MOCK_METHOD(bool, AddSysPath, (const char*), (override)); + MOCK_METHOD(bool, ScanDevices, (), (override)); + MOCK_METHOD(bool, ScanSubsystems, (), (override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetListEntry, + (), + (const, override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdevEnumerate); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_ENUMERATE_H_ diff --git a/brillo/udev/mock_udev_list_entry.h b/brillo/udev/mock_udev_list_entry.h new file mode 100644 index 0000000..255b6e2 --- /dev/null +++ b/brillo/udev/mock_udev_list_entry.h @@ -0,0 +1,35 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_LIST_ENTRY_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_LIST_ENTRY_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_list_entry.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdevListEntry : public UdevListEntry { + public: + MockUdevListEntry() = default; + ~MockUdevListEntry() override = default; + + MOCK_METHOD(std::unique_ptr<UdevListEntry>, GetNext, (), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevListEntry>, + GetByName, + (const char*), + (const, override)); + MOCK_METHOD(const char*, GetName, (), (const, override)); + MOCK_METHOD(const char*, GetValue, (), (const, override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdevListEntry); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_LIST_ENTRY_H_ diff --git a/brillo/udev/mock_udev_monitor.h b/brillo/udev/mock_udev_monitor.h new file mode 100644 index 0000000..5854327 --- /dev/null +++ b/brillo/udev/mock_udev_monitor.h @@ -0,0 +1,38 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_MONITOR_H_ +#define LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_MONITOR_H_ + +#include <memory> + +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_monitor.h> +#include <gmock/gmock.h> + +namespace brillo { + +class BRILLO_EXPORT MockUdevMonitor : public UdevMonitor { + public: + MockUdevMonitor() = default; + ~MockUdevMonitor() override = default; + + MOCK_METHOD(bool, EnableReceiving, (), (override)); + MOCK_METHOD(int, GetFileDescriptor, (), (const, override)); + MOCK_METHOD(std::unique_ptr<UdevDevice>, ReceiveDevice, (), (override)); + MOCK_METHOD(bool, + FilterAddMatchSubsystemDeviceType, + (const char*, const char*), + (override)); + MOCK_METHOD(bool, FilterAddMatchTag, (const char*), (override)); + MOCK_METHOD(bool, FilterUpdate, (), (override)); + MOCK_METHOD(bool, FilterRemove, (), (override)); + + private: + DISALLOW_COPY_AND_ASSIGN(MockUdevMonitor); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_MOCK_UDEV_MONITOR_H_ diff --git a/brillo/udev/udev.cc b/brillo/udev/udev.cc new file mode 100644 index 0000000..78f9d72 --- /dev/null +++ b/brillo/udev/udev.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev.h> + +#include <libudev.h> + +#include <base/format_macros.h> +#include <base/logging.h> +#include <base/strings/stringprintf.h> +#include <brillo/udev/udev_device.h> +#include <brillo/udev/udev_enumerate.h> +#include <brillo/udev/udev_monitor.h> + +using base::StringPrintf; + +namespace brillo { + +Udev::Udev(struct udev* udev) : udev_(udev) {} + +Udev::~Udev() { + if (udev_) { + udev_unref(udev_); + udev_ = nullptr; + } +} + +// static +std::unique_ptr<Udev> Udev::Create() { + struct udev* udev = udev_new(); + if (!udev) + return nullptr; + + return std::unique_ptr<Udev>(new Udev(udev)); +} + +// static +std::unique_ptr<UdevDevice> Udev::CreateDevice(udev_device* device) { + auto device_to_return = std::make_unique<UdevDevice>(device); + + // UdevDevice increases the reference count of the udev_device struct by one. + // Thus, decrease the reference count of the udev_device struct by one before + // returning UdevDevice. + udev_device_unref(device); + + return device_to_return; +} + +std::unique_ptr<UdevDevice> Udev::CreateDeviceFromSysPath( + const char* sys_path) { + udev_device* device = udev_device_new_from_syspath(udev_, sys_path); + if (device) + return CreateDevice(device); + + VLOG(2) << StringPrintf( + "udev_device_new_from_syspath" + "(%p, \"%s\") returned nullptr.", + udev_, sys_path); + return nullptr; +} + +std::unique_ptr<UdevDevice> Udev::CreateDeviceFromDeviceNumber( + char type, dev_t device_number) { + udev_device* device = udev_device_new_from_devnum(udev_, type, device_number); + if (device) + return CreateDevice(device); + + VLOG(2) << StringPrintf( + "udev_device_new_from_devnum" + "(%p, %d, %" PRIu64 ") returned nullptr.", + udev_, type, device_number); + return nullptr; +} + +std::unique_ptr<UdevDevice> Udev::CreateDeviceFromSubsystemSysName( + const char* subsystem, const char* sys_name) { + udev_device* device = + udev_device_new_from_subsystem_sysname(udev_, subsystem, sys_name); + if (device) + return CreateDevice(device); + + VLOG(2) << StringPrintf( + "udev_device_new_from_subsystem_sysname" + "(%p, \"%s\", \"%s\") returned nullptr.", + udev_, subsystem, sys_name); + return nullptr; +} + +std::unique_ptr<UdevEnumerate> Udev::CreateEnumerate() { + udev_enumerate* enumerate = udev_enumerate_new(udev_); + if (enumerate) { + auto enumerate_to_return = std::make_unique<UdevEnumerate>(enumerate); + + // UdevEnumerate increases the reference count of the udev_enumerate struct + // by one. Thus, decrease the reference count of the udev_enumerate struct + // by one before returning UdevEnumerate. + udev_enumerate_unref(enumerate); + + return enumerate_to_return; + } + + VLOG(2) << StringPrintf("udev_enumerate_new(%p) returned nullptr.", udev_); + return nullptr; +} + +std::unique_ptr<UdevMonitor> Udev::CreateMonitorFromNetlink(const char* name) { + udev_monitor* monitor = udev_monitor_new_from_netlink(udev_, name); + if (monitor) { + auto monitor_to_return = std::make_unique<UdevMonitor>(monitor); + + // UdevMonitor increases the reference count of the udev_monitor struct by + // one. Thus, decrease the reference count of the udev_monitor struct by one + // before returning UdevMonitor. + udev_monitor_unref(monitor); + + return monitor_to_return; + } + + VLOG(2) << StringPrintf( + "udev_monitor_new_from_netlink" + "(%p, \"%s\") returned nullptr.", + udev_, name); + return nullptr; +} + +} // namespace brillo diff --git a/brillo/udev/udev.h b/brillo/udev/udev.h new file mode 100644 index 0000000..b2c6c60 --- /dev/null +++ b/brillo/udev/udev.h @@ -0,0 +1,69 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_H_ + +#include <sys/types.h> + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> + +struct udev; +struct udev_device; + +namespace brillo { + +class UdevDevice; +class UdevEnumerate; +class UdevMonitor; + +// A udev library context, which wraps a udev C struct from libudev and related +// library functions into a C++ object. +class BRILLO_EXPORT Udev { + public: + // Creates and initializes a Udev object. Returns nullptr on failure. + static std::unique_ptr<Udev> Create(); + virtual ~Udev(); + + // Wraps udev_device_new_from_syspath(). + virtual std::unique_ptr<UdevDevice> CreateDeviceFromSysPath( + const char* sys_path); + + // Wraps udev_device_new_from_devnum(). + virtual std::unique_ptr<UdevDevice> CreateDeviceFromDeviceNumber( + char type, dev_t device_number); + + // Wraps udev_device_new_from_subsystem_sysname(). + virtual std::unique_ptr<UdevDevice> CreateDeviceFromSubsystemSysName( + const char* subsystem, const char* sys_name); + + // Wraps udev_enumerate_new(). + virtual std::unique_ptr<UdevEnumerate> CreateEnumerate(); + + // Wraps udev_monitor_new_from_netlink(). + virtual std::unique_ptr<UdevMonitor> CreateMonitorFromNetlink( + const char* name); + + private: + friend class MockUdev; + + // Creates a Udev by taking ownership of the |udev|. + explicit Udev(struct udev* udev); + + // Creates a UdevDevice object that wraps a given udev_device struct pointed + // by |device|. The ownership of |device| is transferred to returned + // UdevDevice object. + static std::unique_ptr<UdevDevice> CreateDevice(udev_device* device); + + struct udev* udev_; + + DISALLOW_COPY_AND_ASSIGN(Udev); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_H_ diff --git a/brillo/udev/udev_device.cc b/brillo/udev/udev_device.cc new file mode 100644 index 0000000..2251699 --- /dev/null +++ b/brillo/udev/udev_device.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev_device.h> + +#include <libudev.h> + +#include <base/logging.h> + +namespace brillo { + +UdevDevice::UdevDevice() : device_(nullptr) {} + +UdevDevice::UdevDevice(udev_device* device) : device_(device) { + CHECK(device_); + + udev_device_ref(device_); +} + +UdevDevice::~UdevDevice() { + if (device_) { + udev_device_unref(device_); + device_ = nullptr; + } +} + +std::unique_ptr<UdevDevice> UdevDevice::GetParent() const { + // udev_device_get_parent does not increase the reference count of the + // returned udev_device struct. + udev_device* parent_device = udev_device_get_parent(device_); + return parent_device ? std::make_unique<UdevDevice>(parent_device) : nullptr; +} + +std::unique_ptr<UdevDevice> UdevDevice::GetParentWithSubsystemDeviceType( + const char* subsystem, const char* device_type) const { + // udev_device_get_parent_with_subsystem_devtype does not increase the + // reference count of the returned udev_device struct. + udev_device* parent_device = udev_device_get_parent_with_subsystem_devtype( + device_, subsystem, device_type); + return parent_device ? std::make_unique<UdevDevice>(parent_device) : nullptr; +} + +bool UdevDevice::IsInitialized() const { + return udev_device_get_is_initialized(device_); +} + +uint64_t UdevDevice::GetMicrosecondsSinceInitialized() const { + return udev_device_get_usec_since_initialized(device_); +} + +uint64_t UdevDevice::GetSequenceNumber() const { + return udev_device_get_seqnum(device_); +} + +const char* UdevDevice::GetDevicePath() const { + return udev_device_get_devpath(device_); +} + +const char* UdevDevice::GetDeviceNode() const { + return udev_device_get_devnode(device_); +} + +dev_t UdevDevice::GetDeviceNumber() const { + return udev_device_get_devnum(device_); +} + +const char* UdevDevice::GetDeviceType() const { + return udev_device_get_devtype(device_); +} + +const char* UdevDevice::GetDriver() const { + return udev_device_get_driver(device_); +} + +const char* UdevDevice::GetSubsystem() const { + return udev_device_get_subsystem(device_); +} + +const char* UdevDevice::GetSysPath() const { + return udev_device_get_syspath(device_); +} + +const char* UdevDevice::GetSysName() const { + return udev_device_get_sysname(device_); +} + +const char* UdevDevice::GetSysNumber() const { + return udev_device_get_sysnum(device_); +} + +const char* UdevDevice::GetAction() const { + return udev_device_get_action(device_); +} + +std::unique_ptr<UdevListEntry> UdevDevice::GetDeviceLinksListEntry() const { + udev_list_entry* list_entry = udev_device_get_devlinks_list_entry(device_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +std::unique_ptr<UdevListEntry> UdevDevice::GetPropertiesListEntry() const { + udev_list_entry* list_entry = udev_device_get_properties_list_entry(device_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +const char* UdevDevice::GetPropertyValue(const char* key) const { + return udev_device_get_property_value(device_, key); +} + +std::unique_ptr<UdevListEntry> UdevDevice::GetTagsListEntry() const { + udev_list_entry* list_entry = udev_device_get_tags_list_entry(device_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +std::unique_ptr<UdevListEntry> UdevDevice::GetSysAttributeListEntry() const { + udev_list_entry* list_entry = udev_device_get_sysattr_list_entry(device_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +const char* UdevDevice::GetSysAttributeValue(const char* attribute) const { + return udev_device_get_sysattr_value(device_, attribute); +} + +std::unique_ptr<UdevDevice> UdevDevice::Clone() { + return std::make_unique<UdevDevice>(device_); +} + +} // namespace brillo diff --git a/brillo/udev/udev_device.h b/brillo/udev/udev_device.h new file mode 100644 index 0000000..2704a22 --- /dev/null +++ b/brillo/udev/udev_device.h @@ -0,0 +1,117 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_DEVICE_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_DEVICE_H_ + +#include <stdint.h> +#include <sys/types.h> + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_list_entry.h> + +struct udev_device; + +namespace brillo { + +// A udev device, which wraps a udev_device C struct from libudev and related +// library functions into a C++ object. +class BRILLO_EXPORT UdevDevice { + public: + // Constructs a UdevDevice object by taking a raw pointer to a udev_device + // struct as |device|. The ownership of |device| is not transferred, but its + // reference count is increased by one during the lifetime of this object. + explicit UdevDevice(udev_device* device); + + // Destructs this UdevDevice object and decreases the reference count of the + // underlying udev_device struct by one. + virtual ~UdevDevice(); + + // Wraps udev_device_get_parent(). + virtual std::unique_ptr<UdevDevice> GetParent() const; + + // Wraps udev_device_get_parent_with_subsystem_devtype(). + virtual std::unique_ptr<UdevDevice> GetParentWithSubsystemDeviceType( + const char* subsystem, const char* device_type) const; + + // Wraps udev_device_get_is_initialized(). + virtual bool IsInitialized() const; + + // Wraps udev_device_get_usec_since_initialized(). + virtual uint64_t GetMicrosecondsSinceInitialized() const; + + // Wraps udev_device_get_seqnum(). + virtual uint64_t GetSequenceNumber() const; + + // Wraps udev_device_get_devpath(). + virtual const char* GetDevicePath() const; + + // Wraps udev_device_get_devnode(). + virtual const char* GetDeviceNode() const; + + // Wraps udev_device_get_devnum(). + virtual dev_t GetDeviceNumber() const; + + // Wraps udev_device_get_devtype(). + virtual const char* GetDeviceType() const; + + // Wraps udev_device_get_driver(). + virtual const char* GetDriver() const; + + // Wraps udev_device_get_subsystem(). + virtual const char* GetSubsystem() const; + + // Wraps udev_device_get_syspath(). + virtual const char* GetSysPath() const; + + // Wraps udev_device_get_sysname(). + virtual const char* GetSysName() const; + + // Wraps udev_device_get_sysnum(). + virtual const char* GetSysNumber() const; + + // Wraps udev_device_get_action(). + virtual const char* GetAction() const; + + // Wraps udev_device_get_devlinks_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetDeviceLinksListEntry() const; + + // Wraps udev_device_get_properties_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetPropertiesListEntry() const; + + // Wraps udev_device_get_property_value(). + virtual const char* GetPropertyValue(const char* key) const; + + // Wraps udev_device_get_tags_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetTagsListEntry() const; + + // Wraps udev_device_get_sysattr_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetSysAttributeListEntry() const; + + // Wraps udev_device_get_sysattr_value(). + virtual const char* GetSysAttributeValue(const char* attribute) const; + + // Creates a copy of this UdevDevice pointing to the same underlying + // struct udev_device* (increasing its libudev reference count by 1). + virtual std::unique_ptr<UdevDevice> Clone(); + + private: + // Allows MockUdevDevice to invoke the private default constructor below. + friend class MockUdevDevice; + + // Constructs a UdevDevice object without referencing a udev_device struct, + // which is only allowed to be called by MockUdevDevice. + UdevDevice(); + + udev_device* device_; + + DISALLOW_COPY_AND_ASSIGN(UdevDevice); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_DEVICE_H_ diff --git a/brillo/udev/udev_enumerate.cc b/brillo/udev/udev_enumerate.cc new file mode 100644 index 0000000..0ac59b9 --- /dev/null +++ b/brillo/udev/udev_enumerate.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev_enumerate.h> + +#include <libudev.h> + +#include <base/logging.h> +#include <base/strings/stringprintf.h> +#include <brillo/udev/udev_device.h> + +using base::StringPrintf; + +namespace brillo { + +UdevEnumerate::UdevEnumerate() : enumerate_(nullptr) {} + +UdevEnumerate::UdevEnumerate(udev_enumerate* enumerate) + : enumerate_(enumerate) { + CHECK(enumerate_); + + udev_enumerate_ref(enumerate_); +} + +UdevEnumerate::~UdevEnumerate() { + if (enumerate_) { + udev_enumerate_unref(enumerate_); + enumerate_ = nullptr; + } +} + +bool UdevEnumerate::AddMatchSubsystem(const char* subsystem) { + int result = udev_enumerate_add_match_subsystem(enumerate_, subsystem); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_subsystem (%p, \"%s\") returned %d.", + enumerate_, subsystem, result); + return false; +} + +bool UdevEnumerate::AddNoMatchSubsystem(const char* subsystem) { + int result = udev_enumerate_add_nomatch_subsystem(enumerate_, subsystem); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_nomatch_subsystem (%p, \"%s\") returned %d.", + enumerate_, subsystem, result); + return false; +} + +bool UdevEnumerate::AddMatchSysAttribute(const char* attribute, + const char* value) { + int result = udev_enumerate_add_match_sysattr(enumerate_, attribute, value); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_sysattr (%p, \"%s\", \"%s\") returned %d.", + enumerate_, attribute, value, result); + return false; +} + +bool UdevEnumerate::AddNoMatchSysAttribute(const char* attribute, + const char* value) { + int result = udev_enumerate_add_nomatch_sysattr(enumerate_, attribute, value); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_nomatch_sysattr (%p, \"%s\", \"%s\") returned %d.", + enumerate_, attribute, value, result); + return false; +} + +bool UdevEnumerate::AddMatchProperty(const char* property, const char* value) { + int result = udev_enumerate_add_match_property(enumerate_, property, value); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_property (%p, \"%s\", \"%s\") returned %d.", + enumerate_, property, value, result); + return false; +} + +bool UdevEnumerate::AddMatchSysName(const char* sys_name) { + int result = udev_enumerate_add_match_sysname(enumerate_, sys_name); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_sysname (%p, \"%s\") returned %d.", enumerate_, + sys_name, result); + return false; +} + +bool UdevEnumerate::AddMatchTag(const char* tag) { + int result = udev_enumerate_add_match_tag(enumerate_, tag); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_tag (%p, \"%s\") returned %d.", enumerate_, tag, + result); + return false; +} + +bool UdevEnumerate::AddMatchIsInitialized() { + int result = udev_enumerate_add_match_is_initialized(enumerate_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_enumerate_add_match_is_initialized (%p) returned %d.", enumerate_, + result); + return false; +} + +bool UdevEnumerate::AddSysPath(const char* sys_path) { + int result = udev_enumerate_add_syspath(enumerate_, sys_path); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_enumerate_add_syspath(%p, \"%s\") returned %d.", + enumerate_, sys_path, result); + return false; +} + +bool UdevEnumerate::ScanDevices() { + int result = udev_enumerate_scan_devices(enumerate_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_enumerate_scan_devices(%p) returned %d.", + enumerate_, result); + return false; +} + +bool UdevEnumerate::ScanSubsystems() { + int result = udev_enumerate_scan_subsystems(enumerate_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_enumerate_scan_subsystems(%p) returned %d.", + enumerate_, result); + return false; +} + +std::unique_ptr<UdevListEntry> UdevEnumerate::GetListEntry() const { + udev_list_entry* list_entry = udev_enumerate_get_list_entry(enumerate_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +} // namespace brillo diff --git a/brillo/udev/udev_enumerate.h b/brillo/udev/udev_enumerate.h new file mode 100644 index 0000000..50a6183 --- /dev/null +++ b/brillo/udev/udev_enumerate.h @@ -0,0 +1,83 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_ENUMERATE_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_ENUMERATE_H_ + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> +#include <brillo/udev/udev_list_entry.h> + +struct udev_enumerate; + +namespace brillo { + +// A udev enumerate class, which wraps a udev_enumerate C struct from libudev +// and related library functions into a C++ object. +class BRILLO_EXPORT UdevEnumerate { + public: + // Constructs a UdevEnumerate object by taking a raw pointer to a + // udev_enumerate struct as |enumerate|. The ownership of |enumerate| is not + // transferred, but its reference count is increased by one during the + // lifetime of this object. + explicit UdevEnumerate(udev_enumerate* enumerate); + + // Destructs this UdevEnumerate object and decreases the reference count of + // the underlying udev_enumerate struct by one. + virtual ~UdevEnumerate(); + + // Wraps udev_enumerate_add_match_subsystem(). Returns true on success. + virtual bool AddMatchSubsystem(const char* subsystem); + + // Wraps udev_enumerate_add_nomatch_subsystem(). Returns true on success. + virtual bool AddNoMatchSubsystem(const char* subsystem); + + // Wraps udev_enumerate_add_match_sysattr(). Returns true on success. + virtual bool AddMatchSysAttribute(const char* attribute, const char* value); + + // Wraps udev_enumerate_add_nomatch_sysattr(). Returns true on success. + virtual bool AddNoMatchSysAttribute(const char* attribute, const char* value); + + // Wraps udev_enumerate_add_match_property(). Returns true on success. + virtual bool AddMatchProperty(const char* property, const char* value); + + // Wraps udev_enumerate_add_match_sysname(). Returns true on success. + virtual bool AddMatchSysName(const char* sys_name); + + // Wraps udev_enumerate_add_match_tag(). Returns true on success. + virtual bool AddMatchTag(const char* tag); + + // Wraps udev_enumerate_add_match_is_initialized(). Returns true on success. + virtual bool AddMatchIsInitialized(); + + // Wraps udev_enumerate_add_syspath(). Returns true on success. + virtual bool AddSysPath(const char* sys_path); + + // Wraps udev_enumerate_scan_devices(). Returns true on success. + virtual bool ScanDevices(); + + // Wraps udev_enumerate_scan_subsystems(). Returns true on success. + virtual bool ScanSubsystems(); + + // Wraps udev_enumerate_get_list_entry(). + virtual std::unique_ptr<UdevListEntry> GetListEntry() const; + + private: + // Allows MockUdevEnumerate to invoke the private default constructor below. + friend class MockUdevEnumerate; + + // Constructs a UdevEnumerate object without referencing a udev_enumerate + // struct, which is only allowed to be called by MockUdevEnumerate. + UdevEnumerate(); + + udev_enumerate* enumerate_; + + DISALLOW_COPY_AND_ASSIGN(UdevEnumerate); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_ENUMERATE_H_ diff --git a/brillo/udev/udev_list_entry.cc b/brillo/udev/udev_list_entry.cc new file mode 100644 index 0000000..739c435 --- /dev/null +++ b/brillo/udev/udev_list_entry.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev_list_entry.h> + +#include <libudev.h> + +#include <base/logging.h> + +namespace brillo { + +UdevListEntry::UdevListEntry() : list_entry_(nullptr) {} + +UdevListEntry::UdevListEntry(udev_list_entry* list_entry) + : list_entry_(list_entry) { + CHECK(list_entry_); +} + +std::unique_ptr<UdevListEntry> UdevListEntry::GetNext() const { + udev_list_entry* list_entry = udev_list_entry_get_next(list_entry_); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +std::unique_ptr<UdevListEntry> UdevListEntry::GetByName( + const char* name) const { + udev_list_entry* list_entry = udev_list_entry_get_by_name(list_entry_, name); + return list_entry ? std::make_unique<UdevListEntry>(list_entry) : nullptr; +} + +const char* UdevListEntry::GetName() const { + return udev_list_entry_get_name(list_entry_); +} + +const char* UdevListEntry::GetValue() const { + return udev_list_entry_get_value(list_entry_); +} + +} // namespace brillo diff --git a/brillo/udev/udev_list_entry.h b/brillo/udev/udev_list_entry.h new file mode 100644 index 0000000..ee61d18 --- /dev/null +++ b/brillo/udev/udev_list_entry.h @@ -0,0 +1,55 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_LIST_ENTRY_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_LIST_ENTRY_H_ + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> + +struct udev_list_entry; + +namespace brillo { + +// A udev list entry, which wraps a udev_list_entry C struct from libudev and +// related library functions into a C++ object. +class BRILLO_EXPORT UdevListEntry { + public: + // Constructs a UdevListEntry object by taking a raw pointer to a + // udev_list_entry struct as |list_entry|. The ownership of |list_entry| is + // not transferred, and thus it should outlive this object. + explicit UdevListEntry(udev_list_entry* list_entry); + + virtual ~UdevListEntry() = default; + + // Wraps udev_list_entry_get_next(). + virtual std::unique_ptr<UdevListEntry> GetNext() const; + + // Wraps udev_list_entry_get_by_name(). + virtual std::unique_ptr<UdevListEntry> GetByName(const char* name) const; + + // Wraps udev_list_entry_get_name(). + virtual const char* GetName() const; + + // Wraps udev_list_entry_get_value(). + virtual const char* GetValue() const; + + private: + // Allows MockUdevListEntry to invoke the private default constructor below. + friend class MockUdevListEntry; + + // Constructs a UdevListEntry object without referencing a udev_list_entry + // struct, which is only allowed to be called by MockUdevListEntry. + UdevListEntry(); + + udev_list_entry* const list_entry_; + + DISALLOW_COPY_AND_ASSIGN(UdevListEntry); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_LIST_ENTRY_H_ diff --git a/brillo/udev/udev_monitor.cc b/brillo/udev/udev_monitor.cc new file mode 100644 index 0000000..c4b63e5 --- /dev/null +++ b/brillo/udev/udev_monitor.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <brillo/udev/udev_monitor.h> + +#include <libudev.h> + +#include <base/logging.h> +#include <base/strings/stringprintf.h> +#include <brillo/udev/udev_device.h> + +using base::StringPrintf; + +namespace brillo { + +UdevMonitor::UdevMonitor() : monitor_(nullptr) {} + +UdevMonitor::UdevMonitor(udev_monitor* monitor) : monitor_(monitor) { + CHECK(monitor_); + + udev_monitor_ref(monitor_); +} + +UdevMonitor::~UdevMonitor() { + if (monitor_) { + udev_monitor_unref(monitor_); + monitor_ = nullptr; + } +} + +bool UdevMonitor::EnableReceiving() { + int result = udev_monitor_enable_receiving(monitor_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_monitor_enable_receiving(%p) returned %d.", + monitor_, result); + return false; +} + +int UdevMonitor::GetFileDescriptor() const { + int file_descriptor = udev_monitor_get_fd(monitor_); + if (file_descriptor >= 0) + return file_descriptor; + + VLOG(2) << StringPrintf("udev_monitor_get_fd(%p) returned %d.", monitor_, + file_descriptor); + return kInvalidFileDescriptor; +} + +std::unique_ptr<UdevDevice> UdevMonitor::ReceiveDevice() { + udev_device* received_device = udev_monitor_receive_device(monitor_); + if (received_device) { + auto device = std::make_unique<UdevDevice>(received_device); + // udev_monitor_receive_device increases the reference count of the returned + // udev_device struct, while UdevDevice also holds a reference count of the + // udev_device struct. Thus, decrease the reference count of the udev_device + // struct. + udev_device_unref(received_device); + return device; + } + + VLOG(2) << StringPrintf("udev_monitor_receive_device(%p) returned nullptr.", + monitor_); + return nullptr; +} + +bool UdevMonitor::FilterAddMatchSubsystemDeviceType(const char* subsystem, + const char* device_type) { + int result = udev_monitor_filter_add_match_subsystem_devtype( + monitor_, subsystem, device_type); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_monitor_filter_add_match_subsystem_devtype (%p, \"%s\", \"%s\") " + "returned %d.", + monitor_, subsystem, device_type, result); + return false; +} + +bool UdevMonitor::FilterAddMatchTag(const char* tag) { + int result = udev_monitor_filter_add_match_tag(monitor_, tag); + if (result == 0) + return true; + + VLOG(2) << StringPrintf( + "udev_monitor_filter_add_tag (%p, \"%s\") returned %d.", monitor_, tag, + result); + return false; +} + +bool UdevMonitor::FilterUpdate() { + int result = udev_monitor_filter_update(monitor_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_monitor_filter_update(%p) returned %d.", + monitor_, result); + return false; +} + +bool UdevMonitor::FilterRemove() { + int result = udev_monitor_filter_remove(monitor_); + if (result == 0) + return true; + + VLOG(2) << StringPrintf("udev_monitor_filter_remove(%p) returned %d.", + monitor_, result); + return false; +} + +} // namespace brillo diff --git a/brillo/udev/udev_monitor.h b/brillo/udev/udev_monitor.h new file mode 100644 index 0000000..b9136f0 --- /dev/null +++ b/brillo/udev/udev_monitor.h @@ -0,0 +1,72 @@ +// Copyright (c) 2013 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef LIBBRILLO_BRILLO_UDEV_UDEV_MONITOR_H_ +#define LIBBRILLO_BRILLO_UDEV_UDEV_MONITOR_H_ + +#include <memory> + +#include <base/macros.h> +#include <brillo/brillo_export.h> + +struct udev_monitor; + +namespace brillo { + +class UdevDevice; + +// A udev monitor, which wraps a udev_monitor C struct from libudev and related +// library functions into a C++ object. +class BRILLO_EXPORT UdevMonitor { + public: + static const int kInvalidFileDescriptor = -1; + + // Constructs a UdevMonitor object by taking a raw pointer to a udev_monitor + // struct as |monitor|. The ownership of |monitor| is not transferred, but its + // reference count is increased by one during the lifetime of this object. + explicit UdevMonitor(udev_monitor* monitor); + + // Destructs this UdevMonitor object and decreases the reference count of the + // underlying udev_monitor struct by one. + virtual ~UdevMonitor(); + + // Wraps udev_monitor_enable_receiving(). Returns true on success. + virtual bool EnableReceiving(); + + // Wraps udev_monitor_get_fd(). + virtual int GetFileDescriptor() const; + + // Wraps udev_monitor_receive_device(). + virtual std::unique_ptr<UdevDevice> ReceiveDevice(); + + // Wraps udev_monitor_filter_add_match_subsystem_devtype(). Returns true on + // success. + virtual bool FilterAddMatchSubsystemDeviceType(const char* subsystem, + const char* device_type); + + // Wraps udev_monitor_filter_add_match_tag(). Returns true on success. + virtual bool FilterAddMatchTag(const char* tag); + + // Wraps udev_monitor_filter_update(). Returns true on success. + virtual bool FilterUpdate(); + + // Wraps udev_monitor_filter_remove(). Returns true on success. + virtual bool FilterRemove(); + + private: + // Allows MockUdevMonitor to invoke the private default constructor below. + friend class MockUdevMonitor; + + // Constructs a UdevMonitor object without referencing a udev_monitor struct, + // which is only allowed to be called by MockUdevMonitor. + UdevMonitor(); + + udev_monitor* monitor_; + + DISALLOW_COPY_AND_ASSIGN(UdevMonitor); +}; + +} // namespace brillo + +#endif // LIBBRILLO_BRILLO_UDEV_UDEV_MONITOR_H_ diff --git a/brillo/url_utils_unittest.cc b/brillo/url_utils_test.cc index a2603cb..a2603cb 100644 --- a/brillo/url_utils_unittest.cc +++ b/brillo/url_utils_test.cc diff --git a/brillo/userdb_utils.cc b/brillo/userdb_utils.cc index 55c964c..1308fb7 100644 --- a/brillo/userdb_utils.cc +++ b/brillo/userdb_utils.cc @@ -4,6 +4,7 @@ #include "brillo/userdb_utils.h" +#include <errno.h> #include <grp.h> #include <pwd.h> #include <sys/types.h> @@ -12,6 +13,7 @@ #include <vector> #include <base/logging.h> +#include <base/posix/safe_strerror.h> namespace brillo { namespace userdb { @@ -23,8 +25,16 @@ bool GetUserInfo(const std::string& user, uid_t* uid, gid_t* gid) { passwd pwd_buf; passwd* pwd = nullptr; std::vector<char> buf(buf_len); - if (getpwnam_r(user.c_str(), &pwd_buf, buf.data(), buf_len, &pwd) || !pwd) { - PLOG(ERROR) << "Unable to find user " << user; + + int err_num; + do { + err_num = getpwnam_r(user.c_str(), &pwd_buf, buf.data(), buf_len, &pwd); + } while (err_num == EINTR); + + if (!pwd) { + LOG(ERROR) << "Unable to find user " << user << ": " + << (err_num ? base::safe_strerror(err_num) + : "No matching record"); return false; } @@ -42,8 +52,16 @@ bool GetGroupInfo(const std::string& group, gid_t* gid) { struct group grp_buf; struct group* grp = nullptr; std::vector<char> buf(buf_len); - if (getgrnam_r(group.c_str(), &grp_buf, buf.data(), buf_len, &grp) || !grp) { - PLOG(ERROR) << "Unable to find group " << group; + + int err_num; + do { + err_num = getgrnam_r(group.c_str(), &grp_buf, buf.data(), buf_len, &grp); + } while (err_num == EINTR); + + if (!grp) { + LOG(ERROR) << "Unable to find group " << group << ": " + << (err_num ? base::safe_strerror(err_num) + : "No matching record"); return false; } diff --git a/brillo/value_conversion.h b/brillo/value_conversion.h index b520a77..b0d61a1 100644 --- a/brillo/value_conversion.h +++ b/brillo/value_conversion.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef BRILLO_VALUE_CONVERSION_H_ -#define BRILLO_VALUE_CONVERSION_H_ +#ifndef LIBBRILLO_BRILLO_VALUE_CONVERSION_H_ +#define LIBBRILLO_BRILLO_VALUE_CONVERSION_H_ // This file provides a set of helper functions to convert between base::Value // and native types. Apart from handling standard types such as 'int' and @@ -24,6 +24,7 @@ #include <map> #include <memory> #include <string> +#include <utility> #include <vector> #include <base/values.h> @@ -73,7 +74,7 @@ bool FromValue(const base::Value& in_value, std::vector<T, Alloc>* out_value) { return false; out_value->clear(); out_value->reserve(list->GetSize()); - for (const auto& item : *list) { + for (const base::Value& item : *list) { T value{}; if (!FromValue(item, &value)) return false; @@ -134,4 +135,4 @@ std::unique_ptr<base::Value> ToValue( } // namespace brillo -#endif // BRILLO_VALUE_CONVERSION_H_ +#endif // LIBBRILLO_BRILLO_VALUE_CONVERSION_H_ diff --git a/brillo/value_conversion_unittest.cc b/brillo/value_conversion_test.cc index aa1be2a..fec4052 100644 --- a/brillo/value_conversion_unittest.cc +++ b/brillo/value_conversion_test.cc @@ -170,7 +170,7 @@ TEST(ValueConversionTest, FromValueVectorOfString) { TEST(ValueConversionTest, FromValueVectorOfVectors) { std::vector<std::vector<int>> actual; EXPECT_TRUE(FromValue(*ParseValue("[[1,2], [], [3]]"), &actual)); - EXPECT_EQ((std::vector<std::vector<int>>{{1,2}, {}, {3}}), actual); + EXPECT_EQ((std::vector<std::vector<int>>{{1, 2}, {}, {3}}), actual); EXPECT_TRUE(FromValue(*ParseValue("[]"), &actual)); EXPECT_TRUE(actual.empty()); diff --git a/brillo/variant_dictionary_unittest.cc b/brillo/variant_dictionary_test.cc index 73ead2c..73ead2c 100644 --- a/brillo/variant_dictionary_unittest.cc +++ b/brillo/variant_dictionary_test.cc diff --git a/gen_coverage_html.sh b/gen_coverage_html.sh deleted file mode 100755 index 9faf1a9..0000000 --- a/gen_coverage_html.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -# Copyright (c) 2009 The Chromium OS Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -set -ex - -scons debug=1 -c -scons debug=1 -lcov -d . --zerocounters -./unittests -lcov --base-directory . --directory . --capture --output-file app.info - -# some versions of genhtml support the --no-function-coverage argument, -# which we want. The problem w/ function coverage is that every template -# instantiation of a method counts as a different method, so if we -# instantiate a method twice, once for testing and once for prod, the method -# is tested, but it shows only 50% function coverage b/c it thinks we didn't -# test the prod version. - -genhtml --no-function-coverage -o html ./app.info || genhtml -o html ./app.info diff --git a/install_attributes/libinstallattributes.h b/install_attributes/libinstallattributes.h index b947156..2bcbf0f 100644 --- a/install_attributes/libinstallattributes.h +++ b/install_attributes/libinstallattributes.h @@ -53,7 +53,7 @@ class BRILLO_EXPORT InstallAttributesReader { // successful, too. bool initialized_ = false; -private: + private: // Try to load the verified install attributes from disk. This is expected to // fail when install attributes haven't yet been finalized (OOBE) or verified // (early in the boot sequence). @@ -63,4 +63,4 @@ private: std::string empty_string_; }; -#endif // LIBBRILLO_LIBINSTALLATTRIBUTES_H_ +#endif // LIBBRILLO_INSTALL_ATTRIBUTES_LIBINSTALLATTRIBUTES_H_ diff --git a/install_attributes/mock_install_attributes_reader.h b/install_attributes/mock_install_attributes_reader.h index 5ccee02..0d2adcd 100644 --- a/install_attributes/mock_install_attributes_reader.h +++ b/install_attributes/mock_install_attributes_reader.h @@ -5,6 +5,8 @@ #ifndef LIBBRILLO_INSTALL_ATTRIBUTES_MOCK_INSTALL_ATTRIBUTES_READER_H_ #define LIBBRILLO_INSTALL_ATTRIBUTES_MOCK_INSTALL_ATTRIBUTES_READER_H_ +#include <string> + #include "libinstallattributes.h" #include "bindings/install_attributes.pb.h" diff --git a/install_attributes/tests/libinstallattributes_unittest.cc b/install_attributes/tests/libinstallattributes_test.cc index 45ff827..686e565 100644 --- a/install_attributes/tests/libinstallattributes_unittest.cc +++ b/install_attributes/tests/libinstallattributes_test.cc @@ -76,8 +76,3 @@ TEST(InstallAttributesTest, NoProgressionFromEmptyToManaged) { ASSERT_TRUE(reader.IsLocked()); ASSERT_EQ(std::string(), reader.GetAttribute("enterprise.mode")); } - -int main(int argc, char* argv[]) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/libbrillo-395517.gypi b/libbrillo-395517.gypi deleted file mode 100644 index a846c70..0000000 --- a/libbrillo-395517.gypi +++ /dev/null @@ -1,8 +0,0 @@ -{ - 'variables': { - 'libbase_ver': 395517, - }, - 'includes': [ - '../libbrillo/libbrillo.gypi', - ], -} diff --git a/libbrillo-glib.pc.in b/libbrillo-glib.pc.in deleted file mode 100644 index cfd9fc8..0000000 --- a/libbrillo-glib.pc.in +++ /dev/null @@ -1,8 +0,0 @@ -bslot=@BSLOT@ - -Name: libbrillo-glib -Description: brillo glib wrapper library -Version: ${bslot} -Requires.private: @PRIVATE_PC@ -Libs: -lbrillo-glib-${bslot} - diff --git a/libbrillo-test.pc.in b/libbrillo-test.pc.in deleted file mode 100644 index 4fece7c..0000000 --- a/libbrillo-test.pc.in +++ /dev/null @@ -1,8 +0,0 @@ -bslot=@BSLOT@ - -Name: libbrillo-test -Description: brillo test library -Version: ${bslot} -# Because libbrillo-test is static, we have to depend directly on everything. -Requires: @PRIVATE_PC@ -Libs: -lbrillo-test-${bslot} diff --git a/libbrillo.gyp b/libbrillo.gyp deleted file mode 100644 index 5a2bbe4..0000000 --- a/libbrillo.gyp +++ /dev/null @@ -1,7 +0,0 @@ -{ - 'includes': [ - 'libbrillo-395517.gypi', - 'libinstallattributes.gypi', - 'libpolicy.gypi', - ] -} diff --git a/libbrillo.gypi b/libbrillo.gypi deleted file mode 100644 index 05d95e6..0000000 --- a/libbrillo.gypi +++ /dev/null @@ -1,465 +0,0 @@ -{ - 'target_defaults': { - 'variables': { - 'deps': [ - 'libchrome-<(libbase_ver)' - ], - }, - 'include_dirs': [ - '../libbrillo', - ], - 'defines': [ - 'USE_DBUS=<(USE_dbus)', - 'USE_RTTI_FOR_TYPE_TAGS', - ], - }, - 'targets': [ - { - 'target_name': 'libbrillo-<(libbase_ver)', - 'type': 'none', - 'dependencies': [ - 'libbrillo-core-<(libbase_ver)', - 'libbrillo-cryptohome-<(libbase_ver)', - 'libbrillo-http-<(libbase_ver)', - 'libbrillo-minijail-<(libbase_ver)', - 'libbrillo-streams-<(libbase_ver)', - 'libinstallattributes-<(libbase_ver)', - 'libpolicy-<(libbase_ver)', - ], - 'direct_dependent_settings': { - 'include_dirs': [ - '../libbrillo', - ], - }, - 'includes': ['../common-mk/deps.gypi'], - }, - { - 'target_name': 'libbrillo-core-<(libbase_ver)', - 'type': 'shared_library', - 'variables': { - 'exported_deps': [ - ], - 'conditions': [ - ['USE_dbus == 1', { - 'exported_deps': [ - 'dbus-1', - ], - }], - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'libraries': ['-lmodp_b64'], - 'sources': [ - 'brillo/asynchronous_signal_handler.cc', - 'brillo/backoff_entry.cc', - 'brillo/daemons/daemon.cc', - 'brillo/data_encoding.cc', - 'brillo/errors/error.cc', - 'brillo/errors/error_codes.cc', - 'brillo/file_utils.cc', - 'brillo/flag_helper.cc', - 'brillo/imageloader/manifest.cc', - 'brillo/key_value_store.cc', - 'brillo/message_loops/base_message_loop.cc', - 'brillo/message_loops/message_loop.cc', - 'brillo/message_loops/message_loop_utils.cc', - 'brillo/mime_utils.cc', - 'brillo/osrelease_reader.cc', - 'brillo/process.cc', - 'brillo/process_reaper.cc', - 'brillo/process_information.cc', - 'brillo/secure_blob.cc', - 'brillo/strings/string_utils.cc', - 'brillo/syslog_logging.cc', - 'brillo/type_name_undecorate.cc', - 'brillo/url_utils.cc', - 'brillo/userdb_utils.cc', - 'brillo/value_conversion.cc', - ], - 'conditions': [ - ['USE_dbus == 1', { - 'sources': [ - 'brillo/any.cc', - 'brillo/daemons/dbus_daemon.cc', - 'brillo/dbus/async_event_sequencer.cc', - 'brillo/dbus/data_serialization.cc', - 'brillo/dbus/dbus_connection.cc', - 'brillo/dbus/dbus_method_invoker.cc', - 'brillo/dbus/dbus_method_response.cc', - 'brillo/dbus/dbus_object.cc', - 'brillo/dbus/dbus_service_watcher.cc', - 'brillo/dbus/dbus_signal.cc', - 'brillo/dbus/exported_object_manager.cc', - 'brillo/dbus/exported_property_set.cc', - 'brillo/dbus/utils.cc', - ], - }], - ], - }, - { - 'target_name': 'libbrillo-http-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libbrillo-core-<(libbase_ver)', - 'libbrillo-streams-<(libbase_ver)', - ], - 'variables': { - 'exported_deps': [ - 'libcurl', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'sources': [ - 'brillo/http/curl_api.cc', - 'brillo/http/http_connection_curl.cc', - 'brillo/http/http_form_data.cc', - 'brillo/http/http_request.cc', - 'brillo/http/http_transport.cc', - 'brillo/http/http_transport_curl.cc', - 'brillo/http/http_utils.cc', - ], - 'conditions': [ - ['USE_dbus == 1', { - 'sources': [ - 'brillo/http/http_proxy.cc', - ], - }], - ], - }, - { - 'target_name': 'libbrillo-streams-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libbrillo-core-<(libbase_ver)', - ], - 'variables': { - 'exported_deps': [ - 'openssl', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'sources': [ - 'brillo/streams/file_stream.cc', - 'brillo/streams/input_stream_set.cc', - 'brillo/streams/memory_containers.cc', - 'brillo/streams/memory_stream.cc', - 'brillo/streams/openssl_stream_bio.cc', - 'brillo/streams/stream.cc', - 'brillo/streams/stream_errors.cc', - 'brillo/streams/stream_utils.cc', - 'brillo/streams/tls_stream.cc', - ], - }, - { - 'target_name': 'libbrillo-test-<(libbase_ver)', - 'type': 'static_library', - 'standalone_static_library': 1, - 'dependencies': [ - 'libbrillo-http-<(libbase_ver)', - ], - 'sources': [ - 'brillo/http/http_connection_fake.cc', - 'brillo/http/http_transport_fake.cc', - 'brillo/message_loops/fake_message_loop.cc', - 'brillo/streams/fake_stream.cc', - 'brillo/unittest_utils.cc', - ], - 'includes': ['../common-mk/deps.gypi'], - }, - { - 'target_name': 'libbrillo-cryptohome-<(libbase_ver)', - 'type': 'shared_library', - 'variables': { - 'exported_deps': [ - 'openssl', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'sources': [ - 'brillo/cryptohome.cc', - ], - }, - { - 'target_name': 'libbrillo-minijail-<(libbase_ver)', - 'type': 'shared_library', - 'variables': { - 'exported_deps': [ - 'libminijail', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'cflags': [ - '-fvisibility=default', - ], - 'sources': [ - 'brillo/minijail/minijail.cc', - ], - }, - { - 'target_name': 'libinstallattributes-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libinstallattributes-includes', - '../common-mk/external_dependencies.gyp:install_attributes-proto', - ], - 'variables': { - 'exported_deps': [ - 'protobuf-lite', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'sources': [ - 'install_attributes/libinstallattributes.cc', - ], - }, - { - 'target_name': 'libpolicy-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libinstallattributes-<(libbase_ver)', - 'libpolicy-includes', - '../common-mk/external_dependencies.gyp:policy-protos', - ], - 'variables': { - 'exported_deps': [ - 'openssl', - 'protobuf-lite', - ], - 'deps': ['<@(exported_deps)'], - }, - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'ldflags': [ - '-Wl,--version-script,<(platform2_root)/libbrillo/libpolicy.ver', - ], - 'sources': [ - 'policy/device_policy.cc', - 'policy/device_policy_impl.cc', - 'policy/policy_util.cc', - 'policy/resilient_policy_util.cc', - 'policy/libpolicy.cc', - ], - }, - { - 'target_name': 'libbrillo-glib-<(libbase_ver)', - 'type': 'shared_library', - 'dependencies': [ - 'libbrillo-<(libbase_ver)', - ], - 'variables': { - 'exported_deps': [ - 'glib-2.0', - 'gobject-2.0', - ], - 'conditions': [ - ['USE_dbus == 1', { - 'exported_deps': [ - 'dbus-1', - 'dbus-glib-1', - ], - }], - ], - 'deps': ['<@(exported_deps)'], - }, - 'cflags': [ - # glib uses the deprecated "register" attribute in some header files. - '-Wno-deprecated-register', - ], - 'all_dependent_settings': { - 'variables': { - 'deps': [ - '<@(exported_deps)', - ], - }, - }, - 'includes': ['../common-mk/deps.gypi'], - 'conditions': [ - ['USE_dbus == 1', { - 'sources': [ - 'brillo/glib/abstract_dbus_service.cc', - 'brillo/glib/dbus.cc', - ], - }], - ], - }, - ], - 'conditions': [ - ['USE_test == 1', { - 'targets': [ - { - 'target_name': 'libbrillo-<(libbase_ver)_unittests', - 'type': 'executable', - 'dependencies': [ - 'libbrillo-<(libbase_ver)', - 'libbrillo-glib-<(libbase_ver)', - 'libbrillo-test-<(libbase_ver)', - ], - 'variables': { - 'deps': [ - 'libchrome-test-<(libbase_ver)', - ], - 'proto_in_dir': 'brillo/dbus', - 'proto_out_dir': 'include/brillo/dbus', - }, - 'includes': [ - '../common-mk/common_test.gypi', - '../common-mk/protoc.gypi', - ], - 'cflags': [ - '-Wno-format-zero-length', - ], - 'conditions': [ - ['debug == 1', { - 'cflags': [ - '-fprofile-arcs', - '-ftest-coverage', - '-fno-inline', - ], - 'libraries': [ - '-lgcov', - ], - }], - ], - 'sources': [ - 'brillo/asynchronous_signal_handler_unittest.cc', - 'brillo/backoff_entry_unittest.cc', - 'brillo/data_encoding_unittest.cc', - 'brillo/enum_flags_unittest.cc', - 'brillo/errors/error_codes_unittest.cc', - 'brillo/errors/error_unittest.cc', - 'brillo/file_utils_unittest.cc', - 'brillo/flag_helper_unittest.cc', - 'brillo/glib/object_unittest.cc', - 'brillo/http/http_connection_curl_unittest.cc', - 'brillo/http/http_form_data_unittest.cc', - 'brillo/http/http_request_unittest.cc', - 'brillo/http/http_transport_curl_unittest.cc', - 'brillo/http/http_utils_unittest.cc', - 'brillo/imageloader/manifest_unittest.cc', - 'brillo/key_value_store_unittest.cc', - 'brillo/map_utils_unittest.cc', - 'brillo/message_loops/base_message_loop_unittest.cc', - 'brillo/message_loops/fake_message_loop_unittest.cc', - 'brillo/message_loops/message_loop_unittest.cc', - 'brillo/mime_utils_unittest.cc', - 'brillo/osrelease_reader_unittest.cc', - 'brillo/process_reaper_unittest.cc', - 'brillo/process_unittest.cc', - 'brillo/secure_blob_unittest.cc', - 'brillo/streams/fake_stream_unittest.cc', - 'brillo/streams/file_stream_unittest.cc', - 'brillo/streams/input_stream_set_unittest.cc', - 'brillo/streams/memory_containers_unittest.cc', - 'brillo/streams/memory_stream_unittest.cc', - 'brillo/streams/openssl_stream_bio_unittests.cc', - 'brillo/streams/stream_unittest.cc', - 'brillo/streams/stream_utils_unittest.cc', - 'brillo/strings/string_utils_unittest.cc', - 'brillo/unittest_utils.cc', - 'brillo/url_utils_unittest.cc', - 'brillo/value_conversion_unittest.cc', - 'testrunner.cc', - ], - 'conditions': [ - ['USE_dbus == 1', { - 'sources': [ - 'brillo/any_unittest.cc', - 'brillo/any_internal_impl_unittest.cc', - 'brillo/dbus/async_event_sequencer_unittest.cc', - 'brillo/dbus/data_serialization_unittest.cc', - 'brillo/dbus/dbus_method_invoker_unittest.cc', - 'brillo/dbus/dbus_object_unittest.cc', - 'brillo/dbus/dbus_param_reader_unittest.cc', - 'brillo/dbus/dbus_param_writer_unittest.cc', - 'brillo/dbus/dbus_signal_handler_unittest.cc', - 'brillo/dbus/exported_object_manager_unittest.cc', - 'brillo/dbus/exported_property_set_unittest.cc', - 'brillo/http/http_proxy_unittest.cc', - 'brillo/type_name_undecorate_unittest.cc', - 'brillo/variant_dictionary_unittest.cc', - '<(proto_in_dir)/test.proto', - ], - }], - ], - }, - { - 'target_name': 'libinstallattributes-<(libbase_ver)_unittests', - 'type': 'executable', - 'dependencies': [ - '../common-mk/external_dependencies.gyp:install_attributes-proto', - 'libinstallattributes-<(libbase_ver)', - ], - 'includes': ['../common-mk/common_test.gypi'], - 'sources': [ - 'install_attributes/tests/libinstallattributes_unittest.cc', - ] - }, - { - 'target_name': 'libpolicy-<(libbase_ver)_unittests', - 'type': 'executable', - 'dependencies': [ - '../common-mk/external_dependencies.gyp:install_attributes-proto', - '../common-mk/external_dependencies.gyp:policy-protos', - 'libinstallattributes-<(libbase_ver)', - 'libpolicy-<(libbase_ver)', - ], - 'includes': ['../common-mk/common_test.gypi'], - 'sources': [ - 'install_attributes/mock_install_attributes_reader.cc', - 'policy/tests/device_policy_impl_unittest.cc', - 'policy/tests/libpolicy_unittest.cc', - 'policy/tests/policy_util_unittest.cc', - 'policy/tests/resilient_policy_util_unittest.cc', - ] - }, - ], - }], - ], -} diff --git a/libbrillo.pc.in b/libbrillo.pc.in deleted file mode 100644 index a3a9e07..0000000 --- a/libbrillo.pc.in +++ /dev/null @@ -1,8 +0,0 @@ -bslot=@BSLOT@ - -Name: libbrillo -Description: brillo base library -Version: ${bslot} -Requires.private: @PRIVATE_PC@ -Cflags: -DUSE_RTTI_FOR_TYPE_TAGS -Libs: -lbrillo-${bslot} diff --git a/libinstallattributes.gypi b/libinstallattributes.gypi deleted file mode 100644 index e0c0014..0000000 --- a/libinstallattributes.gypi +++ /dev/null @@ -1,16 +0,0 @@ -{ - 'targets': [ - { - 'target_name': 'libinstallattributes-includes', - 'type': 'none', - 'copies': [ - { - 'destination': '<(SHARED_INTERMEDIATE_DIR)/include/install_attributes', - 'files': [ - 'install_attributes/libinstallattributes.h', - ], - }, - ], - }, - ], -} diff --git a/libpolicy.gypi b/libpolicy.gypi deleted file mode 100644 index b3a3d49..0000000 --- a/libpolicy.gypi +++ /dev/null @@ -1,22 +0,0 @@ -{ - 'targets': [ - { - 'target_name': 'libpolicy-includes', - 'type': 'none', - 'copies': [ - { - 'destination': '<(SHARED_INTERMEDIATE_DIR)/include/policy', - 'files': [ - 'policy/device_policy.h', - 'policy/device_policy_impl.h', - 'policy/libpolicy.h', - 'policy/mock_libpolicy.h', - 'policy/mock_device_policy.h', - 'policy/policy_util.h', - 'policy/resilient_policy_util.h', - ], - }, - ], - }, - ], -} diff --git a/platform2_preinstall.sh b/platform2_preinstall.sh deleted file mode 100755 index 448a31a..0000000 --- a/platform2_preinstall.sh +++ /dev/null @@ -1,52 +0,0 @@ -#!/bin/bash - -# Copyright (c) 2013 The Chromium OS Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -set -e - -OUT=$1 -shift -for v; do - # Extract all the libbrillo sublibs from 'dependencies' section of - # 'libbrillo-<(libbase_ver)' target in libbrillo.gypi and convert them - # into an array of "-lbrillo-<sublib>-<v>" flags. - sublibs=($(sed -n " - /'target_name': 'libbrillo-<(libbase_ver)'/,/target_name/ { - /dependencies/,/],/ { - /libbrillo/ { - s:[',]::g - s:<(libbase_ver):${v}:g - s:libbrillo:-lbrillo: - p - } - } - }" libbrillo.gypi)) - - echo "GROUP ( AS_NEEDED ( ${sublibs[@]} ) )" > "${OUT}"/lib/libbrillo-${v}.so - - deps=$(<"${OUT}"/gen/libbrillo-${v}-deps.txt) - pc="${OUT}"/lib/libbrillo-${v}.pc - - sed \ - -e "s/@BSLOT@/${v}/g" \ - -e "s/@PRIVATE_PC@/${deps}/g" \ - "libbrillo.pc.in" > "${pc}" - - deps_test=$(<"${OUT}"/gen/libbrillo-test-${v}-deps.txt) - deps_test+=" libbrillo-${v}" - sed \ - -e "s/@BSLOT@/${v}/g" \ - -e "s/@PRIVATE_PC@/${deps_test}/g" \ - "libbrillo-test.pc.in" > "${OUT}/lib/libbrillo-test-${v}.pc" - - - deps_glib=$(<"${OUT}"/gen/libbrillo-glib-${v}-deps.txt) - pc_glib="${OUT}"/lib/libbrillo-glib-${v}.pc - - sed \ - -e "s/@BSLOT@/${v}/g" \ - -e "s/@PRIVATE_PC@/${deps_glib}/g" \ - "libbrillo-glib.pc.in" > "${pc_glib}" -done diff --git a/policy/OWNERS b/policy/OWNERS new file mode 100644 index 0000000..74cea9e --- /dev/null +++ b/policy/OWNERS @@ -0,0 +1,7 @@ +emaxx@chromium.org +pmarko@chromium.org +poromov@chromium.org +rsorokin@chromium.org + +# TEAM: managed-devices@google.com +# COMPONENT: Enterprise>CloudPolicy diff --git a/policy/device_policy.h b/policy/device_policy.h index 5913d8c..29e1ed3 100644 --- a/policy/device_policy.h +++ b/policy/device_policy.h @@ -69,6 +69,10 @@ class DevicePolicy { // Returns true unless there is a policy on disk and loading it fails. virtual bool LoadPolicy() = 0; + // Returns true if OOBE has been completed and if the device has been enrolled + // as an enterprise or enterpriseAD device. + virtual bool IsEnterpriseEnrolled() const = 0; + // Writes the value of the DevicePolicyRefreshRate policy in |rate|. Returns // true on success. virtual bool GetPolicyRefreshRate(int* rate) const = 0; @@ -224,6 +228,18 @@ class DevicePolicy { virtual bool GetDeviceUpdateStagingSchedule( std::vector<DayPercentagePair>* staging_schedule_out) const = 0; + // Writes the value of the DeviceQuickFixBuildToken to + // |device_quick_fix_build_token|. + // Returns true if it has been written, or false if the policy was not set. + virtual bool GetDeviceQuickFixBuildToken( + std::string* device_quick_fix_build_token) const = 0; + + // Writes the value of the Directory API ID to |directory_api_id_out|. + // Returns true on success, false if the ID is not available (eg if the device + // is not enrolled). + virtual bool GetDeviceDirectoryApiId( + std::string* directory_api_id_out) const = 0; + private: // Verifies that the policy signature is correct. virtual bool VerifyPolicySignature() = 0; diff --git a/policy/device_policy_impl.cc b/policy/device_policy_impl.cc index 76b82a1..54ea1f9 100644 --- a/policy/device_policy_impl.cc +++ b/policy/device_policy_impl.cc @@ -5,6 +5,7 @@ #include "policy/device_policy_impl.h" #include <algorithm> +#include <map> #include <memory> #include <set> #include <string> @@ -15,6 +16,7 @@ #include <base/logging.h> #include <base/macros.h> #include <base/memory/ptr_util.h> +#include <base/stl_util.h> #include <base/time/time.h> #include <base/values.h> #include <openssl/evp.h> @@ -29,6 +31,12 @@ namespace em = enterprise_management; namespace policy { +// TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped. +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#define EVP_MD_CTX_new EVP_MD_CTX_create +#define EVP_MD_CTX_free EVP_MD_CTX_destroy +#endif + // Maximum value of RollbackAllowedMilestones policy. const int kMaxRollbackAllowedMilestones = 4; @@ -54,36 +62,34 @@ bool ReadPublicKeyFromFile(const base::FilePath& key_file, bool VerifySignature(const std::string& signed_data, const std::string& signature, const std::string& public_key) { - EVP_MD_CTX ctx; - EVP_MD_CTX_init(&ctx); + std::unique_ptr<EVP_MD_CTX, void (*)(EVP_MD_CTX *)> ctx(EVP_MD_CTX_new(), + EVP_MD_CTX_free); + if (!ctx) + return false; const EVP_MD* digest = EVP_sha1(); char* key = const_cast<char*>(public_key.data()); BIO* bio = BIO_new_mem_buf(key, public_key.length()); - if (!bio) { - EVP_MD_CTX_cleanup(&ctx); + if (!bio) return false; - } EVP_PKEY* public_key_ssl = d2i_PUBKEY_bio(bio, nullptr); if (!public_key_ssl) { BIO_free_all(bio); - EVP_MD_CTX_cleanup(&ctx); return false; } const unsigned char* sig = reinterpret_cast<const unsigned char*>(signature.data()); - int rv = EVP_VerifyInit_ex(&ctx, digest, nullptr); + int rv = EVP_VerifyInit_ex(ctx.get(), digest, nullptr); if (rv == 1) { - EVP_VerifyUpdate(&ctx, signed_data.data(), signed_data.length()); - rv = EVP_VerifyFinal(&ctx, sig, signature.length(), public_key_ssl); + EVP_VerifyUpdate(ctx.get(), signed_data.data(), signed_data.length()); + rv = EVP_VerifyFinal(ctx.get(), sig, signature.length(), public_key_ssl); } EVP_PKEY_free(public_key_ssl); BIO_free_all(bio); - EVP_MD_CTX_cleanup(&ctx); return rv == 1; } @@ -95,7 +101,7 @@ std::string DecodeConnectionType(int type) { "ethernet", "wifi", "wimax", "bluetooth", "cellular", }; - if (type < 0 || type >= static_cast<int>(arraysize(kConnectionTypes))) + if (type < 0 || type >= static_cast<int>(base::size(kConnectionTypes))) return std::string(); return kConnectionTypes[type]; @@ -196,6 +202,17 @@ bool DevicePolicyImpl::LoadPolicy() { return policy_loaded; } +bool DevicePolicyImpl::IsEnterpriseEnrolled() const { + DCHECK(install_attributes_reader_); + if (!install_attributes_reader_->IsLocked()) + return false; + + const std::string& device_mode = install_attributes_reader_->GetAttribute( + InstallAttributesReader::kAttrMode); + return device_mode == InstallAttributesReader::kDeviceModeEnterprise || + device_mode == InstallAttributesReader::kDeviceModeEnterpriseAD; +} + bool DevicePolicyImpl::GetPolicyRefreshRate(int* rate) const { if (!device_policy_.has_device_policy_refresh_rate()) return false; @@ -331,6 +348,9 @@ bool DevicePolicyImpl::GetReleaseChannelDelegated( } bool DevicePolicyImpl::GetUpdateDisabled(bool* update_disabled) const { + if (!IsEnterpriseEnrolled()) + return false; + if (!device_policy_.has_auto_update_settings()) return false; @@ -345,6 +365,9 @@ bool DevicePolicyImpl::GetUpdateDisabled(bool* update_disabled) const { bool DevicePolicyImpl::GetTargetVersionPrefix( std::string* target_version_prefix) const { + if (!IsEnterpriseEnrolled()) + return false; + if (!device_policy_.has_auto_update_settings()) return false; @@ -374,14 +397,7 @@ bool DevicePolicyImpl::GetRollbackToTargetVersion( bool DevicePolicyImpl::GetRollbackAllowedMilestones( int* rollback_allowed_milestones) const { // This policy can be only set for devices which are enterprise enrolled. - if (!install_attributes_reader_->IsLocked()) - return false; - if (install_attributes_reader_->GetAttribute( - InstallAttributesReader::kAttrMode) != - InstallAttributesReader::kDeviceModeEnterprise && - install_attributes_reader_->GetAttribute( - InstallAttributesReader::kAttrMode) != - InstallAttributesReader::kDeviceModeEnterpriseAD) + if (!IsEnterpriseEnrolled()) return false; if (device_policy_.has_auto_update_settings()) { @@ -398,8 +414,9 @@ bool DevicePolicyImpl::GetRollbackAllowedMilestones( } } // Policy is not present, use default for enterprise devices. - VLOG(1) << "RollbackAllowedMilestones policy is not set, using default 0."; - *rollback_allowed_milestones = 0; + VLOG(1) << "RollbackAllowedMilestones policy is not set, using default " + << kMaxRollbackAllowedMilestones << "."; + *rollback_allowed_milestones = kMaxRollbackAllowedMilestones; return true; } @@ -419,6 +436,9 @@ bool DevicePolicyImpl::GetScatterFactorInSeconds( bool DevicePolicyImpl::GetAllowedConnectionTypesForUpdate( std::set<std::string>* connection_types) const { + if (!IsEnterpriseEnrolled()) + return false; + if (!device_policy_.has_auto_update_settings()) return false; @@ -541,9 +561,9 @@ bool DevicePolicyImpl::GetDeviceUpdateStagingSchedule( if (!list_val) return false; - for (base::Value* const& pair_value : *list_val) { - base::DictionaryValue* day_percentage_pair; - if (!pair_value->GetAsDictionary(&day_percentage_pair)) + for (const auto& pair_value : *list_val) { + const base::DictionaryValue* day_percentage_pair; + if (!pair_value.GetAsDictionary(&day_percentage_pair)) return false; int days, percentage; if (!day_percentage_pair->GetInteger("days", &days) || @@ -616,6 +636,8 @@ bool DevicePolicyImpl::GetSecondFactorAuthenticationMode(int* mode_out) const { bool DevicePolicyImpl::GetDisallowedTimeIntervals( std::vector<WeeklyTimeInterval>* intervals_out) const { intervals_out->clear(); + if (!IsEnterpriseEnrolled()) + return false; if (!device_policy_.has_auto_update_settings()) { return false; @@ -633,14 +655,14 @@ bool DevicePolicyImpl::GetDisallowedTimeIntervals( if (!list_val) return false; - for (base::Value* const& interval_value : *list_val) { - base::DictionaryValue* interval_dict; - if (!interval_value->GetAsDictionary(&interval_dict)) { + for (const auto& interval_value : *list_val) { + const base::DictionaryValue* interval_dict; + if (!interval_value.GetAsDictionary(&interval_dict)) { LOG(ERROR) << "Invalid JSON string given. Interval is not a dict."; return false; } - base::DictionaryValue* start; - base::DictionaryValue* end; + const base::DictionaryValue* start; + const base::DictionaryValue* end; if (!interval_dict->GetDictionary("start", &start) || !interval_dict->GetDictionary("end", &end)) { LOG(ERROR) << "Interval is missing start/end."; @@ -659,6 +681,29 @@ bool DevicePolicyImpl::GetDisallowedTimeIntervals( return true; } +bool DevicePolicyImpl::GetDeviceQuickFixBuildToken( + std::string* device_quick_fix_build_token) const { + if (!IsEnterpriseEnrolled() || !device_policy_.has_auto_update_settings()) + return false; + + const em::AutoUpdateSettingsProto& proto = + device_policy_.auto_update_settings(); + if (!proto.has_device_quick_fix_build_token()) + return false; + + *device_quick_fix_build_token = proto.device_quick_fix_build_token(); + return true; +} + +bool DevicePolicyImpl::GetDeviceDirectoryApiId( + std::string* directory_api_id_out) const { + if (!policy_data_.has_directory_api_id()) + return false; + + *directory_api_id_out = policy_data_.directory_api_id(); + return true; +} + bool DevicePolicyImpl::VerifyPolicyFile(const base::FilePath& policy_path) { if (!verify_root_ownership_) { return true; diff --git a/policy/device_policy_impl.h b/policy/device_policy_impl.h index 6891312..47426df 100644 --- a/policy/device_policy_impl.h +++ b/policy/device_policy_impl.h @@ -40,6 +40,7 @@ class DevicePolicyImpl : public DevicePolicy { // DevicePolicy overrides: bool LoadPolicy() override; + bool IsEnterpriseEnrolled() const override; bool GetPolicyRefreshRate(int* rate) const override; bool GetUserWhitelist( std::vector<std::string>* user_whitelist) const override; @@ -83,6 +84,10 @@ class DevicePolicyImpl : public DevicePolicy { std::vector<WeeklyTimeInterval>* intervals_out) const override; bool GetDeviceUpdateStagingSchedule( std::vector<DayPercentagePair> *staging_schedule_out) const override; + bool GetDeviceQuickFixBuildToken( + std::string* device_quick_fix_build_token) const override; + bool GetDeviceDirectoryApiId( + std::string* device_directory_api_out) const override; // Methods that can be used only for testing. void set_policy_data_for_testing( diff --git a/policy/libpolicy.cc b/policy/libpolicy.cc index a0b7640..e972814 100644 --- a/policy/libpolicy.cc +++ b/policy/libpolicy.cc @@ -5,6 +5,7 @@ #include "policy/libpolicy.h" #include <memory> +#include <utility> #include <base/logging.h> diff --git a/policy/mock_device_policy.h b/policy/mock_device_policy.h index 90470e2..8bf4b07 100644 --- a/policy/mock_device_policy.h +++ b/policy/mock_device_policy.h @@ -52,62 +52,73 @@ class MockDevicePolicy : public DevicePolicy { } ~MockDevicePolicy() override = default; - MOCK_METHOD0(LoadPolicy, bool(void)); + MOCK_METHOD(bool, LoadPolicy, (), (override)); + MOCK_METHOD(bool, IsEnterpriseEnrolled, (), (const, override)); - MOCK_CONST_METHOD1(GetPolicyRefreshRate, - bool(int*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetUserWhitelist, bool(std::vector<std::string>*)); - MOCK_CONST_METHOD1(GetGuestModeEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetCameraEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetShowUserNames, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetDataRoamingEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetAllowNewUsers, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetMetricsEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetReportVersionInfo, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetReportActivityTimes, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetReportBootMode, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetEphemeralUsersEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetReleaseChannel, bool(std::string*)); - MOCK_CONST_METHOD1(GetReleaseChannelDelegated, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetUpdateDisabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetTargetVersionPrefix, bool(std::string*)); - MOCK_CONST_METHOD1(GetRollbackToTargetVersion, bool(int*)); - MOCK_CONST_METHOD1(GetRollbackAllowedMilestones, bool(int*)); - MOCK_CONST_METHOD1(GetScatterFactorInSeconds, - bool(int64_t*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetAllowedConnectionTypesForUpdate, - bool(std::set<std::string>*)); - MOCK_CONST_METHOD1(GetOpenNetworkConfiguration, bool(std::string*)); - MOCK_CONST_METHOD1(GetOwner, bool(std::string*)); - MOCK_CONST_METHOD1(GetHttpDownloadsEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetAuP2PEnabled, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetAllowKioskAppControlChromeVersion, - bool(bool*)); // NOLINT(readability/function) - MOCK_CONST_METHOD1(GetUsbDetachableWhitelist, - bool(std::vector<DevicePolicy::UsbDeviceId>*)); - MOCK_CONST_METHOD1(GetAutoLaunchedKioskAppId, bool(std::string*)); - MOCK_CONST_METHOD0(IsEnterpriseManaged, bool()); - MOCK_CONST_METHOD1(GetSecondFactorAuthenticationMode, bool(int*)); - MOCK_CONST_METHOD1(GetDisallowedTimeIntervals, - bool(std::vector<WeeklyTimeInterval>*)); - MOCK_CONST_METHOD1(GetDeviceUpdateStagingSchedule, - bool(std::vector<DayPercentagePair>*)); - MOCK_METHOD0(VerifyPolicyFiles, bool(void)); - MOCK_METHOD0(VerifyPolicySignature, bool(void)); + MOCK_METHOD(bool, GetPolicyRefreshRate, (int*), (const, override)); + MOCK_METHOD(bool, + GetUserWhitelist, + (std::vector<std::string>*), + (const, override)); + MOCK_METHOD(bool, GetGuestModeEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetCameraEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetShowUserNames, (bool*), (const, override)); + MOCK_METHOD(bool, GetDataRoamingEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetAllowNewUsers, (bool*), (const, override)); + MOCK_METHOD(bool, GetMetricsEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetReportVersionInfo, (bool*), (const, override)); + MOCK_METHOD(bool, GetReportActivityTimes, (bool*), (const, override)); + MOCK_METHOD(bool, GetReportBootMode, (bool*), (const, override)); + MOCK_METHOD(bool, GetEphemeralUsersEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetReleaseChannel, (std::string*), (const, override)); + MOCK_METHOD(bool, GetReleaseChannelDelegated, (bool*), (const, override)); + MOCK_METHOD(bool, GetUpdateDisabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetTargetVersionPrefix, (std::string*), (const, override)); + MOCK_METHOD(bool, GetRollbackToTargetVersion, (int*), (const, override)); + MOCK_METHOD(bool, GetRollbackAllowedMilestones, (int*), (const, override)); + MOCK_METHOD(bool, GetScatterFactorInSeconds, (int64_t*), (const, override)); + MOCK_METHOD(bool, + GetAllowedConnectionTypesForUpdate, + (std::set<std::string>*), + (const, override)); + MOCK_METHOD(bool, + GetOpenNetworkConfiguration, + (std::string*), + (const, override)); + MOCK_METHOD(bool, GetOwner, (std::string*), (const, override)); + MOCK_METHOD(bool, GetHttpDownloadsEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, GetAuP2PEnabled, (bool*), (const, override)); + MOCK_METHOD(bool, + GetAllowKioskAppControlChromeVersion, + (bool*), + (const, override)); + MOCK_METHOD(bool, + GetUsbDetachableWhitelist, + (std::vector<DevicePolicy::UsbDeviceId>*), + (const, override)); + MOCK_METHOD(bool, + GetAutoLaunchedKioskAppId, + (std::string*), + (const, override)); + MOCK_METHOD(bool, IsEnterpriseManaged, (), (const, override)); + MOCK_METHOD(bool, + GetSecondFactorAuthenticationMode, + (int*), + (const, override)); + MOCK_METHOD(bool, + GetDisallowedTimeIntervals, + (std::vector<WeeklyTimeInterval>*), + (const, override)); + MOCK_METHOD(bool, + GetDeviceUpdateStagingSchedule, + (std::vector<DayPercentagePair>*), + (const, override)); + MOCK_METHOD(bool, + GetDeviceQuickFixBuildToken, + (std::string*), + (const, override)); + MOCK_METHOD(bool, GetDeviceDirectoryApiId, (std::string*), (const, override)); + MOCK_METHOD(bool, VerifyPolicySignature, (), (override)); }; } // namespace policy diff --git a/policy/mock_libpolicy.h b/policy/mock_libpolicy.h index a0f6920..a04af7b 100644 --- a/policy/mock_libpolicy.h +++ b/policy/mock_libpolicy.h @@ -20,10 +20,10 @@ class MockPolicyProvider : public PolicyProvider { MockPolicyProvider() = default; ~MockPolicyProvider() override = default; - MOCK_METHOD0(Reload, bool(void)); - MOCK_CONST_METHOD0(device_policy_is_loaded, bool(void)); - MOCK_CONST_METHOD0(GetDevicePolicy, const DevicePolicy&(void)); - MOCK_CONST_METHOD0(IsConsumerDevice, bool(void)); + MOCK_METHOD(bool, Reload, (), (override)); + MOCK_METHOD(bool, device_policy_is_loaded, (), (const, override)); + MOCK_METHOD(const DevicePolicy&, GetDevicePolicy, (), (const, override)); + MOCK_METHOD(bool, IsConsumerDevice, (), (const, override)); private: DISALLOW_COPY_AND_ASSIGN(MockPolicyProvider); diff --git a/policy/tests/device_policy_impl_unittest.cc b/policy/tests/device_policy_impl_test.cc index 37c3916..2e68eb7 100644 --- a/policy/tests/device_policy_impl_unittest.cc +++ b/policy/tests/device_policy_impl_test.cc @@ -22,8 +22,8 @@ class DevicePolicyImplTest : public testing::Test, public DevicePolicyImpl { const em::ChromeDeviceSettingsProto& proto) { device_policy_.set_policy_for_testing(proto); device_policy_.set_install_attributes_for_testing( - std::make_unique<MockInstallAttributesReader>( - device_mode, true /* initialized */)); + std::make_unique<MockInstallAttributesReader>(device_mode, + true /* initialized */)); } DevicePolicyImpl device_policy_; @@ -108,7 +108,7 @@ TEST_F(DevicePolicyImplTest, GetRollbackAllowedMilestones_NotSet) { int value = -1; ASSERT_TRUE(device_policy_.GetRollbackAllowedMilestones(&value)); - EXPECT_EQ(0, value); + EXPECT_EQ(4, value); } // RollbackAllowedMilestones is set to a valid value. @@ -183,7 +183,7 @@ TEST_F(DevicePolicyImplTest, GetRollbackAllowedMilestones_SetTooSmall) { // Update staging schedule has no values TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_NoValues) { em::ChromeDeviceSettingsProto device_policy_proto; - em::AutoUpdateSettingsProto *auto_update_settings = + em::AutoUpdateSettingsProto* auto_update_settings = device_policy_proto.mutable_auto_update_settings(); auto_update_settings->set_staging_schedule("[]"); InitializePolicy(InstallAttributesReader::kDeviceModeEnterprise, @@ -197,7 +197,7 @@ TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_NoValues) { // Update staging schedule has valid values TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_Valid) { em::ChromeDeviceSettingsProto device_policy_proto; - em::AutoUpdateSettingsProto *auto_update_settings = + em::AutoUpdateSettingsProto* auto_update_settings = device_policy_proto.mutable_auto_update_settings(); auto_update_settings->set_staging_schedule( "[{\"days\": 4, \"percentage\": 40}, {\"days\": 10, \"percentage\": " @@ -214,7 +214,7 @@ TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_Valid) { // Update staging schedule has valid values, set using AD. TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_Valid_AD) { em::ChromeDeviceSettingsProto device_policy_proto; - em::AutoUpdateSettingsProto *auto_update_settings = + em::AutoUpdateSettingsProto* auto_update_settings = device_policy_proto.mutable_auto_update_settings(); auto_update_settings->set_staging_schedule( "[{\"days\": 4, \"percentage\": 40}, {\"days\": 10, \"percentage\": " @@ -233,7 +233,7 @@ TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_Valid_AD) { TEST_F(DevicePolicyImplTest, GetDeviceUpdateStagingSchedule_SetOutsideAllowable) { em::ChromeDeviceSettingsProto device_policy_proto; - em::AutoUpdateSettingsProto *auto_update_settings = + em::AutoUpdateSettingsProto* auto_update_settings = device_policy_proto.mutable_auto_update_settings(); auto_update_settings->set_staging_schedule( "[{\"days\": -1, \"percentage\": -10}, {\"days\": 30, \"percentage\": " @@ -243,8 +243,118 @@ TEST_F(DevicePolicyImplTest, std::vector<DayPercentagePair> staging_schedule; ASSERT_TRUE(device_policy_.GetDeviceUpdateStagingSchedule(&staging_schedule)); - EXPECT_THAT(staging_schedule, ElementsAre(DayPercentagePair{1, 0}, - DayPercentagePair{28, 100})); + EXPECT_THAT(staging_schedule, + ElementsAre(DayPercentagePair{1, 0}, DayPercentagePair{28, 100})); +} + +// Updates should only be disabled for enterprise managed devices. +TEST_F(DevicePolicyImplTest, GetUpdateDisabled_SetConsumer) { + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_update_disabled(true); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + + bool value; + ASSERT_FALSE(device_policy_.GetUpdateDisabled(&value)); +} + +// Updates should only be pinned on enterprise managed devices. +TEST_F(DevicePolicyImplTest, GetTargetVersionPrefix_SetConsumer) { + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_target_version_prefix("hello"); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + + std::string value = ""; + ASSERT_FALSE(device_policy_.GetTargetVersionPrefix(&value)); +} + +// The allowed connection types should only be changed in enterprise devices. +TEST_F(DevicePolicyImplTest, GetAllowedConnectionTypesForUpdate_SetConsumer) { + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->add_allowed_connection_types( + em::AutoUpdateSettingsProto::CONNECTION_TYPE_ETHERNET); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + + std::set<std::string> value; + ASSERT_FALSE(device_policy_.GetAllowedConnectionTypesForUpdate(&value)); +} + +// Update time restrictions should only be used in enterprise devices. +TEST_F(DevicePolicyImplTest, GetDisallowedTimeIntervals_SetConsumer) { + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_disallowed_time_intervals( + "[{\"start\": {\"day_of_week\": \"Monday\", \"hours\": 10, \"minutes\": " + "0}, \"end\": {\"day_of_week\": \"Monday\", \"hours\": 10, \"minutes\": " + "0}}]"); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + + std::vector<WeeklyTimeInterval> value; + ASSERT_FALSE(device_policy_.GetDisallowedTimeIntervals(&value)); +} + +// |DeviceQuickFixBuildToken| is set when device is enterprise enrolled. +TEST_F(DevicePolicyImplTest, GetDeviceQuickFixBuildToken_Set) { + const char kToken[] = "some_token"; + + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_device_quick_fix_build_token(kToken); + InitializePolicy(InstallAttributesReader::kDeviceModeEnterprise, + device_policy_proto); + std::string value; + EXPECT_TRUE(device_policy_.GetDeviceQuickFixBuildToken(&value)); + EXPECT_EQ(value, kToken); +} + +// If the device is not enterprise-enrolled, |GetDeviceQuickFixBuildToken| +// does not provide a token even if it is present in local device settings. +TEST_F(DevicePolicyImplTest, GetDeviceQuickFixBuildToken_NotSet) { + const char kToken[] = "some_token"; + + em::ChromeDeviceSettingsProto device_policy_proto; + em::AutoUpdateSettingsProto* auto_update_settings = + device_policy_proto.mutable_auto_update_settings(); + auto_update_settings->set_device_quick_fix_build_token(kToken); + InitializePolicy(InstallAttributesReader::kDeviceModeConsumer, + device_policy_proto); + std::string value; + EXPECT_FALSE(device_policy_.GetDeviceQuickFixBuildToken(&value)); + EXPECT_TRUE(value.empty()); +} + +// Should only write a value and return true if the ID is present. +TEST_F(DevicePolicyImplTest, GetDeviceDirectoryApiId_Set) { + constexpr char kDummyDeviceId[] = "aa-bb-cc-dd"; + + em::PolicyData policy_data; + policy_data.set_directory_api_id(kDummyDeviceId); + + device_policy_.set_policy_data_for_testing(policy_data); + + std::string id; + EXPECT_TRUE(device_policy_.GetDeviceDirectoryApiId(&id)); + EXPECT_EQ(kDummyDeviceId, id); +} + +TEST_F(DevicePolicyImplTest, GetDeviceDirectoryApiId_NotSet) { + em::PolicyData policy_data; + device_policy_.set_policy_data_for_testing(policy_data); + + std::string id; + EXPECT_FALSE(device_policy_.GetDeviceDirectoryApiId(&id)); + EXPECT_TRUE(id.empty()); } } // namespace policy diff --git a/policy/tests/libpolicy_unittest.cc b/policy/tests/libpolicy_test.cc index aaf497c..b8414bb 100644 --- a/policy/tests/libpolicy_unittest.cc +++ b/policy/tests/libpolicy_test.cc @@ -132,7 +132,7 @@ TEST(PolicyTest, DevicePolicyAllSetTest) { int_value = -1; ASSERT_TRUE(policy.GetRollbackToTargetVersion(&int_value)); EXPECT_EQ(enterprise_management::AutoUpdateSettingsProto:: - ROLLBACK_WITH_FULL_POWERWASH, + ROLLBACK_AND_POWERWASH, int_value); int_value = -1; @@ -243,10 +243,10 @@ TEST(PolicyTest, DevicePolicyNoneSetTest) { EXPECT_FALSE(policy.GetUpdateDisabled(&bool_value)); EXPECT_FALSE(policy.GetTargetVersionPrefix(&string_value)); EXPECT_FALSE(policy.GetRollbackToTargetVersion(&int_value)); - // RollbackAllowedMilestones has the default value of 0 for enterprise + // RollbackAllowedMilestones has the default value of 4 for enterprise // devices. ASSERT_TRUE(policy.GetRollbackAllowedMilestones(&int_value)); - EXPECT_EQ(0, int_value); + EXPECT_EQ(4, int_value); EXPECT_FALSE(policy.GetScatterFactorInSeconds(&int64_value)); EXPECT_FALSE(policy.GetOpenNetworkConfiguration(&string_value)); EXPECT_FALSE(policy.GetHttpDownloadsEnabled(&bool_value)); @@ -358,8 +358,3 @@ TEST(PolicyTest, IsConsumerDeviceEnterpriseAd) { } } // namespace policy - -int main(int argc, char* argv[]) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/policy/tests/policy_util_unittest.cc b/policy/tests/policy_util_test.cc index f26622f..f26622f 100644 --- a/policy/tests/policy_util_unittest.cc +++ b/policy/tests/policy_util_test.cc diff --git a/policy/tests/resilient_policy_util_unittest.cc b/policy/tests/resilient_policy_util_test.cc index 0963b08..0963b08 100644 --- a/policy/tests/resilient_policy_util_unittest.cc +++ b/policy/tests/resilient_policy_util_test.cc |