diff options
author | Colin Cross <ccross@android.com> | 2021-08-26 21:29:35 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2021-08-26 21:29:35 +0000 |
commit | 8c1d6562996fd78350510c66e282dcb50caa1655 (patch) | |
tree | 08f815bfcdaed1c2137de2d6966e42a8870dedb0 | |
parent | 8beec73fbbcaaab6785c3f9cc5a9661afcbc7fa7 (diff) | |
parent | 6b563cff6fca9ffdd14a12d33795fef8e2562a26 (diff) | |
download | openscreen-8c1d6562996fd78350510c66e282dcb50caa1655.tar.gz |
Upgrade openscreen to f54d92523c9f2c8c5afb99e05fed70e4b8772b1c am: 6b563cff6fandroid-s-v2-preview-2android-s-v2-preview-1android-s-v2-beta-2android-s-v2-preview-1
Original change: https://android-review.googlesource.com/c/platform/external/openscreen/+/1810936
Change-Id: I4af446c4bf4a92bc6fe1c6da2ac036a530aee809
320 files changed, 9412 insertions, 10684 deletions
@@ -1,2 +1,3 @@ # The location of the build configuration file. buildconfig = "//build/config/BUILDCONFIG.gn" +script_executable = "python3" @@ -16,7 +16,6 @@ group("gn_all") { "cast/sender:channel", "cast/streaming:receiver", "cast/streaming:sender", - "discovery:common", "discovery:dnssd", "discovery:mdns", "discovery:public", @@ -37,10 +36,6 @@ group("gn_all") { "osp/msgs", ] - if (use_mdns_responder) { - deps += [ "osp/impl/discovery/mdns:mdns_demo" ] - } - if (use_chromium_quic) { deps += [ "third_party/chromium_quic", @@ -49,7 +44,7 @@ group("gn_all") { ] } - if (use_chromium_quic && use_mdns_responder) { + if (use_chromium_quic) { deps += [ "osp:osp_demo" ] } } @@ -61,6 +56,10 @@ group("gn_all") { "third_party/protobuf:protoc($host_toolchain)", "third_party/zlib", ] + } else { + if (!is_mac) { + deps += [ "cast/cast_core/api" ] + } } } @@ -91,15 +90,6 @@ source_set("openscreen_unittests_all") { "osp:unittests", "osp/msgs:unittests", ] - - if (use_mdns_responder) { - public_deps += [ - "osp/impl/discovery/mdns:unittests", - - # Currently this target only includes mDNS tests. - "osp/impl/testing:unittests", - ] - } } } @@ -131,3 +121,16 @@ if (!build_with_chromium && is_posix) { ] } } + +if (!build_with_chromium) { + source_set("fuzzer_tests_all") { + testonly = true + deps = [ + "//cast/common:message_framer_fuzzer", + "//cast/streaming:compound_rtcp_parser_fuzzer", + "//cast/streaming:rtp_packet_parser_fuzzer", + "//cast/streaming:sender_report_parser_fuzzer", + "//discovery:mdns_fuzzer", + ] + } +} @@ -1,9 +1,8 @@ -# Primary reviewers mfoltz@chromium.org btolsch@chromium.org - -# Additional reviewers jopbha@chromium.org -miu@chromium.org -rwkeane@chromium.org -yakimakha@chromium.org +rwkeane@google.com +takumif@chromium.org + +# Add for LUCI configuration changes. +cliffordcheng@chromium.org @@ -12,6 +12,8 @@ use_relative_paths = True vars = { 'boringssl_git': 'https://boringssl.googlesource.com', 'chromium_git': 'https://chromium.googlesource.com', + 'quiche_git': 'https://quiche.googlesource.com', + 'aomedia_git': 'https://aomedia.googlesource.com', # NOTE: we should only reference GitHub directly for dependencies toggled # with the "not build_with_chromium" condition. @@ -29,6 +31,10 @@ vars = { # TODO(issuetracker.google.com/155195126): Change this to False and update # buildbot to call tools/download-clang-update-script.py instead. 'checkout_clang_coverage_tools': True, + + # GN CIPD package version. + 'gn_version': 'git_revision:39a87c0b36310bdf06b692c098f199a0d97fc810', + 'clang_format_revision': '99803d74e35962f63a775f29477882afd4d57d94', } deps = { @@ -36,24 +42,49 @@ deps = { # of the commits to the buildtools directory in the Chromium repository. This # should be regularly updated with the tip of the MIRRORED master branch, # found here: - # https://chromium.googlesource.com/chromium/src/buildtools/+/refs/heads/master. + # https://chromium.googlesource.com/chromium/src/buildtools/+/refs/heads/main. 'buildtools': { - 'url': Var('chromium_git')+ '/chromium/src/buildtools' + - '@' + '6302c1175607a436e18947a5abe9df2209e845fc', + 'url': Var('chromium_git') + '/chromium/src/buildtools' + + '@' + 'fba2905150c974240f14aa5334c3e5c93f873032', 'condition': 'not build_with_chromium', }, - + 'buildtools/clang_format/script': { + 'url': Var('chromium_git') + + '/external/github.com/llvm/llvm-project/clang/tools/clang-format.git' + + '@' + Var('clang_format_revision'), + 'condition': 'not build_with_chromium', + }, + 'buildtools/linux64': { + 'packages': [ + { + 'package': 'gn/gn/linux-amd64', + 'version': Var('gn_version'), + } + ], + 'dep_type': 'cipd', + 'condition': 'host_os == "linux" and not build_with_chromium', + }, + 'buildtools/mac': { + 'packages': [ + { + 'package': 'gn/gn/mac-${{arch}}', + 'version': Var('gn_version'), + } + ], + 'dep_type': 'cipd', + 'condition': 'host_os == "mac" and not build_with_chromium', + }, 'third_party/protobuf/src': { 'url': Var('chromium_git') + '/external/github.com/protocolbuffers/protobuf.git' + - '@' + '2514f0bd7da7e2af1bed4c5d1b84f031c4d12c10', # version 3.14 + '@' + '909a0f36a10075c4b4bc70fdee2c7e32dd612a72', # version 3.17.3 'condition': 'not build_with_chromium', }, 'third_party/libprotobuf-mutator/src': { 'url': Var('chromium_git') + '/external/github.com/google/libprotobuf-mutator.git' + - '@' + 'e5869dd9690c3f4dfb842fb90bd07a5a9ee32172', + '@' + '8942a9ba43d8bb196230c321d46d6a137957a719', 'condition': 'not build_with_chromium', }, @@ -78,14 +109,6 @@ deps = { 'condition': 'not build_with_chromium', }, - 'third_party/mDNSResponder/src': { - # NOTE: this fork of mDNSResponder is ancient (9 years old), but since - # we are moving away from mDNSResponder we will not be updating this. - 'url': Var('github') + '/jevinskie/mDNSResponder.git' + - '@' + '2942dde61f920fbbf96ff9a3840567ebbe7cb1b6', - 'condition': 'not build_with_chromium', - }, - # Note about updating BoringSSL: after changing this hash, run the update # script in BoringSSL's util folder for generating build files from the # <openscreen src-dir>/third_party/boringssl directory: @@ -98,7 +121,14 @@ deps = { 'third_party/chromium_quic/src': { 'url': Var('chromium_git') + '/openscreen/quic.git' + - '@' + '444faf6e3ae0dcade48438144f7e8ea2f8b3436d', + '@' + '79eec3fc28f5c4e1d06c6146825e31def6e3b793', + 'condition': 'not build_with_chromium', + }, + + # To roll forward, use quiche_revision from chromium/src/DEPS. + 'third_party/quiche/src': { + 'url': Var('quiche_git') + '/quiche.git' + + '@' + '51f584db29001036c20db3f72f09b00b875ae625', 'condition': 'not build_with_chromium', }, @@ -133,6 +163,13 @@ deps = { 'url': Var('github') + '/tristanpenman/valijson.git' + '@' + 'cf648930313655b19dc07ebae2f9c3fc37966a33', # Tip-of-tree 'condition': 'not build_with_chromium' + }, + + # Keep in sync with third_party/libaom/source/libaom in Chromium DEPS + 'third_party/aomedia/src': { + 'url': Var('aomedia_git') + '/aom.git' + + '@' + 'bb20160fbdd8226e7904541c8da70b91703e62b8', + 'condition': 'not build_with_chromium' } } @@ -197,6 +234,13 @@ include_rules = [ '+testing/util', '+third_party', + # Inter-module dependencies must be through public APIs. + '-discovery', + '+discovery/common', + '+discovery/dnssd/public', + '+discovery/mdns/public', + '+discovery/public', + # Don't include abseil from the root so the path can change via include_dirs # rules when in Chromium. '-third_party/abseil', @@ -9,11 +9,11 @@ third_party { type: GIT value: "https://chromium.googlesource.com/openscreen" } - version: "207f3b2b5814bbbe2530b3d0f8fb4da1665a02ce" + version: "f54d92523c9f2c8c5afb99e05fed70e4b8772b1c" license_type: NOTICE last_upgrade_date { year: 2021 - month: 4 - day: 1 + month: 8 + day: 26 } } diff --git a/PRESUBMIT.py b/PRESUBMIT.py index b4d7da3e..66f91fd7 100755 --- a/PRESUBMIT.py +++ b/PRESUBMIT.py @@ -15,21 +15,27 @@ sys.path.extend(os.path.join(_REPO_PATH, p) for p in _IMPORT_SUBFOLDERS) import licenses from checkdeps import DepsChecker -from cpp_checker import CppChecker -from rules import Rule + +# Opt-in to using Python3 instead of Python2, as part of the ongoing Python2 +# deprecation. For more information, see +# https://issuetracker.google.com/173766869. +USE_PYTHON3 = True # Rather than pass this to all of the checks, we override the global excluded # list with this one. _EXCLUDED_PATHS = ( # Exclude all of third_party/ except for BUILD.gns that we maintain. r'third_party[\\\/].*(?<!BUILD.gn)$', + # Exclude everything under third_party/chromium_quic/{src|build} r'third_party/chromium_quic/(src|build)/.*', + # Output directories (just in case) r'.*\bDebug[\\\/].*', r'.*\bRelease[\\\/].*', r'.*\bxcodebuild[\\\/].*', r'.*\bout[\\\/].*', + # There is no point in processing a patch file. r'.+\.diff$', r'.+\.patch$', @@ -125,7 +131,7 @@ def _CheckNoexceptOnMove(filename, clean_lines, linenum, error): def _CheckChangeLintsClean(input_api, output_api): """Checks that all '.cc' and '.h' files pass cpplint.py.""" cpplint = input_api.cpplint - # Access to a protected member _XX of a client class + # Directive that allows access to a protected member _XX of a client class. # pylint: disable=protected-access cpplint._cpplint_state.ResetErrorCounts() @@ -167,31 +173,35 @@ def _CommonChecks(input_api, output_api): input_api.canned_checks.CheckChangeHasNoCrAndHasOnlyOneEol( input_api, output_api)) - # Gender inclusivity + # Ensure code change is gender inclusive. results.extend( input_api.canned_checks.CheckGenderNeutral(input_api, output_api)) - # TODO(bug) format required + # Ensure code change to do items uses TODO(bug) or TODO(user) format. + # TODO(bug) is generally preferred. results.extend( input_api.canned_checks.CheckChangeTodoHasOwner(input_api, output_api)) - # Linter. + # Ensure code change passes linter cleanly. results.extend(_CheckChangeLintsClean(input_api, output_api)) - # clang-format + # Ensure code change has already had clang-format ran. results.extend( input_api.canned_checks.CheckPatchFormatted(input_api, output_api, bypass_warnings=False)) - # GN formatting + # Ensure code change has had GN formatting ran. results.extend( input_api.canned_checks.CheckGNFormatted(input_api, output_api)) - # buildtools/checkdeps + # Run buildtools/checkdeps on code change. results.extend(_CheckDeps(input_api, output_api)) - # tools/licenses + # Run tools/licenses on code change. + # TODO(https://crbug.com/1215335): licenses check is confused by our + # buildtools checkout that doesn't actually check out the libraries. + licenses.PRUNE_PATHS.add(os.path.join('buildtools', 'third_party')); results.extend(_CheckLicenses(input_api, output_api)) return results @@ -201,8 +211,9 @@ def CheckChangeOnUpload(input_api, output_api): input_api.DEFAULT_FILES_TO_SKIP = _EXCLUDED_PATHS # We always run the OnCommit checks, as well as some additional checks. results = CheckChangeOnCommit(input_api, output_api) - results.extend( - input_api.canned_checks.CheckChangedLUCIConfigs(input_api, output_api)) + # TODO(crbug.com/1220846): Open Screen needs a `main` config_set. + #results.extend( + # input_api.canned_checks.CheckChangedLUCIConfigs(input_api, output_api)) return results @@ -1,16 +1,16 @@ # Open Screen Library -The openscreen library implements the Open Screen Protocol and the Chromecast -protocols (both control and streaming). +The Open Screen Library implements the Open Screen Protocol and the Chromecast +protocols (discovery, application control, and media streaming). -Information about the protocol and its specification can be found [on -GitHub](https://github.com/webscreens/openscreenprotocol). +Information about the Open Screen Protocol and its specification can be found +[on GitHub](https://w3c.github.io/openscreenprotocol/). # Getting the code ## Installing depot_tools -openscreen library dependencies are managed using `gclient`, from the +Library dependencies are managed using `gclient`, from the [depot_tools](https://www.chromium.org/developers/how-tos/depottools) repo. To get gclient, run the following command in your terminal: @@ -21,8 +21,8 @@ To get gclient, run the following command in your terminal: Then add the `depot_tools` folder to your `PATH` environment variable. Note that openscreen does not use other features of `depot_tools` like `repo` or -`drover`. However, some `git-cl` functions *do* work, like `git cl try`, `git cl -lint` and `git cl upload.` +`drover`. However, some `git-cl` functions *do* work, like `git cl try`, +`git cl format`, `git cl lint`, and `git cl upload.` ## Checking out code @@ -44,7 +44,7 @@ and at their appropriate revisions. ## Syncing your local checkout -To update your local checkout from the openscreen master repository, just run +To update your local checkout from the openscreen reference repository, just run ```bash cd ~/my_project_dir/openscreen @@ -93,7 +93,8 @@ Setting the `gn` argument "is_gcc=true" on Linux enables building using gcc instead. ```bash - gn gen out/Default --args="is_gcc=true" + mkdir out/debug-gcc + gn gen out/debug-gcc --args="is_gcc=true" ``` Note that g++ version 7 or newer must be installed. On Debian flavors you can @@ -114,7 +115,7 @@ installed. Setting the `gn` argument "is_debug=true" enables debug build. ```bash - gn gen out/Default --args="is_debug=true" + gn gen out/debug --args="is_debug=true" ``` To install debug information for libstdc++ 8 on Debian flavors, you can run: @@ -129,30 +130,34 @@ Running `gn args` opens an editor that allows to create a list of arguments passed to every invocation of `gn gen`. ```bash - gn args out/Default + gn args out/debug ``` # Building targets -## Building the OSP demo +## Cast Streaming sender and receiver -The following commands will build a sample executable and run it. +TODO(jophba): Fill in details + +## OSP demo + +The following commands will build the Open Screen Protocol demo and run it. ``` bash - mkdir out/Default - gn gen out/Default # Creates the build directory and necessary ninja files - ninja -C out/Default demo # Builds the executable with ninja - ./out/Default/demo # Runs the executable + mkdir out/debug + gn gen out/debug # Creates the build directory and necessary ninja files + ninja -C out/debug osp_demo # Builds the executable with ninja + ./out/debug/osp_demo # Runs the executable ``` The `-C` argument to `ninja` works just like it does for GNU Make: it specifies the working directory for the build. So the same could be done as follows: ``` bash - ./gn gen out/Default - cd out/Default - ninja - ./demo + ./gn gen out/debug + cd out/debug + ninja osp_demo + ./osp_demo ``` After editing a file, only `ninja` needs to be rerun, not `gn`. If you have @@ -163,80 +168,41 @@ Unless you like to wait longer than necessary for builds to complete, run This will automatically parallelize the build for your system, depending on number of processor cores, RAM, etc. -For details on running `demo`, see its [README.md](demo/README.md). +For details on running `osp_demo`, see its [README.md](osp/demo/README.md). ## Building other targets -Running `ninja -C out/Default gn_all` will build all non-test targets in the +Running `ninja -C out/debug gn_all` will build all non-test targets in the repository. -`gn ls --type=executable out/Default/` will list all of the executable targets +`gn ls --type=executable out/debug` will list all of the executable targets that can be built. -If you want to customize the build further, you can run `gn args out/Default` to -pull up an editor for build flags. `gn args --list out/Default` prints all of +If you want to customize the build further, you can run `gn args out/debug` to +pull up an editor for build flags. `gn args --list out/debug` prints all of the build flags available. ## Building and running unit tests ```bash - ninja -C out/Default unittests - ./out/Default/unittests -``` - -## Building and running fuzzers - -In order to build fuzzers, you need the GN arg `use_libfuzzer=true`. It's also -recommended to build with `is_asan=true` to catch additional problems. Building -and running then might look like: -```bash - gn gen out/libfuzzer --args="use_libfuzzer=true is_asan=true is_debug=false" - ninja -C out/libfuzzer some_fuzz_target - out/libfuzzer/some_fuzz_target <args> <corpus_dir> [additional corpus dirs] + ninja -C out/debug openscreen_unittests + ./out/debug/openscreen_unittests ``` -The arguments to the fuzzer binary should be whatever is listed in the GN target -description (e.g. `-max_len=1500`). These arguments may be automatically -scraped by Chromium's ClusterFuzz tool when it runs fuzzers, but they are not -built into the target. You can also look at the file -`out/libfuzzer/some_fuzz_target.options` for what arguments should be used. The -`corpus_dir` is listed as `seed_corpus` in the GN definition of the fuzzer -target. - -# Continuous build and try jobs - -openscreen uses [LUCI builders](https://ci.chromium.org/p/openscreen/builders) -to monitor the build and test health of the library. Current builders include: - -| Name | Arch | OS | Toolchain | Build | Notes | -|------------------------|--------|--------------------|-----------|---------|------------------------| -| linux64_debug | x86-64 | Ubuntu Linux 16.04 | clang | debug | ASAN enabled | -| linux64_gcc_debug | x86-64 | Ubuntu Linux 18.04 | gcc-7 | debug | | -| linux64_tsan | x86-64 | Ubuntu Linux 16.04 | clang | release | TSAN enabled | -| mac_debug | x86-64 | Mac OS X/Xcode | clang | debug | | -| chromium_linux64_debug | x86-64 | Ubuntu Linux 16.04 | clang | debug | built within chromium | -| chromium_mac_debug | x86-64 | Mac OS X/Xcode | clang | debug | built within chromium | -| linux64_coverage_debug | x86-64 | Ubuntu Linux 16.04 | clang | debug | used for code coverage | - -You can run a patch through the try job queue (which tests it on all -non-chromium builders) using `git cl try`, or through Gerrit (details below). - -The chromium builders compile openscreen HEAD vs. chromium HEAD. They run as -experimental trybots and continuous-integration FYI bots. - -# Submitting changes +# Contributing changes -openscreen library code should follow the [Open Screen Library Style +Open Screen library code should follow the [Open Screen Library Style Guide](docs/style_guide.md). -openscreen uses [Chromium Gerrit](https://chromium-review.googlesource.com/) for -patch management and code review (for better or worse). +This library uses [Chromium Gerrit](https://chromium-review.googlesource.com/) for +patch management and code review (for better or worse). You will need to register +for an account at `chromium-review.googlesource.com` to upload patches for review. The following sections contain some tips about dealing with Gerrit for code reviews, specifically when pushing patches for review, getting patches reviewed, and committing patches. -## Uploading a patch for review +# Uploading a patch for review The `git cl` tool handles details of interacting with Gerrit (the Chromium code review tool) and is recommended for pushing patches for review. Once you have @@ -248,7 +214,7 @@ committed changes locally, simply run: ``` The first command will will auto-format the code changes. Then, the second -command runs the `PRESUBMIT.sh` script to check style and, if it passes, a +command runs the `PRESUBMIT.py` script to check style and, if it passes, a newcode review will be posted on `chromium-review.googlesource.com`. If you make additional commits to your local branch, then running `git cl @@ -282,82 +248,16 @@ Send your patch to one or more committers in the file for code review. All patches must receive at least one LGTM by a committer before it can be submitted. -## Submission +## Submitting patches After your patch has received one or more LGTM commit it by clicking the `SUBMIT` button (or, confusingly, `COMMIT QUEUE +2`) in Gerrit. This will run your patch through the builders again before committing to the main openscreen repository. -<!-- TODO(mfoltz): split up README.md into more manageable files. --> -## Working with ARM/ARM64/the Raspberry PI +# Additional resources -openscreen supports cross compilation for both arm32 and arm64 platforms, by -using the `gn args` parameter `target_cpu="arm"` or `target_cpu="arm64"` -respectively. Note that quotes are required around the target arch value. - -Setting an arm(64) target_cpu causes GN to pull down a sysroot from openscreen's -public cloud storage bucket. Google employees may update the sysroots stored -by requesting access to the Open Screen pantheon project and uploading a new -tar.xz to the openscreen-sysroots bucket. - -NOTE: The "arm" image is taken from Chromium's debian arm image, however it has -been manually patched to include support for libavcodec and libsdl2. To update -this image, the new image must be manually patched to include the necessary -header and library dependencies. Note that if the versions of libavcodec and -libsdl2 are too out of sync from the copies in the sysroot, compilation will -succeed, but you may experience issues decoding content. - -To install the last known good version of the libavcodec and libsdl packages -on a Raspberry Pi, you can run the following command: - -```bash -sudo ./cast/standalone_receiver/install_demo_deps_raspian.sh -``` - -NOTE: until [Issue 106](http://crbug.com/openscreen/106) is resolved, you may -experience issues streaming to a Raspberry Pi if multiple network interfaces -(e.g. WiFi + Ethernet) are enabled. The workaround is to disable either the WiFi -or ethernet connection. - -## Code Coverage - -Code coverage can be checked using clang's source-based coverage tools. You -must use the GN argument `use_coverage=true`. It's recommended to do this in a -separate output directory since the added instrumentation will affect -performance and generate an output file every time a binary is run. You can -read more about this in [clang's -documentation](http://clang.llvm.org/docs/SourceBasedCodeCoverage.html) but the -bare minimum steps are also outlined below. You will also need to download the -pre-built clang coverage tools, which are not downloaded by default. The -easiest way to do this is to set a custom variable in your `.gclient` file. -Under the "openscreen" solution, add: -```python - "custom_vars": { - "checkout_clang_coverage_tools": True, - }, -``` -then run `gclient runhooks`. You can also run the python command from the -`clang_coverage_tools` hook in `//DEPS` yourself or even download the tools -manually -([link](https://storage.googleapis.com/chromium-browser-clang-staging/)). - -Once you have your GN directory (we'll call it `out/coverage`) and have -downloaded the tools, do the following to generate an HTML coverage report: -```bash -out/coverage/openscreen_unittests -third_party/llvm-build/Release+Asserts/bin/llvm-profdata merge -sparse default.profraw -o foo.profdata -third_party/llvm-build/Release+Asserts/bin/llvm-cov show out/coverage/openscreen_unittests -instr-profile=foo.profdata -format=html -output-dir=<out dir> [filter paths] -``` -There are a few things to note here: - - `default.profraw` is generated by running the instrumented code, but - `foo.profdata` can be any path you want. - - `<out dir>` should be an empty directory for placing the generated HTML - files. You can view the report at `<out dir>/index.html`. - - `[filter paths]` is a list of paths to which you want to limit the coverage - report. For example, you may want to limit it to cast/ or even - cast/streaming/. If this list is empty, all data will be in the report. - -The same process can be used to check the coverage of a fuzzer's corpus. Just -add `-runs=0` to the fuzzer arguments to make sure it only runs the existing -corpus then exits. +* [Continuous builders](docs/continuous_build.md) +* [Building and running fuzz tests](docs/fuzzing.md) +* [Running on a Raspberry PI](docs/raspberry_pi.md) +* [Unit test code coverage](docs/code_coverage.md) diff --git a/build/code_coverage/merge_lib.py b/build/code_coverage/merge_lib.py index 4b956d06..ec951dc1 100644 --- a/build/code_coverage/merge_lib.py +++ b/build/code_coverage/merge_lib.py @@ -1,4 +1,4 @@ -#!/usr/bin/env/python +#!/usr/bin/env python3 # Copyright 2020 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. @@ -325,4 +325,3 @@ def get_shards_to_retry(bad_profiles): assert is_task_id(task_id) bad_shard_ids.add(task_id) return bad_shard_ids - diff --git a/build/code_coverage/merge_results.py b/build/code_coverage/merge_results.py index 67e63365..40bf7ca3 100644 --- a/build/code_coverage/merge_results.py +++ b/build/code_coverage/merge_results.py @@ -1,4 +1,4 @@ -#!/usr/bin/env/python +#!/usr/bin/env python3 # Copyright 2020 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/build/code_coverage/merge_steps.py b/build/code_coverage/merge_steps.py index f1140938..af876af9 100644 --- a/build/code_coverage/merge_steps.py +++ b/build/code_coverage/merge_steps.py @@ -1,4 +1,4 @@ -#!/usr/bin/env/python +#!/usr/bin/env python3 # Copyright 2020 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/build/config/data_headers_template.gni b/build/config/data_headers_template.gni index b50c4991..9a241694 100644 --- a/build/config/data_headers_template.gni +++ b/build/config/data_headers_template.gni @@ -6,6 +6,10 @@ # into C++ header files as constexpr char[] raw strings with variable names # taken directly from the original file name. +# The root directory must be defined outside of the template for use while +# embedded. +openscreen_root = rebase_path("../../", "//") + template("data_headers") { action_foreach(target_name) { forward_variables_from(invoker, @@ -14,7 +18,7 @@ template("data_headers") { "sources", "testonly", ]) - script = "../../tools/convert_to_data_file.py" + script = "//${openscreen_root}/tools/convert_to_data_file.py" outputs = [ "{{source_gen_dir}}/{{source_name_part}}_data.h" ] args = [ namespace, diff --git a/build/config/external_libraries.gni b/build/config/external_libraries.gni index a451add7..aa2364e9 100644 --- a/build/config/external_libraries.gni +++ b/build/config/external_libraries.gni @@ -2,11 +2,10 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. +# NOTE: Only add *_dirs declarations if the libraries have been installed at +# non-standard locations. See the related markdown for more information: +# [external_libraries.md](external_libraries.md). declare_args() { - # FFMPEG: If installed on the system, set have_ffmpeg to true. This also - # requires the FFMPEG headers be installed. On Debian-like systems, this can - # be done by running `cast/standalone_receiver/install_demo_deps_debian.sh` - # to install both FFMPEG and libSDL. have_ffmpeg = false ffmpeg_libs = [ "avcodec", @@ -14,35 +13,26 @@ declare_args() { "avutil", "swresample", ] - ffmpeg_include_dirs = [] # Add only if headers are at non-standard locations. - ffmpeg_lib_dirs = [] # Add only if libraries are at non-standard locations. + ffmpeg_include_dirs = [] + ffmpeg_lib_dirs = [] - # libopus: If installed on the system, set have_libopus to true. This also - # requires the libopus headers be installed. For example, on Debian-like - # systems, the following should install everything needed: - # - # sudo apt-get install libopus0 libopus-dev have_libopus = false libopus_libs = [ "opus" ] - libopus_include_dirs = [] # Add only if headers are at non-standard locations. - libopus_lib_dirs = [] # Add only if libraries are at non-standard locations. + libopus_include_dirs = [] + libopus_lib_dirs = [] - # libsdl2: If installed on the system, set have_libsdl2 to true. This also - # requires the libSDL2 headers be installed. On Debian-like systems, this can - # be done by running `cast/standalone_receiver/install_demo_deps_debian.sh` - # to install both FFMPEG and libSDL. have_libsdl2 = false libsdl2_libs = [ "SDL2" ] - libsdl2_include_dirs = [] # Add only if headers are at non-standard locations. - libsdl2_lib_dirs = [] # Add only if libraries are at non-standard locations. + libsdl2_include_dirs = [] + libsdl2_lib_dirs = [] - # libvpx: If installed on the system, set have_libvpx to true. This also - # requires the libvpx headers be installed. For example, on Debian-like - # systems, the following should install everything needed: - # - # sudo apt-get install libvpx5 libvpx-dev have_libvpx = false libvpx_libs = [ "vpx" ] - libvpx_include_dirs = [] # Add only if headers are at non-standard locations. - libvpx_lib_dirs = [] # Add only if libraries are at non-standard locations. + libvpx_include_dirs = [] + libvpx_lib_dirs = [] + + have_libaom = false + libaom_libs = [ "aom" ] + libaom_include_dirs = [] + libaom_lib_dirs = [] } diff --git a/build/config/external_libraries.md b/build/config/external_libraries.md new file mode 100644 index 00000000..30d7502e --- /dev/null +++ b/build/config/external_libraries.md @@ -0,0 +1,65 @@ +# External Libraries in Open Screen + +Currently, external libraries are used exclusively by the standalone sender and +receiver applications, for compiling in dependencies used for video decoding and +playback. + +The decision to link external libraries is made manually by setting the GN args. +For example, a developer wanting to link all the libraries for the standalone +sender and receiver executables might add the following to `gn args out/Default`: + +```python +is_debug=true +have_ffmpeg=true +have_libsdl2=true +have_libopus=true +have_libvpx=true +``` + +On some versions of Debian, the following apt-get command will install all of +the necessary external libraries for Open Screen: + +```sh +sudo apt-get install libsdl2-2.0 libsdl2-dev libavcodec libavcodec-dev + libavformat libavformat-dev libavutil libavutil-dev + libswresample libswresample-dev libopus0 libopus-dev + libvpx5 libvpx-dev +``` + +Similarly, on some versions of Raspian, the following command will install the +necessary external libraries, at least for the standalone receiver. Note that +this command is based off of the packages linked in the [sysroot](sysroot.gni): + +```sh +sudo apt-get install libavcodec58=7:4.1.4* libavcodec-dev=7:4.1.4* + libsdl2-2.0-0=2.0.9* libsdl2-dev=2.0.9* + libavformat-dev=7:4.1.4* +``` + +Note: release of these operating systems may require slightly different +packages, so these `sh` commands are merely a potential starting point. + +Finally, note that generally the headers for packages must also be installed. +In Debian Linux flavors, this usually means that the `*-dev` version of each +package must also be installed. In the example above, this looks like having +both `libavcodec` and `libavcodec-dev`. + +## Standalone Sender + +The standalone sender uses FFMPEG, LibOpus, and LibVpx for encoding video and +audio for sending. When the build has determined that [have_external_libs]( + ../../cast/standalone_sender/BUILD.gn +) is set to true, meaning that all of these libraries are installed, then +the VP8 and Opus encoders are enabled and actual video files can be sent +to standalone receiver instances. Without these dependencies, the standalone +sender cannot properly function (contrasted with the standalone receiver, +which can use a dummy player). + +## Standalone Receiver + +The standalone receiver also uses FFMPEG, for decoding the video stream encoded +by the sender, and also uses LibSDL2 to create a surface for decoding video. +Unlike the sender, the standalone receiver can work without having +its [have_external_libs](../.../cast/standalone_receiver/BUILD.gn) set to true, +through the use of its +[Dummy Player](../../cast/standalone_receiver/dummy_player.h). diff --git a/build/config/sysroot.gni b/build/config/sysroot.gni index 339bfd9b..587de5e4 100644 --- a/build/config/sysroot.gni +++ b/build/config/sysroot.gni @@ -29,7 +29,6 @@ if (use_sysroot) { if (exec_script("//build/scripts/dir_exists.py", [ rebase_path(sysroot) ], "string") != "True") { - print("Missing or outdated sysroot for $current_cpu, downloading latest...") exec_script("//build/scripts/install-sysroot.py", [ "$current_cpu", diff --git a/build/scripts/dir_exists.py b/build/scripts/dir_exists.py index 1e633d22..16400f58 100755 --- a/build/scripts/dir_exists.py +++ b/build/scripts/dir_exists.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2019 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be diff --git a/build/scripts/install-sysroot.py b/build/scripts/install-sysroot.py index 374598f1..898cc7ce 100755 --- a/build/scripts/install-sysroot.py +++ b/build/scripts/install-sysroot.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # Copyright 2019 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be @@ -65,8 +65,8 @@ def GetSha1(filename): def GetSysrootDict(target_platform, target_arch): - """Gets the sysroot information for a given platform and arch from the sysroots.json - file.""" + """Gets the sysroot information for a given platform and arch from the + sysroots.json file.""" if target_arch not in VALID_ARCHS: raise Error('Unknown architecture: %s' % target_arch) @@ -92,7 +92,8 @@ def DownloadFile(url, local_path): raise Error('Failed to download %s' % url) def ValidateFile(local_path, expected_sum): - """Generates the SHA1 hash of a local file to compare with an expected hashsum.""" + """Generates the SHA1 hash of a local file to compare with an expected + hashsum.""" sha1sum = GetSha1(local_path) if sha1sum != expected_sum: raise Error('Tarball sha1sum is wrong.' diff --git a/build/scripts/sysroot_ld_path.py b/build/scripts/sysroot_ld_path.py index 85873812..4502f258 100755 --- a/build/scripts/sysroot_ld_path.py +++ b/build/scripts/sysroot_ld_path.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2019 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be diff --git a/cast/README.md b/cast/README.md index cb63bb33..8af2b47d 100644 --- a/cast/README.md +++ b/cast/README.md @@ -6,13 +6,17 @@ applications and streaming to Cast-compatible devices. ## Using the standalone implementations To run the standalone sender and receivers together, first you need to install -the following dependencies: FFMPEG, LibVPX, LibOpus, LibSDL2, as well as their -headers (frequently in a separate -dev package). From here, you just need a -video to use with the cast_sender, as the cast_receiver can generate a -self-signed certificate and private key for each session. You can also generate -your own RSA private key and either create or have the receiver automatically -create a self signed certificate with that key. If the receiver generates a root -certificate, it will print out the location of that certificate to stdout. +the following dependencies: FFMPEG, LibVPX, LibOpus, LibSDL2, LibAOM as well as +their headers (frequently in a separate -dev package). Currently, it is advised +that most Linux users compile LibAOM from source, using the instructions at +https://aomedia.googlesource.com/aom/. Older versions found in many package +management systems have blocking performance issues, causing AV1 encoding to be +completely unusable. From here, you just need a video to use with the +cast_sender, as the cast_receiver can generate a self-signed certificate and +private key for each session. You can also generate your own RSA private key and +either create or have the receiver automatically create a self signed +certificate with that key. If the receiver generates a root certificate, it will +print out the location of that certificate to stdout. Note that we assume that the private key is a PEM-encoded RSA private key, and the certificate is X509 PEM-encoded. The certificate must also have @@ -33,12 +37,12 @@ the cast_receiver with `-g`, and both should be written out to files: These generated credentials can be passed in to start a session, e.g. ``` -./out/Default/cast_receiver -d generated_root_cast_receiver.crt -p generated_root_cast_receiver.key lo0 -x +./out/Default/cast_receiver -d generated_root_cast_receiver.crt -p generated_root_cast_receiver.key lo0 ``` And then passed to the cast sender to connect and start a streaming session: ``` - $ ./out/Default/cast_sender -d generated_root_cast_receiver.crt ~/video-1080-mp4.mp4 + $ ./out/Default/cast_sender -d generated_root_cast_receiver.crt lo0 ~/video-1080-mp4.mp4 ``` When running on Mac OS X, also pass the `-x` flag to the cast receiver to diff --git a/cast/cast_core/api/BUILD.gn b/cast/cast_core/api/BUILD.gn new file mode 100644 index 00000000..dd9ff283 --- /dev/null +++ b/cast/cast_core/api/BUILD.gn @@ -0,0 +1,171 @@ +# Copyright 2021 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import("//third_party/grpc/grpc_library.gni") +import("//third_party/protobuf/proto_library.gni") + +# NOTE: Our local lite versions of the builtin protos have to retain their +# "google/protobuf" path in order to generate certain correct symbols. However, +# this leads to include confusion with the default committed full versions. The +# work-around is to force an extra include path to reach our local compiled +# versions. +config("force_local_well_known_protos") { + include_dirs = [ "$target_gen_dir" ] +} + +proto_library("base_protos") { + generate_python = false + proto_in_dir = "//third_party/protobuf/src" + proto_out_dir = rebase_path(".", "//") + sources = [ "//third_party/protobuf/src/google/protobuf/duration.proto" ] + cc_generator_options = "lite" + extra_configs = [ ":force_local_well_known_protos" ] +} + +template("cast_core_proto_library_base") { + target(invoker.target_type, target_name) { + proto_in_dir = "//" + rebase_path("../../..", "//") + generate_python = false + + # NOTE: For using built-in proto files like empty.proto. + import_dirs = [ "//third_party/protobuf/src" ] + + forward_variables_from(invoker, + [ + "deps", + "sources", + ]) + if (!defined(deps)) { + deps = [] + } + deps += [ ":base_protos" ] + extra_configs = [ ":force_local_well_known_protos" ] + } +} + +# For .proto files without RPC definitions. +template("cast_core_proto_library") { + cast_core_proto_library_base(target_name) { + target_type = "proto_library" + forward_variables_from(invoker, + [ + "deps", + "sources", + ]) + } +} + +# For .proto files with RPC definitions. +template("cast_core_grpc_library") { + cast_core_proto_library_base(target_name) { + target_type = "grpc_library" + forward_variables_from(invoker, + [ + "deps", + "sources", + ]) + } +} + +group("api") { + public_deps = [ + ":api_bindings_proto", + ":application_config_proto", + ":cast_audio_channel_service_proto", + ":cast_core_service_proto", + ":cast_message_proto", + ":core_application_service_proto", + ":message_channel_proto", + ":metrics_recorder_proto", + ":platform_service_proto", + ":runtime_application_service_proto", + ":runtime_message_port_application_service_proto", + ":runtime_metadata_proto", + ":runtime_service_proto", + ":service_info_proto", + ":url_rewrite_proto", + ] +} + +cast_core_proto_library("api_bindings_proto") { + sources = [ "bindings/api_bindings.proto" ] + deps = [ ":message_channel_proto" ] +} + +cast_core_proto_library("application_config_proto") { + sources = [ "common/application_config.proto" ] +} + +cast_core_proto_library("runtime_metadata_proto") { + sources = [ "common/runtime_metadata.proto" ] +} + +cast_core_proto_library("service_info_proto") { + sources = [ "common/service_info.proto" ] +} + +cast_core_grpc_library("cast_core_service_proto") { + sources = [ "core/cast_core_service.proto" ] + deps = [ ":runtime_metadata_proto" ] +} + +cast_core_grpc_library("platform_service_proto") { + sources = [ "platform/platform_service.proto" ] + deps = [ ":service_info_proto" ] +} + +cast_core_grpc_library("cast_audio_channel_service_proto") { + sources = [ "runtime/cast_audio_channel_service.proto" ] +} + +cast_core_grpc_library("runtime_service_proto") { + sources = [ "runtime/runtime_service.proto" ] + deps = [ + ":application_config_proto", + ":service_info_proto", + ":url_rewrite_proto", + ] +} + +cast_core_proto_library("cast_message_proto") { + sources = [ "v2/cast_message.proto" ] +} + +cast_core_grpc_library("core_application_service_proto") { + sources = [ "v2/core_application_service.proto" ] + deps = [ + ":api_bindings_proto", + ":application_config_proto", + ":cast_message_proto", + ":message_channel_proto", + ":service_info_proto", + ":url_rewrite_proto", + ] +} + +cast_core_grpc_library("runtime_application_service_proto") { + sources = [ "v2/runtime_application_service.proto" ] + deps = [ + ":cast_message_proto", + ":message_channel_proto", + ":url_rewrite_proto", + ] +} + +cast_core_grpc_library("runtime_message_port_application_service_proto") { + sources = [ "v2/runtime_message_port_application_service.proto" ] + deps = [ ":message_channel_proto" ] +} + +cast_core_proto_library("url_rewrite_proto") { + sources = [ "v2/url_rewrite.proto" ] +} + +cast_core_proto_library("message_channel_proto") { + sources = [ "web/message_channel.proto" ] +} + +cast_core_grpc_library("metrics_recorder_proto") { + sources = [ "metrics/metrics_recorder.proto" ] +} diff --git a/cast/cast_core/api/bindings/api_bindings.proto b/cast/cast_core/api/bindings/api_bindings.proto index 0fc9c43d..f275e122 100644 --- a/cast/cast_core/api/bindings/api_bindings.proto +++ b/cast/cast_core/api/bindings/api_bindings.proto @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.bindings; @@ -16,6 +16,8 @@ message ApiBinding { string before_load_script = 1; } +message GetAllRequest {} + message GetAllResponse { repeated ApiBinding bindings = 1; } @@ -24,3 +26,5 @@ message ConnectRequest { string port_name = 1; cast.web.MessagePortDescriptor port = 2; } + +message ConnectResponse {} diff --git a/cast/cast_core/api/common/application_config.proto b/cast/cast_core/api/common/application_config.proto index cb426824..d49d077e 100644 --- a/cast/cast_core/api/common/application_config.proto +++ b/cast/cast_core/api/common/application_config.proto @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.common; diff --git a/cast/cast_core/api/common/runtime_metadata.proto b/cast/cast_core/api/common/runtime_metadata.proto index b8cc0919..734ad36f 100644 --- a/cast/cast_core/api/common/runtime_metadata.proto +++ b/cast/cast_core/api/common/runtime_metadata.proto @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.common; @@ -56,9 +56,6 @@ message RuntimeCapabilities { ApplicationCapabilities native_application_capabilities = 2; } - // Flags if heartbeat is supported. - bool heartbeat_supported = 3; - // Flags if metrics recording is supported. bool metrics_recorder_supported = 4; } diff --git a/cast/cast_core/api/common/service_info.proto b/cast/cast_core/api/common/service_info.proto index e8dc7dd1..2d3539b9 100644 --- a/cast/cast_core/api/common/service_info.proto +++ b/cast/cast_core/api/common/service_info.proto @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.common; diff --git a/cast/cast_core/api/core/cast_core_service.proto b/cast/cast_core/api/core/cast_core_service.proto index af8c0ad8..28ec7e05 100644 --- a/cast/cast_core/api/core/cast_core_service.proto +++ b/cast/cast_core/api/core/cast_core_service.proto @@ -2,14 +2,12 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.core; -import "google/protobuf/empty.proto"; import "cast/cast_core/api/common/runtime_metadata.proto"; -import "cast/cast_core/api/common/service_info.proto"; option optimize_for = LITE_RUNTIME; @@ -21,42 +19,24 @@ service CastCoreService { // Unregisters a Cast Runtime. Usually called by platform. rpc UnregisterRuntime(UnregisterRuntimeRequest) returns (UnregisterRuntimeResponse); - - // Called by the Runtime when it starts up. - rpc RuntimeStarted(RuntimeStartedNotification) - returns (google.protobuf.Empty); - - // Called when the runtime is shutdown. May be called for an active Cast - // session. - rpc RuntimeStopped(RuntimeStoppedNotification) - returns (google.protobuf.Empty); } message RegisterRuntimeRequest { - // Platform-generated runtime ID associated with this runtime. Uniqueness is - // guaranteed by the CastCore service. - string runtime_id = 1; + // DEPRECATED. + string runtime_id = 1 [deprecated = true]; // Metadata about the runtime. cast.common.RuntimeMetadata runtime_metadata = 2; } -message RegisterRuntimeResponse {} - -message UnregisterRuntimeRequest { - // Runtime ID. +message RegisterRuntimeResponse { + // A randomly generated runtime ID. Cast Core will use this ID to reference a + // particular Runtime. string runtime_id = 1; } -message UnregisterRuntimeResponse {} - -message RuntimeStartedNotification { +message UnregisterRuntimeRequest { // Runtime ID. string runtime_id = 1; - // Runtime service info. - cast.common.ServiceInfo runtime_service_info = 2; } -message RuntimeStoppedNotification { - // Runtime ID. - string runtime_id = 1; -} +message UnregisterRuntimeResponse {} diff --git a/cast/cast_core/api/metrics/metrics_recorder.proto b/cast/cast_core/api/metrics/metrics_recorder.proto index d7a04950..16c2ee04 100644 --- a/cast/cast_core/api/metrics/metrics_recorder.proto +++ b/cast/cast_core/api/metrics/metrics_recorder.proto @@ -2,24 +2,24 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.metrics; -import "google/protobuf/empty.proto"; - option optimize_for = LITE_RUNTIME; service MetricsRecorderService { // Record a set of|Event| - rpc Record(RecordRequest) returns (google.protobuf.Empty); + rpc Record(RecordRequest) returns (RecordResponse); } message RecordRequest { repeated Event event = 1; } +message RecordResponse {} + // This repliciates the Fuchsia approach to Cast metrics; for documentation on // event structure, refer to // fuchsia.googlesource.com/fuchsia/+/master/sdk/fidl/fuchsia.legacymetrics/event.fidl @@ -33,7 +33,7 @@ message Event { message UserActionEvent { string name = 1; - optional int64 time = 2; + int64 time = 2; } message Histogram { @@ -50,5 +50,5 @@ message HistogramBucket { message ImplementationDefinedEvent { bytes data = 1; - optional string name = 2; + string name = 2; } diff --git a/cast/cast_core/api/platform/platform_service.proto b/cast/cast_core/api/platform/platform_service.proto index 7e2ad5f5..6ecb7d49 100644 --- a/cast/cast_core/api/platform/platform_service.proto +++ b/cast/cast_core/api/platform/platform_service.proto @@ -2,11 +2,13 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.platform; +import "cast/cast_core/api/common/service_info.proto"; + option optimize_for = LITE_RUNTIME; // Platform service. Implemented and hosted by Platform. @@ -23,7 +25,10 @@ service PlatformService { } message StartRuntimeRequest { + // Cast Runtime ID assigned in CastCoreService.RegisterRuntime. string runtime_id = 1; + // gRPC endpoint Cast Runtime must run on. + cast.common.ServiceInfo runtime_service_info = 2; } message StartRuntimeResponse {} diff --git a/cast/cast_core/api/runtime/cast_audio_decoder_service.proto b/cast/cast_core/api/runtime/cast_audio_channel_service.proto index f6916cb4..61c9c658 100644 --- a/cast/cast_core/api/runtime/cast_audio_decoder_service.proto +++ b/cast/cast_core/api/runtime/cast_audio_channel_service.proto @@ -2,16 +2,182 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.media; import "google/protobuf/duration.proto"; -import "google/protobuf/empty.proto"; option optimize_for = LITE_RUNTIME; +// Cast audio service hosted by Cast Core. +// +// It defines a state machine with the following states: +// - Uninitialized +// - Playing +// - Stopped +// - Paused +// +// Note that the received ordering between different RPC calls is not +// guaranteed to match the sent order. +service CastAudioChannelService { + // Initializes the service and places the pipeline into the 'Stopped' state. + // This must be the first call received by the server, and no other calls + // may be sent prior to receiving this call's response. + rpc Initialize(InitializeRequest) returns (InitializeResponse); + + // Returns the minimum buffering delay (min_delay) required by Cast. This is + // a constant value and only needs to be queried once for each service. + // During a StartRequest or ResumeRequest, the system timestamp must be + // greater than this delay and the current time in order for the buffer to be + // successfully rendered on remote devices. + rpc GetMinimumBufferDelay(GetMinimumBufferDelayRequest) + returns (GetMinimumBufferDelayResponse); + + // Update the pipeline state. + // + // StartRequest: + // Places pipeline into 'Playing' state. Playback will start at the + // specified buffer and system timestamp. + // + // May only be called in the 'Stopped' state, and following this call the + // state machine will be in the 'Playing' state. + // + // StopRequest + // Stops media playback and drops all pushed buffers which have not yet been + // played. + // + // May only be called in the 'Playing' or 'Paused' states, and following + // this call the state machine will be in the 'Stopped' state. + // + // PauseRequest + // Pauses media playback. + // + // May only be called in the 'Playing' state, and following this call the + // state machine will be in the 'Paused' state. + // + // ResumeRequest + // Resumes media playback at the specified buffer and system timestamp. + // + // May only be called in the 'Paused' state, and following this call the + // state machine will be in the 'Playing'' state. + // + // TimestampUpdateRequest + // Sends a timestamp update for a specified buffer for audio + // synchronization. This should be called when operating in + // CAST_AUDIO_DECODER_MODE_MULTIROOM_ONLY when the runtime has detected a + // discrepancy in the system clock or pipeline delay from the original + // playback schedule. See example below: + // + // Assume all buffers have duration of 100us. + // + // StartRequest(id=1, system_timestamp=0); + // -> Cast expects id=1 to play at 0, id=2 at 100us, id=3 at 200 us... + // + // TimestampUpdateRequest(id=4, system_timestamp=405us); + // -> Cast expects id=4 to play at 405, id=5 at 505us, id=6 at 605 us... + // + // May be called from any state. + // + // A state transition may only occur after a successful PushBuffer() + // call has been made with a valid configuration. + rpc StateChange(StateChangeRequest) returns (StateChangeResponse); + + // Sets the volume multiplier for this audio stream. + // The multiplier is in the range [0.0, 1.0]. If not called, a default + // multiplier of 1.0 is assumed. + // + // May be called in any state, and following this call the state machine + // will be in the same state. + rpc SetVolume(SetVolumeRequest) returns (SetVolumeResponse); + + // Sets the playback rate for this audio stream. + // + // May be called in any state, and following this call the state machine + // will be in the same state. + rpc SetPlaybackRate(SetPlaybackRateRequest) returns (SetPlaybackRateResponse); + + // Sends decoded bits and responses to the audio service. The client must + // wait for a response from the server before sending another + // PushBufferRequest. + // + // May only be called in the 'Playing' or 'Paused' states, and following + // this call the state machine will remain the same state. + // + rpc PushBuffer(PushBufferRequest) returns (PushBufferResponse); + + // Returns the current media time that has been rendered. + rpc GetMediaTime(GetMediaTimeRequest) returns (GetMediaTimeResponse); +} + +message InitializeRequest { + // Cast session ID. + string cast_session_id = 1; + + // Configures how the server should operate. + CastAudioDecoderMode mode = 2; +} + +message InitializeResponse {} + +message GetMinimumBufferDelayRequest {} + +message GetMinimumBufferDelayResponse { + // The minimum buffering delay in microseconds. + int64 delay_micros = 1; +} + +message StateChangeRequest { + oneof request { + StartRequest start = 1; + StopRequest stop = 2; + PauseRequest pause = 3; + ResumeRequest resume = 4; + TimestampUpdateRequest timestamp_update = 5; + } +} + +message StateChangeResponse { + // Pipeline state after state change. + PipelineState state = 1; +} + +message SetVolumeRequest { + // The multiplier is in the range [0.0, 1.0]. + float multiplier = 1; +} + +message SetVolumeResponse {} + +message SetPlaybackRateRequest { + // Playback rate greater than 0. + double rate = 1; +} + +message SetPlaybackRateResponse {} + +message PushBufferRequest { + AudioDecoderBuffer buffer = 1; + + // Audio configuration for this buffer and all subsequent buffers. This + // field must be populated for the first request or if there is an audio + // configuration change. + AudioConfiguration audio_config = 2; +} + +message PushBufferResponse { + // The total number of decoded bytes. + int64 decoded_bytes = 1; +} + +message GetMediaTimeRequest {} + +message GetMediaTimeResponse { + // The current media time that has been rendered. + MediaTime media_time = 1; +} + enum PipelineState { PIPELINE_STATE_UNINITIALIZED = 0; PIPELINE_STATE_STOPPED = 1; @@ -142,19 +308,6 @@ message TimestampInfo { int64 buffer_id = 2; } -message InitializeRequest { - // Cast session ID. - string cast_session_id = 1; - - // Configures how the server should operate. - CastAudioDecoderMode mode = 2; -} - -message GetMinimumBufferingDelayResponse { - // The minimum buffering delay in microseconds. - int64 delay_micros = 1; -} - message StartRequest { // The start presentation timestamp in microseconds. int64 pts_micros = 1; @@ -179,148 +332,3 @@ message ResumeRequest { message TimestampUpdateRequest { TimestampInfo timestamp_info = 1; } - -message StateChangeRequest { - oneof request { - StartRequest start = 1; - StopRequest stop = 2; - PauseRequest pause = 3; - ResumeRequest resume = 4; - TimestampUpdateRequest timestamp_update = 5; - } -} - -message StateChangeResponse { - // Pipeline state after state change. - PipelineState state = 1; -} - -message PushBufferRequest { - AudioDecoderBuffer buffer = 1; - - // Audio configuration for this buffer and all subsequent buffers. This - // field must be populated for the first request or if there is an audio - // configuration change. - AudioConfiguration audio_config = 2; -} - -message PushBufferResponse { - // The total number of decoded bytes. - int64 decoded_bytes = 1; -} - -message SetVolumeRequest { - // The multiplier is in the range [0.0, 1.0]. - float multiplier = 1; -} -message SetPlaybackRateRequest { - // Playback rate greater than 0. - double rate = 1; -} - -message GetMediaTimeResponse { - // The current media time that has been rendered. - MediaTime media_time = 1; -} - -// Cast audio service hosted by Cast Core. -// -// It defines a state machine with the following states: -// - Uninitialized -// - Playing -// - Stopped -// - Paused -// -// Note that the received ordering between different RPC calls is not -// guaranteed to match the sent order. -service CastRuntimeAudioChannel { - // Initializes the service and places the pipeline into the 'Stopped' state. - // This must be the first call received by the server, and no other calls - // may be sent prior to receiving this call's response. - rpc Initialize(InitializeRequest) returns (google.protobuf.Empty); - - // Returns the minimum buffering delay (min_delay) required by Cast. This is - // a constant value and only needs to be queried once for each service. - // During a StartRequest or ResumeRequest, the system timestamp must be - // greater than this delay and the current time in order for the buffer to be - // successfully rendered on remote devices. - rpc GetMinimumBufferDelay(google.protobuf.Empty) - returns (GetMinimumBufferingDelayResponse); - - // Update the pipeline state. - // - // StartRequest: - // Places pipeline into 'Playing' state. Playback will start at the - // specified buffer and system timestamp. - // - // May only be called in the 'Stopped' state, and following this call the - // state machine will be in the 'Playing' state. - // - // StopRequest - // Stops media playback and drops all pushed buffers which have not yet been - // played. - // - // May only be called in the 'Playing' or 'Paused' states, and following - // this call the state machine will be in the 'Stopped' state. - // - // PauseRequest - // Pauses media playback. - // - // May only be called in the 'Playing' state, and following this call the - // state machine will be in the 'Paused' state. - // - // ResumeRequest - // Resumes media playback at the specified buffer and system timestamp. - // - // May only be called in the 'Paused' state, and following this call the - // state machine will be in the 'Playing'' state. - // - // TimestampUpdateRequest - // Sends a timestamp update for a specified buffer for audio - // synchronization. This should be called when operating in - // CAST_AUDIO_DECODER_MODE_MULTIROOM_ONLY when the runtime has detected a - // discrepancy in the system clock or pipeline delay from the original - // playback schedule. See example below: - // - // Assume all buffers have duration of 100us. - // - // StartRequest(id=1, system_timestamp=0); - // -> Cast expects id=1 to play at 0, id=2 at 100us, id=3 at 200 us... - // - // TimestampUpdateRequest(id=4, system_timestamp=405us); - // -> Cast expects id=4 to play at 405, id=5 at 505us, id=6 at 605 us... - // - // May be called from any state. - // - // A state transition may only occur after a successful PushBuffer() - // call has been made with a valid configuration. - rpc StateChange(StateChangeRequest) returns (StateChangeResponse); - - // Sets the volume multiplier for this audio stream. - // The multiplier is in the range [0.0, 1.0]. If not called, a default - // multiplier of 1.0 is assumed. - // - // May be called in any state, and following this call the state machine - // will be in the same state. - rpc SetVolume(SetVolumeRequest) returns (google.protobuf.Empty); - - // Sets the playback rate for this audio stream. - // - // May be called in any state, and following this call the state machine - // will be in the same state. - rpc SetPlayback(SetPlaybackRateRequest) returns (google.protobuf.Empty); - - // Sends decoded bits and responses to the audio service. The client must - // wait for a response from the server before sending another - // PushBufferRequest. - // - // May only be called in the 'Playing' or 'Paused' states, and following - // this call the state machine will remain the same state. - // - // TODO(b/178523159): validate that this isn't a performance bottleneck as a - // non-streaming API. If it is, we should make this a bidirectional stream. - rpc PushBuffer(PushBufferRequest) returns (PushBufferResponse); - - // Returns the current media time that has been rendered. - rpc GetMediaTime(google.protobuf.Empty) returns (GetMediaTimeResponse); -} diff --git a/cast/cast_core/api/runtime/runtime_service.proto b/cast/cast_core/api/runtime/runtime_service.proto index 084852eb..0ea47daa 100644 --- a/cast/cast_core/api/runtime/runtime_service.proto +++ b/cast/cast_core/api/runtime/runtime_service.proto @@ -2,14 +2,15 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.runtime; import "google/protobuf/duration.proto"; -import "google/protobuf/empty.proto"; +import "cast/cast_core/api/common/application_config.proto"; import "cast/cast_core/api/common/service_info.proto"; +import "cast/cast_core/api/v2/url_rewrite.proto"; option optimize_for = LITE_RUNTIME; @@ -17,9 +18,12 @@ option optimize_for = LITE_RUNTIME; // // This service is called by CastCore after Runtime starts up. service RuntimeService { + // Loads a Cast application. The runtime must start its + // RuntimeApplicationService on runtime_application_service_info. + rpc LoadApplication(LoadApplicationRequest) returns (LoadApplicationResponse); + // Launches a Cast application. The application must connect to the - // CoreApplicationService based on cast_protocol and - // core_application_endpoint, and provide its endpoint. + // CoreApplicationService via core_application_service_info. rpc LaunchApplication(LaunchApplicationRequest) returns (LaunchApplicationResponse); @@ -38,24 +42,57 @@ service RuntimeService { // Provides information need by the runtime to start recording metrics via // the core. rpc StartMetricsRecorder(StartMetricsRecorderRequest) - returns (google.protobuf.Empty); + returns (StartMetricsRecorderResponse); // Stops the metrics recorder, which may also attempt to flush. - rpc StopMetricsRecorder(google.protobuf.Empty) - returns (google.protobuf.Empty); + rpc StopMetricsRecorder(StopMetricsRecorderRequest) + returns (StopMetricsRecorderResponse); } -message StartMetricsRecorderRequest { - // Metrics service info. - cast.common.ServiceInfo metrics_recorder_service_info = 1; +message LoadApplicationRequest { + // Cast application config. + cast.common.ApplicationConfig application_config = 1; + // Initial rules to rewrite URLs and headers. + cast.v2.UrlRequestRewriteRules url_rewrite_rules = 2; + // Cast session id used to setup a connection and pull the config from core + // application service. + string cast_session_id = 3; + // RuntimeApplication service info. The endpoint is generated by Cast Core and + // must be used by the Runtime to bind the RuntimeApplication service. + cast.common.ServiceInfo runtime_application_service_info = 4; +} + +// Info relevant to a V2 channel between the runtime and cast core. +message V2ChannelInfo { + // If set, only messages within these namespaces will be sent to the runtime. + // If empty, all V2 messages will be sent to the runtime regardless of + // namespace. + repeated string requested_namespaces = 1; +} + +// Info relevant to a MessagePort channel between the runtime and cast core. +message MessagePortInfo {} + +message LoadApplicationResponse { + // One of these fields must be set. This specifies what type of communication + // channel should be used to communicate between the runtime and cast core for + // the given application. + oneof channel_type { + V2ChannelInfo v2_info = 1; + MessagePortInfo message_port_info = 2; + } } message LaunchApplicationRequest { // CoreApplication service info. cast.common.ServiceInfo core_application_service_info = 1; - // Cast session id used to setup a connection and pull the config from core - // application service. - string cast_session_id = 2; + // DEPRECATED + string cast_session_id = 2 [deprecated = true]; + // DEPRECATED + cast.common.ServiceInfo runtime_application_service_info = 3 + [deprecated = true]; + // CastMedia service info for this application in CastCore. + cast.common.ServiceInfo cast_media_service_info = 4; } // Returned by the runtime in response to a launch application request. @@ -82,3 +119,14 @@ message HeartbeatRequest { } message HeartbeatResponse {} + +message StartMetricsRecorderRequest { + // Metrics service info. + cast.common.ServiceInfo metrics_recorder_service_info = 1; +} + +message StartMetricsRecorderResponse {} + +message StopMetricsRecorderRequest {} + +message StopMetricsRecorderResponse {} diff --git a/cast/cast_core/api/v2/cast_message.proto b/cast/cast_core/api/v2/cast_message.proto index 72575264..639c3aaf 100644 --- a/cast/cast_core/api/v2/cast_message.proto +++ b/cast/cast_core/api/v2/cast_message.proto @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.v2; @@ -10,7 +10,7 @@ package cast.v2; option optimize_for = LITE_RUNTIME; // Cast V2 request definition. -message CastMessage { +message CastMessageRequest { // Cast sender ID; distinct from virtual connection source ID. string sender_id = 1; // Cast namespace. diff --git a/cast/cast_core/api/v2/core_application_service.proto b/cast/cast_core/api/v2/core_application_service.proto index 7933393b..3a31f7b2 100644 --- a/cast/cast_core/api/v2/core_application_service.proto +++ b/cast/cast_core/api/v2/core_application_service.proto @@ -2,12 +2,11 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.v2; -import "google/protobuf/empty.proto"; import "cast/cast_core/api/bindings/api_bindings.proto"; import "cast/cast_core/api/common/application_config.proto"; import "cast/cast_core/api/common/service_info.proto"; @@ -25,32 +24,41 @@ service CoreApplicationService { // the gRPC status code. rpc GetConfig(GetConfigRequest) returns (GetConfigResponse); + // DEPRECATED // Send a Cast V2 message to core application. - rpc SendCastMessage(CastMessage) returns (CastMessageResponse); + rpc SendCastMessage(CastMessageRequest) returns (CastMessageResponse); // Notifies Cast Core on the application state changes. The callback must be // called by the Runtime whenever the internal state of the application // changes. Cast Core may discard any resources associated with the // application upon failures. - rpc OnApplicationStatus(ApplicationStatus) returns (google.protobuf.Empty); + rpc SetApplicationStatus(ApplicationStatusRequest) + returns (ApplicationStatusResponse); + // DEPRECATED // Posts messages between MessagePorts. MessagePorts are connected using other // services (e.g. ApiBindings), then registered with the // MessageConnectorService to communicate over IPC. rpc PostMessage(cast.web.Message) returns (cast.web.MessagePortStatus); + // DEPRECATED // Gets the list of bindings to early-inject into javascript at page load. - rpc GetAll(google.protobuf.Empty) returns (cast.bindings.GetAllResponse); + rpc GetAll(cast.bindings.GetAllRequest) + returns (cast.bindings.GetAllResponse); + // DEPRECATED // Connects to a binding returned by GetAll. - rpc Connect(cast.bindings.ConnectRequest) returns (google.protobuf.Empty); + rpc Connect(cast.bindings.ConnectRequest) + returns (cast.bindings.ConnectResponse); + + // GetWebUIResource request + rpc GetWebUIResource(GetWebUIResourceRequest) + returns (GetWebUIResourceResponse); } message GetConfigRequest { // Cast session ID. string cast_session_id = 1; - // RuntimeApplication service info. - cast.common.ServiceInfo runtime_application_service_info = 2; } message GetConfigResponse { @@ -63,7 +71,7 @@ message GetConfigResponse { } // Contains information about an application status in the runtime. -message ApplicationStatus { +message ApplicationStatusRequest { // The Cast session ID whose application status changed. string cast_session_id = 1; @@ -94,3 +102,15 @@ message ApplicationStatus { // |stop_reason| is HTTP_ERROR. int32 http_response_code = 4; } + +message ApplicationStatusResponse {} + +message GetWebUIResourceRequest { + // Resource identifier. It can either be name of the resource or a url. + string resource_id = 1; +} + +message GetWebUIResourceResponse { + // Path to the resource file on device. + string resource_path = 1; +} diff --git a/cast/cast_core/api/v2/core_message_port_application_service.proto b/cast/cast_core/api/v2/core_message_port_application_service.proto new file mode 100644 index 00000000..9dcd918e --- /dev/null +++ b/cast/cast_core/api/v2/core_message_port_application_service.proto @@ -0,0 +1,30 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// **** DO NOT EDIT - this file was automatically generated. **** +syntax = "proto3"; + +package cast.v2; + +import "cast/cast_core/api/bindings/api_bindings.proto"; +import "cast/cast_core/api/web/message_channel.proto"; + +option optimize_for = LITE_RUNTIME; + +// This service runs in Cast Core for a particular app. It uses a MessagePort to +// communicate with the app. +service CoreMessagePortApplicationService { + // Posts messages between MessagePorts. MessagePorts are connected using other + // services (e.g. ApiBindings), then registered with the + // MessageConnectorService to communicate over IPC. + rpc PostMessage(cast.web.Message) returns (cast.web.MessagePortStatus); + + // Gets the list of bindings to early-inject into javascript at page load. + rpc GetAll(cast.bindings.GetAllRequest) + returns (cast.bindings.GetAllResponse); + + // Connects to a binding returned by GetAll. + rpc Connect(cast.bindings.ConnectRequest) + returns (cast.bindings.ConnectResponse); +} diff --git a/cast/cast_core/api/v2/core_v2_application_service.proto b/cast/cast_core/api/v2/core_v2_application_service.proto new file mode 100644 index 00000000..38a599bf --- /dev/null +++ b/cast/cast_core/api/v2/core_v2_application_service.proto @@ -0,0 +1,19 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// **** DO NOT EDIT - this file was automatically generated. **** +syntax = "proto3"; + +package cast.v2; + +import "cast/cast_core/api/v2/cast_message.proto"; + +option optimize_for = LITE_RUNTIME; + +// This service runs in Cast Core for a particular app. It uses the V2 protocol +// to communicate with the app. +service CoreV2ApplicationService { + // Send a Cast V2 message to core application. + rpc SendCastMessage(CastMessageRequest) returns (CastMessageResponse); +} diff --git a/cast/cast_core/api/v2/runtime_application_service.proto b/cast/cast_core/api/v2/runtime_application_service.proto index d80cf61c..f105f189 100644 --- a/cast/cast_core/api/v2/runtime_application_service.proto +++ b/cast/cast_core/api/v2/runtime_application_service.proto @@ -2,12 +2,11 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.v2; -import "google/protobuf/empty.proto"; import "cast/cast_core/api/v2/cast_message.proto"; import "cast/cast_core/api/v2/url_rewrite.proto"; import "cast/cast_core/api/web/message_channel.proto"; @@ -19,16 +18,9 @@ option optimize_for = LITE_RUNTIME; // This service is implemented by the Runtime and represents services // specific to a Cast application. service RuntimeApplicationService { - // Notifies the runtime that a new Cast V2 virtual connection has been opened. - rpc OnVirtualConnectionOpen(VirtualConnectionInfo) - returns (google.protobuf.Empty); - - // Notifies the runtime that a Cast V2 virtual connection has been closed. - rpc OnVirtualConnectionClosed(VirtualConnectionInfo) - returns (google.protobuf.Empty); - + // DEPRECATED // Sends a Cast message to the runtime. - rpc SendCastMessage(CastMessage) returns (CastMessageResponse); + rpc SendCastMessage(CastMessageRequest) returns (CastMessageResponse); // Set the URL rewrite rules that the Runtime will use to contact the MSP // This is called when the rewrite rules are changed @@ -36,6 +28,7 @@ service RuntimeApplicationService { rpc SetUrlRewriteRules(SetUrlRewriteRulesRequest) returns (SetUrlRewriteRulesResponse); + // DEPRECATED // "MessageConnectorService" provides the transport for MessagePorts. // MessagePorts are connected using other services (e.g. ApiBindings), then // registered with the MessageConnectorService to communicate over IPC @@ -49,16 +42,3 @@ message SetUrlRewriteRulesRequest { } message SetUrlRewriteRulesResponse {} - -// Request by the sender to open or close a virtual connection to the Cast -// runtime. -message VirtualConnectionInfo { - // The source of the virtual connection request. Connections from the - // sender platform use an id of 'sender-0' and connections from applications - // use a unique ID. - string source_id = 1; - // The destination of the connection request. Connections to the Cast - // receiver platform use an id of 'receiver-0' and connections to applications - // use the Cast session id. - string destination_id = 2; -} diff --git a/cast/cast_core/api/v2/runtime_message_port_application_service.proto b/cast/cast_core/api/v2/runtime_message_port_application_service.proto new file mode 100644 index 00000000..9687fe50 --- /dev/null +++ b/cast/cast_core/api/v2/runtime_message_port_application_service.proto @@ -0,0 +1,21 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// **** DO NOT EDIT - this file was automatically generated. **** +syntax = "proto3"; + +package cast.v2; + +import "cast/cast_core/api/web/message_channel.proto"; + +option optimize_for = LITE_RUNTIME; + +// This service runs in the runtime for a particular app. It uses a MessagePort +// to communicate with Cast Core. +service RuntimeMessagePortApplicationService { + // "MessageConnectorService" provides the transport for MessagePorts. + // MessagePorts are connected using other services (e.g. ApiBindings), then + // registered with the MessageConnectorService to communicate over IPC + rpc PostMessage(cast.web.Message) returns (cast.web.MessagePortStatus); +} diff --git a/cast/cast_core/api/v2/runtime_v2_application_service.proto b/cast/cast_core/api/v2/runtime_v2_application_service.proto new file mode 100644 index 00000000..b5050a71 --- /dev/null +++ b/cast/cast_core/api/v2/runtime_v2_application_service.proto @@ -0,0 +1,19 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// **** DO NOT EDIT - this file was automatically generated. **** +syntax = "proto3"; + +package cast.v2; + +import "cast/cast_core/api/v2/cast_message.proto"; + +option optimize_for = LITE_RUNTIME; + +// This service runs in the runtime for a particular app. It uses the V2 +// protocol to communicate with Cast Core. +service RuntimeV2ApplicationService { + // Sends a Cast V2 message to the runtime. + rpc SendCastMessage(CastMessageRequest) returns (CastMessageResponse); +} diff --git a/cast/cast_core/api/v2/url_rewrite.proto b/cast/cast_core/api/v2/url_rewrite.proto index 6f637c8d..e9987007 100644 --- a/cast/cast_core/api/v2/url_rewrite.proto +++ b/cast/cast_core/api/v2/url_rewrite.proto @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.v2; diff --git a/cast/cast_core/api/web/message_channel.proto b/cast/cast_core/api/web/message_channel.proto index 4e3f1ded..dc59b13b 100644 --- a/cast/cast_core/api/web/message_channel.proto +++ b/cast/cast_core/api/web/message_channel.proto @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// **** DO NOT EDIT - this .proto was automatically generated. **** +// **** DO NOT EDIT - this file was automatically generated. **** syntax = "proto3"; package cast.web; diff --git a/cast/common/BUILD.gn b/cast/common/BUILD.gn index ae32bf4a..154d436e 100644 --- a/cast/common/BUILD.gn +++ b/cast/common/BUILD.gn @@ -83,15 +83,15 @@ source_set("public") { sources = [ "public/cast_socket.h", "public/message_port.h", - "public/service_info.cc", - "public/service_info.h", + "public/receiver_info.cc", + "public/receiver_info.h", ] deps = [ - "../../discovery:dnssd", "../../discovery:public", "../../platform", "../../third_party/abseil", + "../../util", ] } @@ -107,6 +107,7 @@ if (!build_with_chromium) { ":public", "../../discovery:dnssd", "../../discovery:public", + "../../platform:standalone_impl", "../../testing/util", "../../third_party/googletest:gtest", ] @@ -129,7 +130,7 @@ source_set("test_helpers") { ":certificate", ":channel", ":public", - "../../discovery:dnssd", + "../../discovery:public", "../../platform:test", "../../testing/util", "../../third_party/abseil", @@ -153,7 +154,7 @@ source_set("unittests") { "channel/message_framer_unittest.cc", "channel/namespace_router_unittest.cc", "channel/virtual_connection_router_unittest.cc", - "public/service_info_unittest.cc", + "public/receiver_info_unittest.cc", ] deps = [ diff --git a/cast/common/certificate/cast_cert_validator.cc b/cast/common/certificate/cast_cert_validator.cc index b66f8859..f8ee66f7 100644 --- a/cast/common/certificate/cast_cert_validator.cc +++ b/cast/common/certificate/cast_cert_validator.cc @@ -103,18 +103,18 @@ CastDeviceCertPolicy GetAudioPolicy(const std::vector<X509*>& path) { int pos = X509_get_ext_by_NID(cert, NID_certificate_policies, -1); if (pos != -1) { X509_EXTENSION* policies_extension = X509_get_ext(cert, pos); - const uint8_t* in = policies_extension->value->data; - CERTIFICATEPOLICIES* policies = d2i_CERTIFICATEPOLICIES( - nullptr, &in, policies_extension->value->length); + const ASN1_STRING* value = X509_EXTENSION_get_data(policies_extension); + const uint8_t* in = ASN1_STRING_get0_data(value); + CERTIFICATEPOLICIES* policies = + d2i_CERTIFICATEPOLICIES(nullptr, &in, ASN1_STRING_length(value)); if (policies) { // Check for |audio_only_policy_oid| in the set of policies. uint32_t policy_count = sk_POLICYINFO_num(policies); for (uint32_t i = 0; i < policy_count; ++i) { POLICYINFO* info = sk_POLICYINFO_value(policies, i); - if (info->policyid->length == - static_cast<int>(audio_only_policy_oid.length) && - memcmp(info->policyid->data, audio_only_policy_oid.data, + if (OBJ_length(info->policyid) == audio_only_policy_oid.length && + memcmp(OBJ_get0_data(info->policyid), audio_only_policy_oid.data, audio_only_policy_oid.length) == 0) { policy = CastDeviceCertPolicy::kAudioOnly; break; @@ -162,10 +162,17 @@ Error VerifyDeviceCert(const std::vector<std::string>& der_certs, // CertVerificationContextImpl. X509_NAME* target_subject = X509_get_subject_name(result_path.target_cert.get()); - std::string common_name(target_subject->canon_enclen, 0); - int len = X509_NAME_get_text_by_NID(target_subject, NID_commonName, - &common_name[0], common_name.size()); - if (len == 0) { + int len = + X509_NAME_get_text_by_NID(target_subject, NID_commonName, nullptr, 0); + if (len <= 0) { + return Error::Code::kErrCertsRestrictions; + } + // X509_NAME_get_text_by_NID writes one more byte than it reports, for a + // trailing NUL. + std::string common_name(len + 1, 0); + len = X509_NAME_get_text_by_NID(target_subject, NID_commonName, + &common_name[0], common_name.size()); + if (len <= 0) { return Error::Code::kErrCertsRestrictions; } common_name.resize(len); diff --git a/cast/common/certificate/cast_cert_validator_internal.cc b/cast/common/certificate/cast_cert_validator_internal.cc index 764ac3e4..073b76ac 100644 --- a/cast/common/certificate/cast_cert_validator_internal.cc +++ b/cast/common/certificate/cast_cert_validator_internal.cc @@ -18,6 +18,7 @@ #include <utility> #include <vector> +#include "absl/strings/str_cat.h" #include "cast/common/certificate/types.h" #include "util/crypto/pem_helpers.h" #include "util/osp_logging.h" @@ -407,29 +408,30 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, result_path->intermediate_certs; target_cert.reset(ParseX509Der(der_certs[0])); if (!target_cert) { - OSP_DVLOG << "FindCertificatePath: Invalid target certificate"; - return Error::Code::kErrCertsParse; + return Error(Error::Code::kErrCertsParse, + "FindCertificatePath: Invalid target certificate"); } for (size_t i = 1; i < der_certs.size(); ++i) { intermediate_certs.emplace_back(ParseX509Der(der_certs[i])); if (!intermediate_certs.back()) { - OSP_DVLOG - << "FindCertificatePath: Failed to parse intermediate certificate " - << i << " of " << der_certs.size(); - return Error::Code::kErrCertsParse; + return Error( + Error::Code::kErrCertsParse, + absl::StrCat( + "FindCertificatePath: Failed to parse intermediate certificate ", + i, " of ", der_certs.size())); } } // Basic checks on the target certificate. - Error::Code error = VerifyCertTime(target_cert.get(), time); - if (error != Error::Code::kNone) { - OSP_DVLOG << "FindCertificatePath: Failed to verify certificate time"; - return error; + Error::Code valid_time = VerifyCertTime(target_cert.get(), time); + if (valid_time != Error::Code::kNone) { + return Error(valid_time, + "FindCertificatePath: Failed to verify certificate time"); } bssl::UniquePtr<EVP_PKEY> public_key{X509_get_pubkey(target_cert.get())}; if (!VerifyPublicKeyLength(public_key.get())) { - OSP_DVLOG << "FindCertificatePath: Failed with invalid public key length"; - return Error::Code::kErrCertsVerifyGeneric; + return Error(Error::Code::kErrCertsVerifyGeneric, + "FindCertificatePath: Failed with invalid public key length"); } const X509_ALGOR* sig_alg; X509_get0_signature(nullptr, &sig_alg, target_cert.get()); @@ -438,14 +440,14 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, } bssl::UniquePtr<ASN1_BIT_STRING> key_usage = GetKeyUsage(target_cert.get()); if (!key_usage) { - OSP_DVLOG << "FindCertificatePath: Failed with no key usage"; - return Error::Code::kErrCertsRestrictions; + return Error(Error::Code::kErrCertsRestrictions, + "FindCertificatePath: Failed with no key usage"); } int bit = ASN1_BIT_STRING_get_bit(key_usage.get(), KeyUsageBits::kDigitalSignature); if (bit == 0) { - OSP_DVLOG << "FindCertificatePath: Failed to get digital signature"; - return Error::Code::kErrCertsRestrictions; + return Error(Error::Code::kErrCertsRestrictions, + "FindCertificatePath: Failed to get digital signature"); } X509* path_head = target_cert.get(); @@ -478,8 +480,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, Error::Code last_error = Error::Code::kNone; for (;;) { X509_NAME* target_issuer_name = X509_get_issuer_name(path_head); - OSP_DVLOG << "FindCertificatePath: Target certificate issuer name: " - << X509_NAME_oneline(target_issuer_name, 0, 0); + OSP_VLOG << "FindCertificatePath: Target certificate issuer name: " + << X509_NAME_oneline(target_issuer_name, 0, 0); // The next issuer certificate to add to the current path. X509* next_issuer = nullptr; @@ -488,8 +490,8 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, X509* trust_store_cert = trust_store->certs[i].get(); X509_NAME* trust_store_cert_name = X509_get_subject_name(trust_store_cert); - OSP_DVLOG << "FindCertificatePath: Trust store certificate issuer name: " - << X509_NAME_oneline(trust_store_cert_name, 0, 0); + OSP_VLOG << "FindCertificatePath: Trust store certificate issuer name: " + << X509_NAME_oneline(trust_store_cert_name, 0, 0); if (X509_NAME_cmp(trust_store_cert_name, target_issuer_name) == 0) { CertPathStep& next_step = path[--path_index]; next_step.cert = trust_store_cert; @@ -524,9 +526,9 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, if (path_index == first_index) { // There are no more paths to try. Ensure an error is returned. if (last_error == Error::Code::kNone) { - OSP_DVLOG << "FindCertificatePath: Failed after trying all " - "certificate paths, no matches"; - return Error::Code::kErrCertsVerifyUntrustedCert; + return Error(Error::Code::kErrCertsVerifyUntrustedCert, + "FindCertificatePath: Failed after trying all " + "certificate paths, no matches"); } return last_error; } else { @@ -556,7 +558,7 @@ Error FindCertificatePath(const std::vector<std::string>& der_certs, result_path->path.push_back(path[i].cert); } - OSP_DVLOG + OSP_VLOG << "FindCertificatePath: Succeeded at validating receiver certificates"; return Error::Code::kNone; } diff --git a/cast/common/certificate/types.cc b/cast/common/certificate/types.cc index d891c0a9..2c8fecc1 100644 --- a/cast/common/certificate/types.cc +++ b/cast/common/certificate/types.cc @@ -55,7 +55,7 @@ bool DateTimeFromSeconds(uint64_t seconds, DateTime* time) { #if defined(_WIN32) // NOTE: This is for compiling in Chromium and is not validated in any direct // libcast Windows build. - if (!gmtime_s(&tm, &sec)) { + if (gmtime_s(&tm, &sec)) { return false; } #else diff --git a/cast/common/channel/cast_socket.cc b/cast/common/channel/cast_socket.cc index 0479c997..e06cde46 100644 --- a/cast/common/channel/cast_socket.cc +++ b/cast/common/channel/cast_socket.cc @@ -14,6 +14,8 @@ namespace cast { using ::cast::channel::CastMessage; using message_serialization::DeserializeResult; +CastSocket::Client::~Client() = default; + CastSocket::CastSocket(std::unique_ptr<TlsConnection> connection, Client* client) : connection_(std::move(connection)), diff --git a/cast/common/channel/cast_socket_message_port.cc b/cast/common/channel/cast_socket_message_port.cc index 0c51304b..bdc33f20 100644 --- a/cast/common/channel/cast_socket_message_port.cc +++ b/cast/common/channel/cast_socket_message_port.cc @@ -93,7 +93,6 @@ void CastSocketMessagePort::OnMessage(VirtualConnectionRouter* router, return; } - OSP_DVLOG << "Received a cast socket message"; if (!client_) { OSP_DLOG_WARN << "Dropping message due to nullptr client_"; return; diff --git a/cast/common/channel/connection_namespace_handler.cc b/cast/common/channel/connection_namespace_handler.cc index d3b2ea8b..c50b97be 100644 --- a/cast/common/channel/connection_namespace_handler.cc +++ b/cast/common/channel/connection_namespace_handler.cc @@ -221,8 +221,8 @@ void ConnectionNamespaceHandler::HandleConnect(CastSocket* socket, data.ip_fragment = {}; } - OSP_DVLOG << "Connection opened: " << virtual_conn.local_id << ", " - << virtual_conn.peer_id << ", " << virtual_conn.socket_id; + OSP_VLOG << "Connection opened: " << virtual_conn.local_id << ", " + << virtual_conn.peer_id << ", " << virtual_conn.socket_id; // NOTE: Only send a response for senders that actually sent a version. This // maintains compatibility with older senders that don't send a version and @@ -242,9 +242,9 @@ void ConnectionNamespaceHandler::HandleClose(CastSocket* socket, ToCastSocketId(socket)}; const auto reason = GetCloseReason(parsed_message); if (RemoveConnection(conn, reason)) { - OSP_DVLOG << "Connection closed (reason: " << reason - << "): " << conn.local_id << ", " << conn.peer_id << ", " - << conn.socket_id; + OSP_VLOG << "Connection closed (reason: " << reason + << "): " << conn.local_id << ", " << conn.peer_id << ", " + << conn.socket_id; } } diff --git a/cast/common/channel/message_util.cc b/cast/common/channel/message_util.cc index 92ea5007..f7f790bf 100644 --- a/cast/common/channel/message_util.cc +++ b/cast/common/channel/message_util.cc @@ -162,5 +162,12 @@ std::string MakeUniqueSessionId(const char* prefix) { return oss.str(); } +bool HasType(const Json::Value& object, CastMessageType type) { + OSP_DCHECK(object.isObject()); + const Json::Value& value = + object.get(kMessageKeyType, Json::Value::nullSingleton()); + return value.isString() && value.asString() == CastMessageTypeToString(type); +} + } // namespace cast } // namespace openscreen diff --git a/cast/common/channel/message_util.h b/cast/common/channel/message_util.h index 8e8fe823..6eef5b11 100644 --- a/cast/common/channel/message_util.h +++ b/cast/common/channel/message_util.h @@ -9,6 +9,11 @@ #include "absl/strings/string_view.h" #include "cast/common/channel/proto/cast_channel.pb.h" +#include "util/enum_name_table.h" + +namespace Json { +class Value; +} namespace openscreen { namespace cast { @@ -158,63 +163,35 @@ enum class AppAvailabilityResult { std::string ToString(AppAvailabilityResult availability); -// TODO(crbug.com/openscreen/111): When this and/or other enums need the -// string->enum mapping, import EnumTable from Chromium's -// //components/cast_channel/enum_table.h. -inline constexpr const char* CastMessageTypeToString(CastMessageType type) { - switch (type) { - case CastMessageType::kPing: - return "PING"; - case CastMessageType::kPong: - return "PONG"; - case CastMessageType::kRpc: - return "RPC"; - case CastMessageType::kGetAppAvailability: - return "GET_APP_AVAILABILITY"; - case CastMessageType::kGetStatus: - return "GET_STATUS"; - case CastMessageType::kConnect: - return "CONNECT"; - case CastMessageType::kCloseConnection: - return "CLOSE"; - case CastMessageType::kBroadcast: - return "APPLICATION_BROADCAST"; - case CastMessageType::kLaunch: - return "LAUNCH"; - case CastMessageType::kStop: - return "STOP"; - case CastMessageType::kReceiverStatus: - return "RECEIVER_STATUS"; - case CastMessageType::kMediaStatus: - return "MEDIA_STATUS"; - case CastMessageType::kLaunchError: - return "LAUNCH_ERROR"; - case CastMessageType::kOffer: - return "OFFER"; - case CastMessageType::kAnswer: - return "ANSWER"; - case CastMessageType::kCapabilitiesResponse: - return "CAPABILITIES_RESPONSE"; - case CastMessageType::kStatusResponse: - return "STATUS_RESPONSE"; - case CastMessageType::kMultizoneStatus: - return "MULTIZONE_STATUS"; - case CastMessageType::kInvalidPlayerState: - return "INVALID_PLAYER_STATE"; - case CastMessageType::kLoadFailed: - return "LOAD_FAILED"; - case CastMessageType::kLoadCancelled: - return "LOAD_CANCELLED"; - case CastMessageType::kInvalidRequest: - return "INVALID_REQUEST"; - case CastMessageType::kPresentation: - return "PRESENTATION"; - case CastMessageType::kGetCapabilities: - return "GET_CAPABILITIES"; - case CastMessageType::kOther: - default: - return "OTHER"; - } +static const EnumNameTable<CastMessageType, 25> kCastMessageTypeNames{ + {{"PING", CastMessageType::kPing}, + {"PONG", CastMessageType::kPong}, + {"RPC", CastMessageType::kRpc}, + {"GET_APP_AVAILABILITY", CastMessageType::kGetAppAvailability}, + {"GET_STATUS", CastMessageType::kGetStatus}, + {"CONNECT", CastMessageType::kConnect}, + {"CLOSE", CastMessageType::kCloseConnection}, + {"APPLICATION_BROADCAST", CastMessageType::kBroadcast}, + {"LAUNCH", CastMessageType::kLaunch}, + {"STOP", CastMessageType::kStop}, + {"RECEIVER_STATUS", CastMessageType::kReceiverStatus}, + {"MEDIA_STATUS", CastMessageType::kMediaStatus}, + {"LAUNCH_ERROR", CastMessageType::kLaunchError}, + {"OFFER", CastMessageType::kOffer}, + {"ANSWER", CastMessageType::kAnswer}, + {"CAPABILITIES_RESPONSE", CastMessageType::kCapabilitiesResponse}, + {"STATUS_RESPONSE", CastMessageType::kStatusResponse}, + {"MULTIZONE_STATUS", CastMessageType::kMultizoneStatus}, + {"INVALID_PLAYER_STATE", CastMessageType::kInvalidPlayerState}, + {"LOAD_FAILED", CastMessageType::kLoadFailed}, + {"LOAD_CANCELLED", CastMessageType::kLoadCancelled}, + {"INVALID_REQUEST", CastMessageType::kInvalidRequest}, + {"PRESENTATION", CastMessageType::kPresentation}, + {"GET_CAPABILITIES", CastMessageType::kGetCapabilities}, + {"OTHER", CastMessageType::kOther}}}; + +inline const char* CastMessageTypeToString(CastMessageType type) { + return GetEnumName(kCastMessageTypeNames, type).value("OTHER"); } inline bool IsAuthMessage(const ::cast::channel::CastMessage& message) { @@ -242,6 +219,8 @@ inline bool IsTransportNamespace(absl::string_view namespace_) { // |prefix| of "sender" will result in a string like "sender-12345". std::string MakeUniqueSessionId(const char* prefix); +// Returns true if the type field in |object| is set to the given |type|. +bool HasType(const Json::Value& object, CastMessageType type); } // namespace cast } // namespace openscreen diff --git a/cast/common/discovery/e2e_test/tests.cc b/cast/common/discovery/e2e_test/tests.cc index 7c294418..3f316ae3 100644 --- a/cast/common/discovery/e2e_test/tests.cc +++ b/cast/common/discovery/e2e_test/tests.cc @@ -11,7 +11,7 @@ // ASSERTS due to asynchronous concerns around test failures. // Although this causes the entire test binary to fail instead of // just a single test, it makes debugging easier/possible. -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include "discovery/common/config.h" #include "discovery/common/reporting_client.h" #include "discovery/public/dns_sd_service_factory.h" @@ -44,12 +44,12 @@ constexpr milliseconds kCheckLoopSleepTime(100); constexpr int kMaxCheckLoopIterations = 25; // Publishes new service instances. -class Publisher : public discovery::DnsSdServicePublisher<ServiceInfo> { +class Publisher : public discovery::DnsSdServicePublisher<ReceiverInfo> { public: explicit Publisher(discovery::DnsSdService* service) // NOLINT - : DnsSdServicePublisher<ServiceInfo>(service, - kCastV2ServiceId, - ServiceInfoToDnsSdInstance) { + : DnsSdServicePublisher<ReceiverInfo>(service, + kCastV2ServiceId, + ReceiverInfoToDnsSdInstance) { OSP_LOG_INFO << "Initializing Publisher...\n"; } @@ -71,40 +71,40 @@ class Publisher : public discovery::DnsSdServicePublisher<ServiceInfo> { }; // Receives incoming services and outputs their results to stdout. -class ServiceReceiver : public discovery::DnsSdServiceWatcher<ServiceInfo> { +class ServiceReceiver : public discovery::DnsSdServiceWatcher<ReceiverInfo> { public: explicit ServiceReceiver(discovery::DnsSdService* service) // NOLINT - : discovery::DnsSdServiceWatcher<ServiceInfo>( + : discovery::DnsSdServiceWatcher<ReceiverInfo>( service, kCastV2ServiceId, - DnsSdInstanceEndpointToServiceInfo, + DnsSdInstanceEndpointToReceiverInfo, [this]( - std::vector<std::reference_wrapper<const ServiceInfo>> infos) { + std::vector<std::reference_wrapper<const ReceiverInfo>> infos) { ProcessResults(std::move(infos)); }) { OSP_LOG_INFO << "Initializing ServiceReceiver..."; } - bool IsServiceFound(const ServiceInfo& check_service) { - return std::find_if(service_infos_.begin(), service_infos_.end(), - [&check_service](const ServiceInfo& info) { + bool IsServiceFound(const ReceiverInfo& check_service) { + return std::find_if(receiver_infos_.begin(), receiver_infos_.end(), + [&check_service](const ReceiverInfo& info) { return info.friendly_name == check_service.friendly_name; - }) != service_infos_.end(); + }) != receiver_infos_.end(); } - void EraseReceivedServices() { service_infos_.clear(); } + void EraseReceivedServices() { receiver_infos_.clear(); } private: void ProcessResults( - std::vector<std::reference_wrapper<const ServiceInfo>> infos) { - service_infos_.clear(); - for (const ServiceInfo& info : infos) { - service_infos_.push_back(info); + std::vector<std::reference_wrapper<const ReceiverInfo>> infos) { + receiver_infos_.clear(); + for (const ReceiverInfo& info : infos) { + receiver_infos_.push_back(info); } } - std::vector<ServiceInfo> service_infos_; + std::vector<ReceiverInfo> receiver_infos_; }; class FailOnErrorReporting : public discovery::ReportingClient { @@ -125,16 +125,7 @@ discovery::Config GetConfigSettings() { // Get the loopback interface to run on. InterfaceInfo loopback = GetLoopbackInterfaceForTesting().value(); OSP_LOG_INFO << "Selected network interface for testing: " << loopback; - discovery::Config::NetworkInfo::AddressFamilies address_families = - discovery::Config::NetworkInfo::kNoAddressFamily; - if (loopback.GetIpAddressV4()) { - address_families |= discovery::Config::NetworkInfo::kUseIpV4; - } - if (loopback.GetIpAddressV6()) { - address_families |= discovery::Config::NetworkInfo::kUseIpV6; - } - - return discovery::Config{{{std::move(loopback), address_families}}}; + return discovery::Config{{std::move(loopback)}}; } class DiscoveryE2ETest : public testing::Test { @@ -154,8 +145,8 @@ class DiscoveryE2ETest : public testing::Test { } protected: - ServiceInfo GetInfo(int id) { - ServiceInfo hosted_service; + ReceiverInfo GetInfo(int id) { + ReceiverInfo hosted_service; hosted_service.port = 1234; hosted_service.unique_id = "id" + std::to_string(id); hosted_service.model_name = "openscreen-Model" + std::to_string(id); @@ -188,8 +179,8 @@ class DiscoveryE2ETest : public testing::Test { OSP_DCHECK(dnssd_service_.get()); OSP_DCHECK(publisher_.get()); - std::vector<ServiceInfo> record_set{std::move(records)...}; - for (ServiceInfo& record : record_set) { + std::vector<ReceiverInfo> record_set{std::move(records)...}; + for (ReceiverInfo& record : record_set) { task_runner_->PostTask([this, r = std::move(record)]() { auto error = publisher_->UpdateRegistration(r); OSP_CHECK(error.ok()) << "\tFailed to update service instance '" @@ -203,8 +194,8 @@ class DiscoveryE2ETest : public testing::Test { OSP_DCHECK(dnssd_service_.get()); OSP_DCHECK(publisher_.get()); - std::vector<ServiceInfo> record_set{std::move(records)...}; - for (ServiceInfo& record : record_set) { + std::vector<ReceiverInfo> record_set{std::move(records)...}; + for (ReceiverInfo& record : record_set) { task_runner_->PostTask([this, r = std::move(record)]() { auto error = publisher_->Register(r); OSP_CHECK(error.ok()) << "\tFailed to publish service instance '" @@ -239,20 +230,20 @@ class DiscoveryE2ETest : public testing::Test { << "Could not find " << waiting_on << " service instances!"; } - void CheckForClaimedIds(ServiceInfo service_info, + void CheckForClaimedIds(ReceiverInfo receiver_info, std::atomic_bool* has_been_seen) { OSP_DCHECK(dnssd_service_.get()); task_runner_->PostTask( - [this, info = std::move(service_info), has_been_seen]() mutable { + [this, info = std::move(receiver_info), has_been_seen]() mutable { CheckForClaimedIds(std::move(info), has_been_seen, 0); }); } - void CheckForPublishedService(ServiceInfo service_info, + void CheckForPublishedService(ReceiverInfo receiver_info, std::atomic_bool* has_been_seen) { OSP_DCHECK(dnssd_service_.get()); task_runner_->PostTask( - [this, info = std::move(service_info), has_been_seen]() mutable { + [this, info = std::move(receiver_info), has_been_seen]() mutable { CheckForPublishedService(std::move(info), has_been_seen, 0, true); }); } @@ -260,11 +251,11 @@ class DiscoveryE2ETest : public testing::Test { // TODO(issuetracker.google.com/159256503): Change this to use a polling // method to wait until the service disappears rather than immediately failing // if it exists, so waits throughout this file can be removed. - void CheckNotPublishedService(ServiceInfo service_info, + void CheckNotPublishedService(ReceiverInfo receiver_info, std::atomic_bool* has_been_seen) { OSP_DCHECK(dnssd_service_.get()); task_runner_->PostTask( - [this, info = std::move(service_info), has_been_seen]() mutable { + [this, info = std::move(receiver_info), has_been_seen]() mutable { CheckForPublishedService(std::move(info), has_been_seen, 0, false); }); } @@ -275,37 +266,38 @@ class DiscoveryE2ETest : public testing::Test { std::unique_ptr<Publisher> publisher_; private: - void CheckForClaimedIds(ServiceInfo service_info, + void CheckForClaimedIds(ReceiverInfo receiver_info, std::atomic_bool* has_been_seen, int attempts) { - if (publisher_->IsInstanceIdClaimed(service_info.GetInstanceId())) { + if (publisher_->IsInstanceIdClaimed(receiver_info.GetInstanceId())) { // TODO(crbug.com/openscreen/110): Log the published service instance. *has_been_seen = true; return; } OSP_CHECK_LE(attempts++, kMaxCheckLoopIterations) - << "Service " << service_info.friendly_name << " publication failed."; + << "Service " << receiver_info.friendly_name << " publication failed."; task_runner_->PostTaskWithDelay( - [this, info = std::move(service_info), has_been_seen, + [this, info = std::move(receiver_info), has_been_seen, attempts]() mutable { CheckForClaimedIds(std::move(info), has_been_seen, attempts); }, kCheckLoopSleepTime); } - void CheckForPublishedService(ServiceInfo service_info, + void CheckForPublishedService(ReceiverInfo receiver_info, std::atomic_bool* has_been_seen, int attempts, bool expect_to_be_present) { - if (!receiver_->IsServiceFound(service_info)) { + if (!receiver_->IsServiceFound(receiver_info)) { if (attempts++ > kMaxCheckLoopIterations) { OSP_CHECK(!expect_to_be_present) - << "Service " << service_info.friendly_name << " discovery failed."; + << "Service " << receiver_info.friendly_name + << " discovery failed."; return; } task_runner_->PostTaskWithDelay( - [this, info = std::move(service_info), has_been_seen, attempts, + [this, info = std::move(receiver_info), has_been_seen, attempts, expect_to_be_present]() mutable { CheckForPublishedService(std::move(info), has_been_seen, attempts, expect_to_be_present); @@ -315,7 +307,8 @@ class DiscoveryE2ETest : public testing::Test { // TODO(crbug.com/openscreen/110): Log the discovered service instance. *has_been_seen = true; } else { - OSP_LOG_FATAL << "Found instance '" << service_info.friendly_name << "'!"; + OSP_LOG_FATAL << "Found instance '" << receiver_info.friendly_name + << "'!"; } } }; diff --git a/cast/common/public/DEPS b/cast/common/public/DEPS index c098d4db..d31bade2 100644 --- a/cast/common/public/DEPS +++ b/cast/common/public/DEPS @@ -4,5 +4,6 @@ include_rules = [ # Dependencies on the implementation are not allowed in public/. '-cast/common', '+cast/common/public', - '+discovery/dnssd/public' + '+discovery/dnssd/public', + '+discovery/mdns/public' ] diff --git a/cast/common/public/cast_socket.h b/cast/common/public/cast_socket.h index 5c0b8775..330b1962 100644 --- a/cast/common/public/cast_socket.h +++ b/cast/common/public/cast_socket.h @@ -28,13 +28,15 @@ class CastSocket : public TlsConnection::Client { public: class Client { public: - virtual ~Client() = default; // Called when a terminal error on |socket| has occurred. virtual void OnError(CastSocket* socket, Error error) = 0; virtual void OnMessage(CastSocket* socket, ::cast::channel::CastMessage message) = 0; + + protected: + virtual ~Client(); }; CastSocket(std::unique_ptr<TlsConnection> connection, Client* client); diff --git a/cast/common/public/message_port.h b/cast/common/public/message_port.h index 23094740..91eabcd1 100644 --- a/cast/common/public/message_port.h +++ b/cast/common/public/message_port.h @@ -20,11 +20,13 @@ class MessagePort { public: class Client { public: - virtual ~Client() = default; virtual void OnMessage(const std::string& source_sender_id, const std::string& message_namespace, const std::string& message) = 0; virtual void OnError(Error error) = 0; + + protected: + virtual ~Client() = default; }; virtual ~MessagePort() = default; diff --git a/cast/common/public/service_info.cc b/cast/common/public/receiver_info.cc index 732688f8..ec45efea 100644 --- a/cast/common/public/service_info.cc +++ b/cast/common/public/receiver_info.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include <cctype> #include <cinttypes> @@ -11,49 +11,47 @@ #include "absl/strings/numbers.h" #include "absl/strings/str_replace.h" +#include "discovery/mdns/public/mdns_constants.h" #include "util/osp_logging.h" namespace openscreen { namespace cast { namespace { -// Maximum size for registered MDNS service instance names. -const size_t kMaxDeviceNameSize = 63; - -// Maximum size for the device model prefix at start of MDNS service instance +// Maximum size for the receiver model prefix at start of MDNS service instance // names. Any model names that are larger than this size will be truncated. -const size_t kMaxDeviceModelSize = 20; +const size_t kMaxReceiverModelSize = 20; -// Build the MDNS instance name for service. This will be the device model (up -// to 20 bytes) appended with the virtual device ID (device UUID) and optionally -// appended with extension at the end to resolve name conflicts. The total MDNS -// service instance name is kept below 64 bytes so it can easily fit into a -// single domain name label. +// Build the MDNS instance name for service. This will be the receiver model (up +// to 20 bytes) appended with the virtual receiver ID (receiver UUID) and +// optionally appended with extension at the end to resolve name conflicts. The +// total MDNS service instance name is kept below 64 bytes so it can easily fit +// into a single domain name label. // // NOTE: This value is based on what is currently done by Eureka, not what is // called out in the CastV2 spec. Eureka uses |model|-|uuid|, so the same // convention will be followed here. That being said, the Eureka receiver does // not use the instance ID in any way, so the specific calculation used should // not be important. -std::string CalculateInstanceId(const ServiceInfo& info) { - // First set the device model, truncated to 20 bytes at most. Replace any - // whitespace characters (" ") with hyphens ("-") in the device model before +std::string CalculateInstanceId(const ReceiverInfo& info) { + // First set the receiver model, truncated to 20 bytes at most. Replace any + // whitespace characters (" ") with hyphens ("-") in the receiver model before // truncation. std::string instance_name = absl::StrReplaceAll(info.model_name, {{" ", "-"}}); - instance_name = std::string(instance_name, 0, kMaxDeviceModelSize); + instance_name = std::string(instance_name, 0, kMaxReceiverModelSize); - // Append the virtual device ID to the instance name separated by a single - // '-' character if not empty. Strip all hyphens from the device ID prior + // Append the receiver ID to the instance name separated by a single + // '-' character if not empty. Strip all hyphens from the receiver ID prior // to appending it. - std::string device_id = absl::StrReplaceAll(info.unique_id, {{"-", ""}}); + std::string receiver_id = absl::StrReplaceAll(info.unique_id, {{"-", ""}}); if (!instance_name.empty()) { instance_name.push_back('-'); } - instance_name.append(device_id); + instance_name.append(receiver_id); - return std::string(instance_name, 0, kMaxDeviceNameSize); + return std::string(instance_name, 0, discovery::kMaxLabelLength); } // Returns the value for the provided |key| in the |txt| record if it exists; @@ -71,7 +69,7 @@ std::string GetStringFromRecord(const discovery::DnsSdTxtRecord& txt, } // namespace -const std::string& ServiceInfo::GetInstanceId() const { +const std::string& ReceiverInfo::GetInstanceId() const { if (instance_id_ == std::string("")) { instance_id_ = CalculateInstanceId(*this); } @@ -79,7 +77,7 @@ const std::string& ServiceInfo::GetInstanceId() const { return instance_id_; } -bool ServiceInfo::IsValid() const { +bool ReceiverInfo::IsValid() const { return ( discovery::IsInstanceValid(GetInstanceId()) && port != 0 && !unique_id.empty() && @@ -98,7 +96,7 @@ bool ServiceInfo::IsValid() const { friendly_name)); } -discovery::DnsSdInstance ServiceInfoToDnsSdInstance(const ServiceInfo& info) { +discovery::DnsSdInstance ReceiverInfoToDnsSdInstance(const ReceiverInfo& info) { OSP_DCHECK(discovery::IsServiceValid(kCastV2ServiceId)); OSP_DCHECK(discovery::IsDomainValid(kCastV2DomainId)); @@ -121,13 +119,13 @@ discovery::DnsSdInstance ServiceInfoToDnsSdInstance(const ServiceInfo& info) { kCastV2DomainId, std::move(txt), info.port); } -ErrorOr<ServiceInfo> DnsSdInstanceEndpointToServiceInfo( +ErrorOr<ReceiverInfo> DnsSdInstanceEndpointToReceiverInfo( const discovery::DnsSdInstanceEndpoint& endpoint) { if (endpoint.service_id() != kCastV2ServiceId) { - return {Error::Code::kParameterInvalid, "Not a Cast device."}; + return {Error::Code::kParameterInvalid, "Not a Cast receiver."}; } - ServiceInfo record; + ReceiverInfo record; for (const IPAddress& address : endpoint.addresses()) { if (!record.v4_address && address.IsV4()) { record.v4_address = address; @@ -148,7 +146,7 @@ ErrorOr<ServiceInfo> DnsSdInstanceEndpointToServiceInfo( record.unique_id = GetStringFromRecord(endpoint.txt(), kUniqueIdKey); if (record.unique_id.empty()) { return {Error::Code::kParameterInvalid, - "Missing device unique ID in record."}; + "Missing receiver unique ID in record."}; } // Cast protocol version supported. Begins at 2 and is incremented by 1 with @@ -169,15 +167,15 @@ ErrorOr<ServiceInfo> DnsSdInstanceEndpointToServiceInfo( } record.protocol_version = static_cast<uint8_t>(version); - // A bitset of device capabilities. + // A bitset of receiver capabilities. a_decimal_number = GetStringFromRecord(endpoint.txt(), kCapabilitiesKey); if (a_decimal_number.empty()) { return {Error::Code::kParameterInvalid, - "Missing device capabilities in record."}; + "Missing receiver capabilities in record."}; } if (!absl::SimpleAtoi(a_decimal_number, &record.capabilities)) { return {Error::Code::kParameterInvalid, - "Invalid device capabilities field in record."}; + "Invalid receiver capabilities field in record."}; } // Receiver status flag. @@ -194,11 +192,11 @@ ErrorOr<ServiceInfo> DnsSdInstanceEndpointToServiceInfo( // [Optional] Receiver model name. record.model_name = GetStringFromRecord(endpoint.txt(), kModelNameKey); - // The friendly name of the device. + // The friendly name of the receiver. record.friendly_name = GetStringFromRecord(endpoint.txt(), kFriendlyNameKey); if (record.friendly_name.empty()) { return {Error::Code::kParameterInvalid, - "Missing device friendly name in record."}; + "Missing receiver friendly name in record."}; } return record; diff --git a/cast/common/public/service_info.h b/cast/common/public/receiver_info.h index 301ef99f..c4e82c82 100644 --- a/cast/common/public/service_info.h +++ b/cast/common/public/receiver_info.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef CAST_COMMON_PUBLIC_SERVICE_INFO_H_ -#define CAST_COMMON_PUBLIC_SERVICE_INFO_H_ +#ifndef CAST_COMMON_PUBLIC_RECEIVER_INFO_H_ +#define CAST_COMMON_PUBLIC_RECEIVER_INFO_H_ #include <memory> #include <string> @@ -53,14 +53,13 @@ constexpr uint64_t kIsDevModeEnabled = 1 << 4; constexpr uint64_t kNoCapabilities = 0; -// This is the top-level service info class for CastV2. It describes a specific +// This is the top-level receiver info class for CastV2. It describes a specific // service instance. -// TODO(crbug.com/openscreen/112): Rename this to CastReceiverInfo or similar. -struct ServiceInfo { - // returns the instance id associated with this ServiceInfo instance. +struct ReceiverInfo { + // returns the instance id associated with this ReceiverInfo instance. const std::string& GetInstanceId() const; - // Returns whether all fields of this ServiceInfo are valid. + // Returns whether all fields of this ReceiverInfo are valid. bool IsValid() const; // Addresses for the service. Present if an address of this address type @@ -88,17 +87,17 @@ struct ServiceInfo { // Status of the service instance. ReceiverStatus status = ReceiverStatus::kIdle; - // The model name of the device, e.g. “Eureka v1”, “Mollie”. + // The model name of the receiver, e.g. “Eureka v1”, “Mollie”. std::string model_name; - // The friendly name of the device, e.g. “Living Room TV". + // The friendly name of the receiver, e.g. “Living Room TV". std::string friendly_name; private: mutable std::string instance_id_ = ""; }; -inline bool operator==(const ServiceInfo& lhs, const ServiceInfo& rhs) { +inline bool operator==(const ReceiverInfo& lhs, const ReceiverInfo& rhs) { return lhs.v4_address == rhs.v4_address && lhs.v6_address == rhs.v6_address && lhs.port == rhs.port && lhs.unique_id == rhs.unique_id && lhs.protocol_version == rhs.protocol_version && @@ -107,18 +106,19 @@ inline bool operator==(const ServiceInfo& lhs, const ServiceInfo& rhs) { lhs.friendly_name == rhs.friendly_name; } -inline bool operator!=(const ServiceInfo& lhs, const ServiceInfo& rhs) { +inline bool operator!=(const ReceiverInfo& lhs, const ReceiverInfo& rhs) { return !(lhs == rhs); } // Functions responsible for converting between CastV2 and DNS-SD // representations of a service instance. -discovery::DnsSdInstance ServiceInfoToDnsSdInstance(const ServiceInfo& service); +discovery::DnsSdInstance ReceiverInfoToDnsSdInstance( + const ReceiverInfo& service); -ErrorOr<ServiceInfo> DnsSdInstanceEndpointToServiceInfo( +ErrorOr<ReceiverInfo> DnsSdInstanceEndpointToReceiverInfo( const discovery::DnsSdInstanceEndpoint& endpoint); } // namespace cast } // namespace openscreen -#endif // CAST_COMMON_PUBLIC_SERVICE_INFO_H_ +#endif // CAST_COMMON_PUBLIC_RECEIVER_INFO_H_ diff --git a/cast/common/public/service_info_unittest.cc b/cast/common/public/receiver_info_unittest.cc index 08401a45..a7b16e2d 100644 --- a/cast/common/public/service_info_unittest.cc +++ b/cast/common/public/receiver_info_unittest.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include <cstdio> #include <sstream> @@ -20,13 +20,13 @@ constexpr NetworkInterfaceIndex kNetworkInterface = 0; } -TEST(ServiceInfoTests, ConvertValidFromDnsSd) { +TEST(ReceiverInfoTests, ConvertValidFromDnsSd) { std::string instance = "InstanceId"; discovery::DnsSdTxtRecord txt = CreateValidTxt(); discovery::DnsSdInstanceEndpoint record( instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, kEndpointV4, kEndpointV6); - ErrorOr<ServiceInfo> info = DnsSdInstanceEndpointToServiceInfo(record); + ErrorOr<ReceiverInfo> info = DnsSdInstanceEndpointToReceiverInfo(record); ASSERT_TRUE(info.is_value()) << info; EXPECT_EQ(info.value().unique_id, kTestUniqueId); EXPECT_TRUE(info.value().v4_address); @@ -44,7 +44,7 @@ TEST(ServiceInfoTests, ConvertValidFromDnsSd) { record = discovery::DnsSdInstanceEndpoint(instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, kEndpointV4); - info = DnsSdInstanceEndpointToServiceInfo(record); + info = DnsSdInstanceEndpointToReceiverInfo(record); ASSERT_TRUE(info.is_value()); EXPECT_EQ(info.value().unique_id, kTestUniqueId); EXPECT_TRUE(info.value().v4_address); @@ -60,7 +60,7 @@ TEST(ServiceInfoTests, ConvertValidFromDnsSd) { record = discovery::DnsSdInstanceEndpoint(instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, kEndpointV6); - info = DnsSdInstanceEndpointToServiceInfo(record); + info = DnsSdInstanceEndpointToReceiverInfo(record); ASSERT_TRUE(info.is_value()); EXPECT_EQ(info.value().unique_id, kTestUniqueId); EXPECT_FALSE(info.value().v4_address); @@ -74,42 +74,42 @@ TEST(ServiceInfoTests, ConvertValidFromDnsSd) { EXPECT_EQ(info.value().friendly_name, kFriendlyName); } -TEST(ServiceInfoTests, ConvertInvalidFromDnsSd) { +TEST(ReceiverInfoTests, ConvertInvalidFromDnsSd) { std::string instance = "InstanceId"; discovery::DnsSdTxtRecord txt = CreateValidTxt(); txt.ClearValue(kUniqueIdKey); discovery::DnsSdInstanceEndpoint record( instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, kEndpointV4, kEndpointV6); - EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); + EXPECT_TRUE(DnsSdInstanceEndpointToReceiverInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kVersionKey); record = discovery::DnsSdInstanceEndpoint( instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, kEndpointV4, kEndpointV6); - EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); + EXPECT_TRUE(DnsSdInstanceEndpointToReceiverInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kCapabilitiesKey); record = discovery::DnsSdInstanceEndpoint( instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, kEndpointV4, kEndpointV6); - EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); + EXPECT_TRUE(DnsSdInstanceEndpointToReceiverInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kStatusKey); record = discovery::DnsSdInstanceEndpoint( instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, kEndpointV4, kEndpointV6); - EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); + EXPECT_TRUE(DnsSdInstanceEndpointToReceiverInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kFriendlyNameKey); record = discovery::DnsSdInstanceEndpoint( instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, kEndpointV4, kEndpointV6); - EXPECT_TRUE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); + EXPECT_TRUE(DnsSdInstanceEndpointToReceiverInfo(record).is_error()); txt = CreateValidTxt(); txt.ClearValue(kModelNameKey); @@ -117,11 +117,11 @@ TEST(ServiceInfoTests, ConvertInvalidFromDnsSd) { instance, kCastV2ServiceId, kCastV2DomainId, txt, kNetworkInterface, kEndpointV4, kEndpointV6); // Note: Model name is an optional field. - EXPECT_FALSE(DnsSdInstanceEndpointToServiceInfo(record).is_error()); + EXPECT_FALSE(DnsSdInstanceEndpointToReceiverInfo(record).is_error()); } -TEST(ServiceInfoTests, ConvertValidToDnsSd) { - ServiceInfo info; +TEST(ReceiverInfoTests, ConvertValidToDnsSd) { + ReceiverInfo info; info.v4_address = kAddressV4; info.v6_address = kAddressV6; info.port = kPort; @@ -131,7 +131,7 @@ TEST(ServiceInfoTests, ConvertValidToDnsSd) { info.status = kStatusParsed; info.model_name = kModelName; info.friendly_name = kFriendlyName; - discovery::DnsSdInstance instance = ServiceInfoToDnsSdInstance(info); + discovery::DnsSdInstance instance = ReceiverInfoToDnsSdInstance(info); CompareTxtString(instance.txt(), kUniqueIdKey, kTestUniqueId); CompareTxtString(instance.txt(), kCapabilitiesKey, kCapabilitiesString); CompareTxtString(instance.txt(), kModelNameKey, kModelName); @@ -140,7 +140,7 @@ TEST(ServiceInfoTests, ConvertValidToDnsSd) { CompareTxtInt(instance.txt(), kStatusKey, kStatus); } -TEST(ServiceInfoTests, ParseServiceInfoFromRealTXT) { +TEST(ReceiverInfoTests, ParseReceiverInfoFromRealTXT) { constexpr struct { const char* key; const char* value; @@ -168,9 +168,9 @@ TEST(ServiceInfoTests, ParseServiceInfoFromRealTXT) { "InstanceId", kCastV2ServiceId, kCastV2DomainId, std::move(txt), kNetworkInterface, kEndpointV4, kEndpointV6); - const ErrorOr<ServiceInfo> result = - DnsSdInstanceEndpointToServiceInfo(record); - const ServiceInfo& info = result.value(); + const ErrorOr<ReceiverInfo> result = + DnsSdInstanceEndpointToReceiverInfo(record); + const ReceiverInfo& info = result.value(); EXPECT_EQ(info.unique_id, "4ef522244a5a877f35ddead7d98702e6"); EXPECT_EQ(info.protocol_version, 5); EXPECT_TRUE(info.capabilities & (kHasVideoOutput | kHasAudioOutput)); diff --git a/cast/common/public/testing/discovery_utils.h b/cast/common/public/testing/discovery_utils.h index b0d7d03b..6e805e74 100644 --- a/cast/common/public/testing/discovery_utils.h +++ b/cast/common/public/testing/discovery_utils.h @@ -7,7 +7,7 @@ #include <string> -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include "discovery/dnssd/public/dns_sd_txt_record.h" #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/cast/protocol/BUILD.gn b/cast/protocol/BUILD.gn index 590caed3..a68bb989 100644 --- a/cast/protocol/BUILD.gn +++ b/cast/protocol/BUILD.gn @@ -63,11 +63,14 @@ source_set("unittests") { deps = [ ":castv2", + ":castv2_schema_headers", ":receiver_examples", ":streaming_examples", + "../../platform:base", "../../third_party/abseil", "../../third_party/googletest:gmock", "../../third_party/googletest:gtest", + "../../util:base", "//third_party/valijson", ] } diff --git a/cast/protocol/castv2/streaming_examples/answer.json b/cast/protocol/castv2/streaming_examples/answer.json index 73c45ea2..0e115fc1 100644 --- a/cast/protocol/castv2/streaming_examples/answer.json +++ b/cast/protocol/castv2/streaming_examples/answer.json @@ -16,7 +16,7 @@ }, "video": { "maxPixelsPerSecond": 62208000, - "minDimensions": {"width": 320, "height": 240, "frameRate": "23/3"}, + "minResolution": {"width": 320, "height": 240}, "maxDimensions": {"width": 1920, "height": 1080, "frameRate": "60"}, "minBitRate": 300000, "maxBitRate": 10000000, @@ -30,7 +30,6 @@ }, "receiverRtcpEventLog": [0, 1], "receiverRtcpDscp": [234, 567], - "receiverGetStatus": true, "rtpExtensions": ["adaptive_playout_delay"] } }
\ No newline at end of file diff --git a/cast/protocol/castv2/streaming_examples/offer.json b/cast/protocol/castv2/streaming_examples/offer.json index 339b6d15..b6162112 100644 --- a/cast/protocol/castv2/streaming_examples/offer.json +++ b/cast/protocol/castv2/streaming_examples/offer.json @@ -1,7 +1,6 @@ { "offer": { "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": [ { "aesIvMask": "64A6AAC2821880145271BB15B0188821", @@ -37,9 +36,64 @@ "targetDelay": 400, "timeBase": "1/90000", "type": "video_source" + }, + { + "aesIvMask": "64A6AAC2821880145271BB15B0188821", + "aesKey": "65386FD9BCC30BC7FB6A4DD1D3B0FA5E", + "codecName": "h264", + "codecParameter": "avc1.4D4028", + "index": 2, + "maxBitRate": 4000000, + "maxFrameRate": "25", + "receiverRtcpEventLog": false, + "renderMode": "video", + "resolutions": [{"height": 720, "width": 1280}], + "rtpExtensions": ["adaptive_playout_delay"], + "rtpPayloadType": 97, + "rtpProfile": "cast", + "ssrc": 748229, + "targetDelay": 400, + "timeBase": "1/90000", + "type": "video_source" + }, + { + "aesIvMask": "64A6AAC2821880145271BB15B0188821", + "aesKey": "65386FD9BCC30BC7FB6A4DD1D3B0FA5E", + "codecName": "vp9", + "index": 2, + "maxBitRate": 5000000, + "maxFrameRate": "30000/1000", + "receiverRtcpEventLog": true, + "renderMode": "video", + "resolutions": [{"height": 1080, "width": 1920}], + "rtpExtensions": ["adaptive_playout_delay"], + "rtpPayloadType": 96, + "rtpProfile": "cast", + "ssrc": 748230, + "targetDelay": 400, + "timeBase": "1/90000", + "type": "video_source" + }, + { + "aesIvMask": "64A6AAC2821880145271BB15B0188821", + "aesKey": "65386FD9BCC30BC7FB6A4DD1D3B0FA5E", + "codecName": "av1", + "index": 3, + "maxBitRate": 5000000, + "maxFrameRate": "30000/1000", + "receiverRtcpEventLog": true, + "renderMode": "video", + "resolutions": [{"height": 1080, "width": 1920}], + "rtpExtensions": ["adaptive_playout_delay"], + "rtpPayloadType": 96, + "rtpProfile": "cast", + "ssrc": 748231, + "targetDelay": 400, + "timeBase": "1/90000", + "type": "video_source" } ] }, - "seqNum": 0, + "seqNum": 123, "type": "OFFER" -}
\ No newline at end of file +} diff --git a/cast/protocol/castv2/streaming_examples/rpc.json b/cast/protocol/castv2/streaming_examples/rpc.json index 6ebfc0e0..880ca641 100644 --- a/cast/protocol/castv2/streaming_examples/rpc.json +++ b/cast/protocol/castv2/streaming_examples/rpc.json @@ -1,5 +1,4 @@ { - "seqNum": 12345, "sessionId": 735189, "type": "RPC", "result": "ok", diff --git a/cast/protocol/castv2/streaming_schema.json b/cast/protocol/castv2/streaming_schema.json index 392d135c..4c78d526 100644 --- a/cast/protocol/castv2/streaming_schema.json +++ b/cast/protocol/castv2/streaming_schema.json @@ -27,7 +27,8 @@ "properties": { "index": {"type": "integer", "minimum": 0}, "type": {"type": "string", "enum": ["audio_source", "video_source"]}, - "codecName": {"type": "string"}, + "codecName": {"type": "string", "enum": ["aac", "opus", "h264", "vp8", "hevc", "vp9", "av1"]}, + "codecParameter": {"type": "string"}, "rtpProfile": {"type": "string", "enum": ["cast"]}, "rtpPayloadType": {"type": "integer", "minimum": 96, "maximum": 127}, "ssrc": {"$ref": "#/definitions/ssrc"}, @@ -115,16 +116,13 @@ "video_constraints": { "properties": { "maxPixelsPerSecond": {"type": "number", "minimum": 0}, - "minDimensions": {"$ref": "#/definitions/dimensions"}, + "minResolution": {"$ref": "#/definitions/resolution"}, "maxDimensions": {"$ref": "#/definitions/dimensions"}, "minBitRate": {"type": "integer", "minimum": 300000}, "maxBitRate": {"type": "integer", "minimum": 300000}, "maxDelay": {"$ref": "#/definitions/delay"} }, - "required": [ - "maxDimensions", - "maxBitRate" - ] + "required": ["maxDimensions", "maxBitRate"] }, "constraints": { "properties": { @@ -168,24 +166,10 @@ "type": "array", "items": {"type": "integer", "minimum": 0} }, - "receiverGetStatus": {"type": "boolean"}, "rtpExtensions": {"$ref": "#/definitions/rtp_extensions"} }, "required": ["udpPort", "sendIndexes", "ssrcs"] }, - "status_response": { - "properties": { - "wifiSpeed": { - "type": "array", - "items": {"type": "integer", "minimum": 0} - }, - "wifiFcsError": { - "type": "array", - "items": {"type": "integer", "minimum": 0} - }, - "wifiSnr": {"type": "number", "examples": ["3.23", "50.1"]} - } - }, "capabilities": { "$id": "#capabilities", "type": "object", @@ -221,14 +205,6 @@ "result": {"type": "string", "enum": ["ok", "error"]}, "seqNum": {"type": "integer", "minimum": 0}, "sessionId": {"type": "integer"}, - "get_status": { - "type": "array", - "items": { - "type": "string", - "enum": ["wifiFcsError", "wifiSnr", "wifiSpeed"] - } - }, - "status": {"$ref": "#/definitions/status_response"}, "type": { "type": "string", "enum": [ @@ -242,15 +218,36 @@ ] } }, - "required": ["type", "seqNum"], + "required": ["type"], "allOf": [ { "if": { - "properties": {"type": {"enum": ["ANSWER", "CAPABILITIES_RESPONSE", "STATUS_RESPONSE"]}} + "properties": { + "type": { + "enum": ["ANSWER", "CAPABILITIES_RESPONSE", "STATUS_RESPONSE"] + } + } }, "then": {"required": ["result"]} }, { + "if": { + "properties": { + "type": { + "enum": [ + "OFFER", + "ANSWER", + "GET_CAPABILITIES", + "CAPABILITIES_RESPONSE", + "GET_STATUS", + "STATUS_RESPONSE" + ] + } + } + }, + "then": {"required": ["seqNum"]} + }, + { "if": {"properties": {"type": {"const": "OFFER"}}}, "then": {"required": ["offer"]} }, @@ -270,18 +267,10 @@ "then": {"required": ["capabilities"]} }, { - "if": {"properties": {"type": {"const": "GET_STATUS"}}}, - "then": {"required": ["get_status"]} - }, - { - "if": {"properties": {"type": {"const": "STATUS_RESPONSE"}}}, - "then": {"required": ["status"]} - }, - { "if": { "properties": {"type": {"const": "RPC"}, "result": {"const": "ok"}} }, "then": {"required": ["rpc"]} } ] -}
\ No newline at end of file +} diff --git a/cast/protocol/castv2/validation.cc b/cast/protocol/castv2/validation.cc index 67a9b351..a87dd5e3 100644 --- a/cast/protocol/castv2/validation.cc +++ b/cast/protocol/castv2/validation.cc @@ -32,9 +32,6 @@ std::vector<Error> MapErrors(const valijson::ValidationResults& results) { errors.emplace_back(Error::Code::kJsonParseError, StringPrintf("Node: %s, Message: %s", context.c_str(), result.description.c_str())); - - OSP_DVLOG << "JsonCpp validation error: " - << errors.at(errors.size() - 1).message(); } return errors; } diff --git a/cast/protocol/castv2/validation_unittest.cc b/cast/protocol/castv2/validation_unittest.cc index 46ded40a..d6c0e70e 100644 --- a/cast/protocol/castv2/validation_unittest.cc +++ b/cast/protocol/castv2/validation_unittest.cc @@ -71,8 +71,6 @@ std::string BuildSchema(const char* definitions, } bool TestValidate(absl::string_view document, absl::string_view schema) { - OSP_DVLOG << "Validating document: \"" << document << "\" against schema: \"" - << schema << "\""; ErrorOr<Json::Value> document_root = json::Parse(document); EXPECT_TRUE(document_root.is_value()); ErrorOr<Json::Value> schema_root = json::Parse(schema); diff --git a/cast/receiver/application_agent.cc b/cast/receiver/application_agent.cc index df2b49e6..d5665391 100644 --- a/cast/receiver/application_agent.cc +++ b/cast/receiver/application_agent.cc @@ -11,6 +11,7 @@ #include "cast/common/public/cast_socket.h" #include "platform/base/tls_credentials.h" #include "platform/base/tls_listen_options.h" +#include "util/json/json_helpers.h" #include "util/json/json_serialization.h" #include "util/osp_logging.h" @@ -18,24 +19,6 @@ namespace openscreen { namespace cast { namespace { -// Parses the given string as a JSON object. If the parse fails, an empty object -// is returned. -Json::Value ParseAsObject(absl::string_view value) { - ErrorOr<Json::Value> parsed = json::Parse(value); - if (parsed.is_value() && parsed.value().isObject()) { - return std::move(parsed.value()); - } - return Json::Value(Json::objectValue); -} - -// Returns true if the type field in |object| is set to the given |type|. -bool HasType(const Json::Value& object, CastMessageType type) { - OSP_DCHECK(object.isObject()); - const Json::Value& value = - object.get(kMessageKeyType, Json::Value::nullSingleton()); - return value.isString() && value.asString() == CastMessageTypeToString(type); -} - // Returns the first app ID for the given |app|, or the empty string if there is // none. std::string GetFirstAppId(ApplicationAgent::Application* app) { @@ -142,25 +125,29 @@ void ApplicationAgent::OnMessage(VirtualConnectionRouter* router, return; } - const Json::Value request = ParseAsObject(message.payload_utf8()); + const ErrorOr<Json::Value> request = json::Parse(message.payload_utf8()); + if (request.is_error() || request.value().type() != Json::objectValue) { + return; + } + Json::Value response; if (ns == kHeartbeatNamespace) { - if (HasType(request, CastMessageType::kPing)) { + if (HasType(request.value(), CastMessageType::kPing)) { response = HandlePing(); } } else if (ns == kReceiverNamespace) { - if (request[kMessageKeyRequestId].isNull()) { - response = HandleInvalidCommand(request); - } else if (HasType(request, CastMessageType::kGetAppAvailability)) { - response = HandleGetAppAvailability(request); - } else if (HasType(request, CastMessageType::kGetStatus)) { - response = HandleGetStatus(request); - } else if (HasType(request, CastMessageType::kLaunch)) { - response = HandleLaunch(request, socket); - } else if (HasType(request, CastMessageType::kStop)) { - response = HandleStop(request); + if (request.value()[kMessageKeyRequestId].isNull()) { + response = HandleInvalidCommand(request.value()); + } else if (HasType(request.value(), CastMessageType::kGetAppAvailability)) { + response = HandleGetAppAvailability(request.value()); + } else if (HasType(request.value(), CastMessageType::kGetStatus)) { + response = HandleGetStatus(request.value()); + } else if (HasType(request.value(), CastMessageType::kLaunch)) { + response = HandleLaunch(request.value(), socket); + } else if (HasType(request.value(), CastMessageType::kStop)) { + response = HandleStop(request.value()); } else { - response = HandleInvalidCommand(request); + response = HandleInvalidCommand(request.value()); } } else { // Ignore messages for all other namespaces. diff --git a/cast/receiver/channel/receiver_socket_factory.cc b/cast/receiver/channel/receiver_socket_factory.cc index c8ddd691..5645bbda 100644 --- a/cast/receiver/channel/receiver_socket_factory.cc +++ b/cast/receiver/channel/receiver_socket_factory.cc @@ -9,6 +9,8 @@ namespace openscreen { namespace cast { +ReceiverSocketFactory::Client::~Client() = default; + ReceiverSocketFactory::ReceiverSocketFactory(Client* client, CastSocket::Client* socket_client) : client_(client), socket_client_(socket_client) { @@ -38,7 +40,6 @@ void ReceiverSocketFactory::OnConnected( void ReceiverSocketFactory::OnConnectionFailed( TlsConnectionFactory* factory, const IPEndpoint& remote_address) { - OSP_DVLOG << "Receiving connection from endpoint failed: " << remote_address; client_->OnError(this, Error(Error::Code::kConnectionFailed, "Accepting connection failed.")); } diff --git a/cast/receiver/public/receiver_socket_factory.h b/cast/receiver/public/receiver_socket_factory.h index 0e2e4e1c..612ffc47 100644 --- a/cast/receiver/public/receiver_socket_factory.h +++ b/cast/receiver/public/receiver_socket_factory.h @@ -5,6 +5,7 @@ #ifndef CAST_RECEIVER_PUBLIC_RECEIVER_SOCKET_FACTORY_H_ #define CAST_RECEIVER_PUBLIC_RECEIVER_SOCKET_FACTORY_H_ +#include <memory> #include <vector> #include "cast/common/public/cast_socket.h" @@ -22,6 +23,9 @@ class ReceiverSocketFactory final : public TlsConnectionFactory::Client { const IPEndpoint& endpoint, std::unique_ptr<CastSocket> socket) = 0; virtual void OnError(ReceiverSocketFactory* factory, Error error) = 0; + + protected: + virtual ~Client(); }; // |client| and |socket_client| must outlive |this|. diff --git a/cast/sender/BUILD.gn b/cast/sender/BUILD.gn index ae17f7cf..09e39453 100644 --- a/cast/sender/BUILD.gn +++ b/cast/sender/BUILD.gn @@ -13,6 +13,7 @@ source_set("channel") { ] deps = [ + "../../third_party/abseil", "../common:channel", "../common/certificate/proto:certificate_proto", "../common/channel/proto:channel_proto", @@ -65,9 +66,7 @@ source_set("test_helpers") { "../receiver:channel", ] - public_deps = [ - ":channel", - ] + public_deps = [ ":channel" ] } source_set("unittests") { diff --git a/cast/sender/cast_app_availability_tracker.cc b/cast/sender/cast_app_availability_tracker.cc index 0a018d19..7e980797 100644 --- a/cast/sender/cast_app_availability_tracker.cc +++ b/cast/sender/cast_app_availability_tracker.cc @@ -54,10 +54,10 @@ void CastAppAvailabilityTracker::UnregisterSource( } std::vector<CastMediaSource> CastAppAvailabilityTracker::UpdateAppAvailability( - const std::string& device_id, + const std::string& receiver_id, const std::string& app_id, AppAvailability availability) { - auto& availabilities = app_availabilities_[device_id]; + auto& availabilities = app_availabilities_[receiver_id]; auto it = availabilities.find(app_id); AppAvailabilityResult old_availability = it == availabilities.end() @@ -84,21 +84,22 @@ std::vector<CastMediaSource> CastAppAvailabilityTracker::UpdateAppAvailability( return affected_sources; } -std::vector<CastMediaSource> CastAppAvailabilityTracker::RemoveResultsForDevice( - const std::string& device_id) { - auto affected_sources = GetSupportedSources(device_id); - app_availabilities_.erase(device_id); +std::vector<CastMediaSource> +CastAppAvailabilityTracker::RemoveResultsForReceiver( + const std::string& receiver_id) { + auto affected_sources = GetSupportedSources(receiver_id); + app_availabilities_.erase(receiver_id); return affected_sources; } std::vector<CastMediaSource> CastAppAvailabilityTracker::GetSupportedSources( - const std::string& device_id) const { - auto it = app_availabilities_.find(device_id); + const std::string& receiver_id) const { + auto it = app_availabilities_.find(receiver_id); if (it == app_availabilities_.end()) { return std::vector<CastMediaSource>(); } - // Find all app IDs that are available on the device. + // Find all app IDs that are available on the receiver. std::vector<std::string> supported_app_ids; for (const auto& availability : it->second) { if (availability.second.availability == AppAvailabilityResult::kAvailable) { @@ -106,7 +107,7 @@ std::vector<CastMediaSource> CastAppAvailabilityTracker::GetSupportedSources( } } - // Find all registered sources whose query results contain the device ID. + // Find all registered sources whose query results contain the receiver ID. std::vector<CastMediaSource> sources; for (const auto& source : registered_sources_) { if (source.second.ContainsAnyAppIdFrom(supported_app_ids)) { @@ -117,9 +118,9 @@ std::vector<CastMediaSource> CastAppAvailabilityTracker::GetSupportedSources( } CastAppAvailabilityTracker::AppAvailability -CastAppAvailabilityTracker::GetAvailability(const std::string& device_id, +CastAppAvailabilityTracker::GetAvailability(const std::string& receiver_id, const std::string& app_id) const { - auto availabilities_it = app_availabilities_.find(device_id); + auto availabilities_it = app_availabilities_.find(receiver_id); if (availabilities_it == app_availabilities_.end()) { return {AppAvailabilityResult::kUnknown, Clock::time_point{}}; } @@ -142,10 +143,11 @@ std::vector<std::string> CastAppAvailabilityTracker::GetRegisteredApps() const { return registered_apps; } -std::vector<std::string> CastAppAvailabilityTracker::GetAvailableDevices( +std::vector<std::string> CastAppAvailabilityTracker::GetAvailableReceivers( const CastMediaSource& source) const { - std::vector<std::string> device_ids; - // For each device, check if there is at least one available app in |source|. + std::vector<std::string> receiver_ids; + // For each receiver, check if there is at least one available app in + // |source|. for (const auto& availabilities : app_availabilities_) { for (const std::string& app_id : source.app_ids()) { const auto& availabilities_map = availabilities.second; @@ -153,12 +155,12 @@ std::vector<std::string> CastAppAvailabilityTracker::GetAvailableDevices( if (availability_it != availabilities_map.end() && availability_it->second.availability == AppAvailabilityResult::kAvailable) { - device_ids.push_back(availabilities.first); + receiver_ids.push_back(availabilities.first); break; } } } - return device_ids; + return receiver_ids; } } // namespace cast diff --git a/cast/sender/cast_app_availability_tracker.h b/cast/sender/cast_app_availability_tracker.h index c0bded96..74d3bc25 100644 --- a/cast/sender/cast_app_availability_tracker.h +++ b/cast/sender/cast_app_availability_tracker.h @@ -16,8 +16,8 @@ namespace openscreen { namespace cast { -// Tracks device queries and their extracted Cast app IDs and their -// availabilities on discovered devices. +// Tracks receiver queries and their extracted Cast app IDs and their +// availabilities on discovered receivers. // Example usage: /// // (1) A page is interested in a Cast URL (e.g. by creating a @@ -28,24 +28,24 @@ namespace cast { // auto new_app_ids = tracker.RegisterSource(source.value()); // // (2) The set of app IDs returned by the tracker can then be used by the caller -// to send an app availability request to each of the discovered devices. +// to send an app availability request to each of the discovered receivers. // -// (3) Once the caller knows the availability value for a (device, app) pair, it -// may inform the tracker to update its results: +// (3) Once the caller knows the availability value for a (receiver, app) pair, +// it may inform the tracker to update its results: // auto affected_sources = -// tracker.UpdateAppAvailability(device_id, app_id, {availability, now}); +// tracker.UpdateAppAvailability(receiver_id, app_id, {availability, +// now}); // // (4) The tracker returns a subset of discovered sources that were affected by -// the update. The caller can then call |GetAvailableDevices()| to get the +// the update. The caller can then call |GetAvailableReceivers()| to get the // updated results for each affected source. // -// (5a): At any time, the caller may call |RemoveResultsForDevice()| to remove -// cached results pertaining to the device, when it detects that a device is +// (5a): At any time, the caller may call |RemoveResultsForReceiver()| to remove +// cached results pertaining to the receiver, when it detects that a receiver is // removed or no longer valid. // -// (5b): At any time, the caller may call |GetAvailableDevices()| (even before +// (5b): At any time, the caller may call |GetAvailableReceivers()| (even before // the source is registered) to determine if there are cached results available. -// TODO(crbug.com/openscreen/112): Device -> Receiver renaming. class CastAppAvailabilityTracker { public: // The result of an app availability request and the time when it is obtained. @@ -69,40 +69,40 @@ class CastAppAvailabilityTracker { void UnregisterSource(const std::string& source_id); void UnregisterSource(const CastMediaSource& source); - // Updates the availability of |app_id| on |device_id| to |availability|. + // Updates the availability of |app_id| on |receiver_id| to |availability|. // Returns a list of registered CastMediaSources for which the set of - // available devices might have been updated by this call. The caller should - // call |GetAvailableDevices| with the returned CastMediaSources to get the + // available receivers might have been updated by this call. The caller should + // call |GetAvailableReceivers| with the returned CastMediaSources to get the // updated lists. std::vector<CastMediaSource> UpdateAppAvailability( - const std::string& device_id, + const std::string& receiver_id, const std::string& app_id, AppAvailability availability); - // Removes all results associated with |device_id|, i.e. when the device + // Removes all results associated with |receiver_id|, i.e. when the receiver // becomes invalid. Returns a list of registered CastMediaSources for which - // the set of available devices might have been updated by this call. The - // caller should call |GetAvailableDevices| with the returned CastMediaSources - // to get the updated lists. - std::vector<CastMediaSource> RemoveResultsForDevice( - const std::string& device_id); + // the set of available receivers might have been updated by this call. The + // caller should call |GetAvailableReceivers| with the returned + // CastMediaSources to get the updated lists. + std::vector<CastMediaSource> RemoveResultsForReceiver( + const std::string& receiver_id); - // Returns a list of registered CastMediaSources supported by |device_id|. + // Returns a list of registered CastMediaSources supported by |receiver_id|. std::vector<CastMediaSource> GetSupportedSources( - const std::string& device_id) const; + const std::string& receiver_id) const; - // Returns the availability for |app_id| on |device_id| and the time at which - // the availability was determined. If availability is kUnknown, then the time - // may be null (e.g. if an availability request was never sent). - AppAvailability GetAvailability(const std::string& device_id, + // Returns the availability for |app_id| on |receiver_id| and the time at + // which the availability was determined. If availability is kUnknown, then + // the time may be null (e.g. if an availability request was never sent). + AppAvailability GetAvailability(const std::string& receiver_id, const std::string& app_id) const; // Returns a list of registered app IDs. std::vector<std::string> GetRegisteredApps() const; - // Returns a list of device IDs compatible with |source|, using the current + // Returns a list of receiver IDs compatible with |source|, using the current // availability info. - std::vector<std::string> GetAvailableDevices( + std::vector<std::string> GetAvailableReceivers( const CastMediaSource& source) const; private: @@ -115,7 +115,7 @@ class CastAppAvailabilityTracker { // App IDs tracked and the number of registered sources containing them. std::map<std::string, int> registration_count_by_app_id_; - // IDs and app availabilities of known devices. + // IDs and app availabilities of known receivers. std::map<std::string, AppAvailabilityMap> app_availabilities_; }; diff --git a/cast/sender/cast_app_availability_tracker_unittest.cc b/cast/sender/cast_app_availability_tracker_unittest.cc index b45d3563..1b721577 100644 --- a/cast/sender/cast_app_availability_tracker_unittest.cc +++ b/cast/sender/cast_app_availability_tracker_unittest.cc @@ -97,62 +97,63 @@ TEST_F(CastAppAvailabilityTrackerTest, UpdateAppAvailability) { // |source3| not affected. EXPECT_THAT( tracker_.UpdateAppAvailability( - "deviceId1", "AAA", {AppAvailabilityResult::kAvailable, Now()}), + "receiverId1", "AAA", {AppAvailabilityResult::kAvailable, Now()}), CastMediaSourcesEqual(std::vector<CastMediaSource>())); - std::vector<std::string> devices_1 = {"deviceId1"}; - std::vector<std::string> devices_1_2 = {"deviceId1", "deviceId2"}; + std::vector<std::string> receivers_1 = {"receiverId1"}; + std::vector<std::string> receivers_1_2 = {"receiverId1", "receiverId2"}; std::vector<CastMediaSource> sources_1 = {source1}; std::vector<CastMediaSource> sources_1_2 = {source1, source2}; - // Tracker returns available devices even though sources aren't registered. - EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source1)); - EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source2)); - EXPECT_TRUE(tracker_.GetAvailableDevices(source3).empty()); + // Tracker returns available receivers even though sources aren't registered. + EXPECT_EQ(receivers_1, tracker_.GetAvailableReceivers(source1)); + EXPECT_EQ(receivers_1, tracker_.GetAvailableReceivers(source2)); + EXPECT_TRUE(tracker_.GetAvailableReceivers(source3).empty()); tracker_.RegisterSource(source1); // Only |source1| is registered for this app. EXPECT_THAT( tracker_.UpdateAppAvailability( - "deviceId2", "AAA", {AppAvailabilityResult::kAvailable, Now()}), + "receiverId2", "AAA", {AppAvailabilityResult::kAvailable, Now()}), CastMediaSourcesEqual(sources_1)); - EXPECT_THAT(tracker_.GetAvailableDevices(source1), - UnorderedElementsAreArray(devices_1_2)); - EXPECT_THAT(tracker_.GetAvailableDevices(source2), - UnorderedElementsAreArray(devices_1_2)); - EXPECT_TRUE(tracker_.GetAvailableDevices(source3).empty()); + EXPECT_THAT(tracker_.GetAvailableReceivers(source1), + UnorderedElementsAreArray(receivers_1_2)); + EXPECT_THAT(tracker_.GetAvailableReceivers(source2), + UnorderedElementsAreArray(receivers_1_2)); + EXPECT_TRUE(tracker_.GetAvailableReceivers(source3).empty()); tracker_.RegisterSource(source2); EXPECT_THAT( tracker_.UpdateAppAvailability( - "deviceId2", "AAA", {AppAvailabilityResult::kUnavailable, Now()}), + "receiverId2", "AAA", {AppAvailabilityResult::kUnavailable, Now()}), CastMediaSourcesEqual(sources_1_2)); - EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source1)); - EXPECT_EQ(devices_1, tracker_.GetAvailableDevices(source2)); - EXPECT_TRUE(tracker_.GetAvailableDevices(source3).empty()); + EXPECT_EQ(receivers_1, tracker_.GetAvailableReceivers(source1)); + EXPECT_EQ(receivers_1, tracker_.GetAvailableReceivers(source2)); + EXPECT_TRUE(tracker_.GetAvailableReceivers(source3).empty()); } -TEST_F(CastAppAvailabilityTrackerTest, RemoveResultsForDevice) { +TEST_F(CastAppAvailabilityTrackerTest, RemoveResultsForReceiver) { CastMediaSource source1("cast:AAA?clientId=1", {"AAA"}); - tracker_.UpdateAppAvailability("deviceId1", "AAA", + tracker_.UpdateAppAvailability("receiverId1", "AAA", {AppAvailabilityResult::kAvailable, Now()}); EXPECT_EQ(AppAvailabilityResult::kAvailable, - tracker_.GetAvailability("deviceId1", "AAA").availability); + tracker_.GetAvailability("receiverId1", "AAA").availability); - std::vector<std::string> expected_device_ids = {"deviceId1"}; - EXPECT_EQ(expected_device_ids, tracker_.GetAvailableDevices(source1)); + std::vector<std::string> expected_receiver_ids = {"receiverId1"}; + EXPECT_EQ(expected_receiver_ids, tracker_.GetAvailableReceivers(source1)); - // Unrelated device ID. - tracker_.RemoveResultsForDevice("deviceId2"); + // Unrelated receiver ID. + tracker_.RemoveResultsForReceiver("receiverId2"); EXPECT_EQ(AppAvailabilityResult::kAvailable, - tracker_.GetAvailability("deviceId1", "AAA").availability); - EXPECT_EQ(expected_device_ids, tracker_.GetAvailableDevices(source1)); + tracker_.GetAvailability("receiverId1", "AAA").availability); + EXPECT_EQ(expected_receiver_ids, tracker_.GetAvailableReceivers(source1)); - tracker_.RemoveResultsForDevice("deviceId1"); + tracker_.RemoveResultsForReceiver("receiverId1"); EXPECT_EQ(AppAvailabilityResult::kUnknown, - tracker_.GetAvailability("deviceId1", "AAA").availability); - EXPECT_EQ(std::vector<std::string>{}, tracker_.GetAvailableDevices(source1)); + tracker_.GetAvailability("receiverId1", "AAA").availability); + EXPECT_EQ(std::vector<std::string>{}, + tracker_.GetAvailableReceivers(source1)); } } // namespace cast diff --git a/cast/sender/cast_app_discovery_service_impl.cc b/cast/sender/cast_app_discovery_service_impl.cc index 4ca9a016..1e428a79 100644 --- a/cast/sender/cast_app_discovery_service_impl.cc +++ b/cast/sender/cast_app_discovery_service_impl.cc @@ -41,10 +41,10 @@ CastAppDiscoveryServiceImpl::StartObservingAvailability( const std::string& source_id = source.source_id(); // Return cached results immediately, if available. - std::vector<std::string> cached_device_ids = - availability_tracker_.GetAvailableDevices(source); - if (!cached_device_ids.empty()) { - callback(source, GetReceiversByIds(cached_device_ids)); + std::vector<std::string> cached_receiver_ids = + availability_tracker_.GetAvailableReceivers(source); + if (!cached_receiver_ids.empty()) { + callback(source, GetReceiversByIds(cached_receiver_ids)); } auto& callbacks = avail_queries_[source_id]; @@ -76,54 +76,54 @@ void CastAppDiscoveryServiceImpl::Refresh() { } void CastAppDiscoveryServiceImpl::AddOrUpdateReceiver( - const ServiceInfo& receiver) { - const std::string& device_id = receiver.unique_id; - receivers_by_id_[device_id] = receiver; + const ReceiverInfo& receiver) { + const std::string& receiver_id = receiver.unique_id; + receivers_by_id_[receiver_id] = receiver; // Any queries that currently contain this receiver should be updated. UpdateAvailabilityQueries( - availability_tracker_.GetSupportedSources(device_id)); + availability_tracker_.GetSupportedSources(receiver_id)); for (const std::string& app_id : availability_tracker_.GetRegisteredApps()) { - RequestAppAvailability(device_id, app_id); + RequestAppAvailability(receiver_id, app_id); } } -void CastAppDiscoveryServiceImpl::RemoveReceiver(const ServiceInfo& receiver) { - const std::string& device_id = receiver.unique_id; - receivers_by_id_.erase(device_id); +void CastAppDiscoveryServiceImpl::RemoveReceiver(const ReceiverInfo& receiver) { + const std::string& receiver_id = receiver.unique_id; + receivers_by_id_.erase(receiver_id); UpdateAvailabilityQueries( - availability_tracker_.RemoveResultsForDevice(device_id)); + availability_tracker_.RemoveResultsForReceiver(receiver_id)); } void CastAppDiscoveryServiceImpl::RequestAppAvailability( - const std::string& device_id, + const std::string& receiver_id, const std::string& app_id) { - if (ShouldRefreshAppAvailability(device_id, app_id, clock_())) { + if (ShouldRefreshAppAvailability(receiver_id, app_id, clock_())) { platform_client_->RequestAppAvailability( - device_id, app_id, - [self = weak_factory_.GetWeakPtr(), device_id]( + receiver_id, app_id, + [self = weak_factory_.GetWeakPtr(), receiver_id]( const std::string& app_id, AppAvailabilityResult availability) { if (self) { - self->UpdateAppAvailability(device_id, app_id, availability); + self->UpdateAppAvailability(receiver_id, app_id, availability); } }); } } void CastAppDiscoveryServiceImpl::UpdateAppAvailability( - const std::string& device_id, + const std::string& receiver_id, const std::string& app_id, AppAvailabilityResult availability) { - if (receivers_by_id_.find(device_id) == receivers_by_id_.end()) { + if (receivers_by_id_.find(receiver_id) == receivers_by_id_.end()) { return; } - OSP_DVLOG << "App " << app_id << " on receiver " << device_id << " is " + OSP_DVLOG << "App " << app_id << " on receiver " << receiver_id << " is " << ToString(availability); UpdateAvailabilityQueries(availability_tracker_.UpdateAppAvailability( - device_id, app_id, {availability, clock_()})); + receiver_id, app_id, {availability, clock_()})); } void CastAppDiscoveryServiceImpl::UpdateAvailabilityQueries( @@ -133,20 +133,20 @@ void CastAppDiscoveryServiceImpl::UpdateAvailabilityQueries( auto it = avail_queries_.find(source_id); if (it == avail_queries_.end()) continue; - std::vector<std::string> device_ids = - availability_tracker_.GetAvailableDevices(source); - std::vector<ServiceInfo> receivers = GetReceiversByIds(device_ids); + std::vector<std::string> receiver_ids = + availability_tracker_.GetAvailableReceivers(source); + std::vector<ReceiverInfo> receivers = GetReceiversByIds(receiver_ids); for (const auto& callback : it->second) { callback.callback(source, receivers); } } } -std::vector<ServiceInfo> CastAppDiscoveryServiceImpl::GetReceiversByIds( - const std::vector<std::string>& device_ids) const { - std::vector<ServiceInfo> receivers; - for (const std::string& device_id : device_ids) { - auto entry = receivers_by_id_.find(device_id); +std::vector<ReceiverInfo> CastAppDiscoveryServiceImpl::GetReceiversByIds( + const std::vector<std::string>& receiver_ids) const { + std::vector<ReceiverInfo> receivers; + for (const std::string& receiver_id : receiver_ids) { + auto entry = receivers_by_id_.find(receiver_id); if (entry != receivers_by_id_.end()) { receivers.push_back(entry->second); } @@ -155,13 +155,14 @@ std::vector<ServiceInfo> CastAppDiscoveryServiceImpl::GetReceiversByIds( } bool CastAppDiscoveryServiceImpl::ShouldRefreshAppAvailability( - const std::string& device_id, + const std::string& receiver_id, const std::string& app_id, Clock::time_point now) const { // TODO(btolsch): Consider an exponential backoff mechanism instead. // Receivers will typically respond with "unavailable" immediately after boot // and then become available 10-30 seconds later. - auto availability = availability_tracker_.GetAvailability(device_id, app_id); + auto availability = + availability_tracker_.GetAvailability(receiver_id, app_id); switch (availability.availability) { case AppAvailabilityResult::kAvailable: return false; diff --git a/cast/sender/cast_app_discovery_service_impl.h b/cast/sender/cast_app_discovery_service_impl.h index fa577808..49943994 100644 --- a/cast/sender/cast_app_discovery_service_impl.h +++ b/cast/sender/cast_app_discovery_service_impl.h @@ -9,7 +9,7 @@ #include <string> #include <vector> -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include "cast/sender/cast_app_availability_tracker.h" #include "cast/sender/cast_platform_client.h" #include "cast/sender/public/cast_app_discovery_service.h" @@ -33,12 +33,12 @@ class CastAppDiscoveryServiceImpl : public CastAppDiscoveryService { const CastMediaSource& source, AvailabilityCallback callback) override; - // Reissues app availability requests for currently registered (device_id, + // Reissues app availability requests for currently registered (receiver_id, // app_id) pairs whose status is kUnavailable or kUnknown. void Refresh() override; - void AddOrUpdateReceiver(const ServiceInfo& receiver); - void RemoveReceiver(const ServiceInfo& receiver); + void AddOrUpdateReceiver(const ReceiverInfo& receiver); + void RemoveReceiver(const ReceiverInfo& receiver); private: struct AvailabilityCallbackEntry { @@ -47,32 +47,32 @@ class CastAppDiscoveryServiceImpl : public CastAppDiscoveryService { }; // Issues an app availability request for |app_id| to the receiver given by - // |device_id|. - void RequestAppAvailability(const std::string& device_id, + // |receiver_id|. + void RequestAppAvailability(const std::string& receiver_id, const std::string& app_id); - // Updates the availability result for |device_id| and |app_id| with |result|, - // and notifies callbacks with updated availability query results. - void UpdateAppAvailability(const std::string& device_id, + // Updates the availability result for |receiver_id| and |app_id| with + // |result|, and notifies callbacks with updated availability query results. + void UpdateAppAvailability(const std::string& receiver_id, const std::string& app_id, AppAvailabilityResult result); // Updates the availability query results for |sources|. void UpdateAvailabilityQueries(const std::vector<CastMediaSource>& sources); - std::vector<ServiceInfo> GetReceiversByIds( - const std::vector<std::string>& device_ids) const; + std::vector<ReceiverInfo> GetReceiversByIds( + const std::vector<std::string>& receiver_ids) const; // Returns true if an app availability request should be issued for - // |device_id| and |app_id|. |now| is used for checking whether previously + // |receiver_id| and |app_id|. |now| is used for checking whether previously // cached results should be refreshed. - bool ShouldRefreshAppAvailability(const std::string& device_id, + bool ShouldRefreshAppAvailability(const std::string& receiver_id, const std::string& app_id, Clock::time_point now) const; void RemoveAvailabilityCallback(uint32_t id) override; - std::map<std::string, ServiceInfo> receivers_by_id_; + std::map<std::string, ReceiverInfo> receivers_by_id_; // Registered availability queries and their associated callbacks keyed by // media source IDs. diff --git a/cast/sender/cast_app_discovery_service_impl_unittest.cc b/cast/sender/cast_app_discovery_service_impl_unittest.cc index a2eeb046..60175f59 100644 --- a/cast/sender/cast_app_discovery_service_impl_unittest.cc +++ b/cast/sender/cast_app_discovery_service_impl_unittest.cc @@ -9,7 +9,7 @@ #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_socket_error_handler.h" #include "cast/common/channel/virtual_connection_router.h" -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include "cast/sender/testing/test_helpers.h" #include "gtest/gtest.h" #include "platform/test/fake_clock.h" @@ -32,7 +32,7 @@ class CastAppDiscoveryServiceImplTest : public ::testing::Test { receiver_.v4_address = fake_cast_socket_pair_.remote_endpoint.address; receiver_.port = fake_cast_socket_pair_.remote_endpoint.port; - receiver_.unique_id = "deviceId1"; + receiver_.unique_id = "receiverId1"; receiver_.friendly_name = "Some Name"; } @@ -42,23 +42,23 @@ class CastAppDiscoveryServiceImplTest : public ::testing::Test { return fake_cast_socket_pair_.mock_peer_client; } - void AddOrUpdateReceiver(const ServiceInfo& receiver, int32_t socket_id) { + void AddOrUpdateReceiver(const ReceiverInfo& receiver, int32_t socket_id) { platform_client_.AddOrUpdateReceiver(receiver, socket_id); app_discovery_service_.AddOrUpdateReceiver(receiver); } CastAppDiscoveryService::Subscription StartObservingAvailability( const CastMediaSource& source, - std::vector<ServiceInfo>* save_receivers) { + std::vector<ReceiverInfo>* save_receivers) { return app_discovery_service_.StartObservingAvailability( source, [save_receivers](const CastMediaSource& source, - const std::vector<ServiceInfo>& receivers) { + const std::vector<ReceiverInfo>& receivers) { *save_receivers = receivers; }); } CastAppDiscoveryService::Subscription StartSourceA1Query( - std::vector<ServiceInfo>* receivers, + std::vector<ReceiverInfo>* receivers, int* request_id, std::string* sender_id) { auto subscription = StartObservingAvailability(source_a_1_, receivers); @@ -91,18 +91,18 @@ class CastAppDiscoveryServiceImplTest : public ::testing::Test { CastMediaSource source_a_2_{"cast:AAA?clientId=2", {"AAA"}}; CastMediaSource source_b_1_{"cast:BBB?clientId=1", {"BBB"}}; - ServiceInfo receiver_; + ReceiverInfo receiver_; }; TEST_F(CastAppDiscoveryServiceImplTest, StartObservingAvailability) { - std::vector<ServiceInfo> receivers1; + std::vector<ReceiverInfo> receivers1; int request_id; std::string sender_id; auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); // Same app ID should not trigger another request. EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); - std::vector<ServiceInfo> receivers2; + std::vector<ReceiverInfo> receivers2; auto subscription2 = StartObservingAvailability(source_a_2_, &receivers2); CastMessage availability_response = @@ -110,8 +110,8 @@ TEST_F(CastAppDiscoveryServiceImplTest, StartObservingAvailability) { EXPECT_TRUE(peer_socket().Send(availability_response).ok()); ASSERT_EQ(receivers1.size(), 1u); ASSERT_EQ(receivers2.size(), 1u); - EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); - EXPECT_EQ(receivers2[0].unique_id, "deviceId1"); + EXPECT_EQ(receivers1[0].unique_id, "receiverId1"); + EXPECT_EQ(receivers2[0].unique_id, "receiverId1"); // No more updates for |source_a_1_| (i.e. |receivers1|). subscription1.Reset(); @@ -119,11 +119,11 @@ TEST_F(CastAppDiscoveryServiceImplTest, StartObservingAvailability) { app_discovery_service_.RemoveReceiver(receiver_); ASSERT_EQ(receivers1.size(), 1u); EXPECT_EQ(receivers2.size(), 0u); - EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + EXPECT_EQ(receivers1[0].unique_id, "receiverId1"); } TEST_F(CastAppDiscoveryServiceImplTest, ReAddAvailQueryUsesCachedValue) { - std::vector<ServiceInfo> receivers1; + std::vector<ReceiverInfo> receivers1; int request_id; std::string sender_id; auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); @@ -132,7 +132,7 @@ TEST_F(CastAppDiscoveryServiceImplTest, ReAddAvailQueryUsesCachedValue) { CreateAppAvailableResponseChecked(request_id, sender_id, "AAA"); EXPECT_TRUE(peer_socket().Send(availability_response).ok()); ASSERT_EQ(receivers1.size(), 1u); - EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + EXPECT_EQ(receivers1[0].unique_id, "receiverId1"); subscription1.Reset(); receivers1.clear(); @@ -141,11 +141,11 @@ TEST_F(CastAppDiscoveryServiceImplTest, ReAddAvailQueryUsesCachedValue) { EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); subscription1 = StartObservingAvailability(source_a_1_, &receivers1); ASSERT_EQ(receivers1.size(), 1u); - EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + EXPECT_EQ(receivers1[0].unique_id, "receiverId1"); } TEST_F(CastAppDiscoveryServiceImplTest, AvailQueryUpdatedOnReceiverUpdate) { - std::vector<ServiceInfo> receivers1; + std::vector<ReceiverInfo> receivers1; int request_id; std::string sender_id; auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); @@ -155,7 +155,7 @@ TEST_F(CastAppDiscoveryServiceImplTest, AvailQueryUpdatedOnReceiverUpdate) { CreateAppAvailableResponseChecked(request_id, sender_id, "AAA"); EXPECT_TRUE(peer_socket().Send(availability_response).ok()); ASSERT_EQ(receivers1.size(), 1u); - EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + EXPECT_EQ(receivers1[0].unique_id, "receiverId1"); // Updating |receiver_| causes |source_a_1_| query to be updated, but it's too // soon for a new message to be sent. @@ -169,9 +169,9 @@ TEST_F(CastAppDiscoveryServiceImplTest, AvailQueryUpdatedOnReceiverUpdate) { } TEST_F(CastAppDiscoveryServiceImplTest, Refresh) { - std::vector<ServiceInfo> receivers1; + std::vector<ReceiverInfo> receivers1; auto subscription1 = StartObservingAvailability(source_a_1_, &receivers1); - std::vector<ServiceInfo> receivers2; + std::vector<ReceiverInfo> receivers2; auto subscription2 = StartObservingAvailability(source_b_1_, &receivers2); // Adding a receiver after app registered causes two separate app availability @@ -207,7 +207,7 @@ TEST_F(CastAppDiscoveryServiceImplTest, Refresh) { EXPECT_TRUE(peer_socket().Send(availability_response).ok()); ASSERT_EQ(receivers1.size(), 1u); ASSERT_EQ(receivers2.size(), 0u); - EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + EXPECT_EQ(receivers1[0].unique_id, "receiverId1"); // Not enough time has passed for a refresh. clock_.Advance(std::chrono::seconds(30)); @@ -236,7 +236,7 @@ TEST_F(CastAppDiscoveryServiceImplTest, .WillOnce([&request_idA, &sender_id](CastSocket*, CastMessage message) { VerifyAppAvailabilityRequest(message, "AAA", &request_idA, &sender_id); }); - std::vector<ServiceInfo> receivers1; + std::vector<ReceiverInfo> receivers1; auto subscription1 = StartObservingAvailability(source_a_1_, &receivers1); int request_idB = -1; @@ -244,7 +244,7 @@ TEST_F(CastAppDiscoveryServiceImplTest, .WillOnce([&request_idB, &sender_id](CastSocket*, CastMessage message) { VerifyAppAvailabilityRequest(message, "BBB", &request_idB, &sender_id); }); - std::vector<ServiceInfo> receivers2; + std::vector<ReceiverInfo> receivers2; auto subscription2 = StartObservingAvailability(source_b_1_, &receivers2); // Add a new receiver with a corresponding socket. @@ -252,8 +252,8 @@ TEST_F(CastAppDiscoveryServiceImplTest, {{192, 168, 1, 19}, 2345}); CastSocket* socket2 = fake_sockets2.socket.get(); router_.TakeSocket(&mock_error_handler_, std::move(fake_sockets2.socket)); - ServiceInfo receiver2; - receiver2.unique_id = "deviceId2"; + ReceiverInfo receiver2; + receiver2.unique_id = "receiverId2"; receiver2.v4_address = fake_sockets2.remote_endpoint.address; receiver2.port = fake_sockets2.remote_endpoint.port; @@ -283,7 +283,7 @@ TEST_F(CastAppDiscoveryServiceImplTest, } TEST_F(CastAppDiscoveryServiceImplTest, StartObservingAvailabilityCachedValue) { - std::vector<ServiceInfo> receivers1; + std::vector<ReceiverInfo> receivers1; int request_id; std::string sender_id; auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); @@ -292,19 +292,19 @@ TEST_F(CastAppDiscoveryServiceImplTest, StartObservingAvailabilityCachedValue) { CreateAppAvailableResponseChecked(request_id, sender_id, "AAA"); EXPECT_TRUE(peer_socket().Send(availability_response).ok()); ASSERT_EQ(receivers1.size(), 1u); - EXPECT_EQ(receivers1[0].unique_id, "deviceId1"); + EXPECT_EQ(receivers1[0].unique_id, "receiverId1"); // Same app ID should not trigger another request, but it should return // cached value. EXPECT_CALL(peer_client(), OnMessage(_, _)).Times(0); - std::vector<ServiceInfo> receivers2; + std::vector<ReceiverInfo> receivers2; auto subscription2 = StartObservingAvailability(source_a_2_, &receivers2); ASSERT_EQ(receivers2.size(), 1u); - EXPECT_EQ(receivers2[0].unique_id, "deviceId1"); + EXPECT_EQ(receivers2[0].unique_id, "receiverId1"); } TEST_F(CastAppDiscoveryServiceImplTest, AvailabilityUnknownOrUnavailable) { - std::vector<ServiceInfo> receivers1; + std::vector<ReceiverInfo> receivers1; int request_id; std::string sender_id; auto subscription1 = StartSourceA1Query(&receivers1, &request_id, &sender_id); diff --git a/cast/sender/cast_platform_client.cc b/cast/sender/cast_platform_client.cc index c321201a..b0b956a8 100644 --- a/cast/sender/cast_platform_client.cc +++ b/cast/sender/cast_platform_client.cc @@ -11,7 +11,7 @@ #include "absl/strings/str_cat.h" #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include "util/json/json_serialization.h" #include "util/osp_logging.h" #include "util/stringprintf.h" @@ -38,7 +38,7 @@ CastPlatformClient::~CastPlatformClient() { virtual_conn_router_->RemoveConnectionsByLocalId(sender_id_); virtual_conn_router_->RemoveHandlerForLocalId(sender_id_); - for (auto& pending_requests : pending_requests_by_device_id_) { + for (auto& pending_requests : pending_requests_by_receiver_id_) { for (auto& avail_request : pending_requests.second.availability) { avail_request.callback(avail_request.app_id, AppAvailabilityResult::kUnknown); @@ -47,11 +47,11 @@ CastPlatformClient::~CastPlatformClient() { } absl::optional<int> CastPlatformClient::RequestAppAvailability( - const std::string& device_id, + const std::string& receiver_id, const std::string& app_id, AppAvailabilityCallback callback) { - auto entry = socket_id_by_device_id_.find(device_id); - if (entry == socket_id_by_device_id_.end()) { + auto entry = socket_id_by_receiver_id_.find(receiver_id); + if (entry == socket_id_by_receiver_id_.end()) { callback(app_id, AppAvailabilityResult::kUnknown); return absl::nullopt; } @@ -62,7 +62,8 @@ absl::optional<int> CastPlatformClient::RequestAppAvailability( CreateAppAvailabilityRequest(sender_id_, request_id, app_id); OSP_DCHECK(message); - PendingRequests& pending_requests = pending_requests_by_device_id_[device_id]; + PendingRequests& pending_requests = + pending_requests_by_receiver_id_[receiver_id]; auto timeout = std::make_unique<Alarm>(clock_, task_runner_); timeout->ScheduleFromNow( [this, request_id]() { CancelAppAvailabilityRequest(request_id); }, @@ -82,28 +83,28 @@ absl::optional<int> CastPlatformClient::RequestAppAvailability( return request_id; } -void CastPlatformClient::AddOrUpdateReceiver(const ServiceInfo& device, +void CastPlatformClient::AddOrUpdateReceiver(const ReceiverInfo& receiver, int socket_id) { - socket_id_by_device_id_[device.unique_id] = socket_id; + socket_id_by_receiver_id_[receiver.unique_id] = socket_id; } -void CastPlatformClient::RemoveReceiver(const ServiceInfo& device) { +void CastPlatformClient::RemoveReceiver(const ReceiverInfo& receiver) { auto pending_requests_it = - pending_requests_by_device_id_.find(device.unique_id); - if (pending_requests_it != pending_requests_by_device_id_.end()) { + pending_requests_by_receiver_id_.find(receiver.unique_id); + if (pending_requests_it != pending_requests_by_receiver_id_.end()) { for (const AvailabilityRequest& availability : pending_requests_it->second.availability) { availability.callback(availability.app_id, AppAvailabilityResult::kUnknown); } - pending_requests_by_device_id_.erase(pending_requests_it); + pending_requests_by_receiver_id_.erase(pending_requests_it); } - socket_id_by_device_id_.erase(device.unique_id); + socket_id_by_receiver_id_.erase(receiver.unique_id); } void CastPlatformClient::CancelRequest(int request_id) { - for (auto entry = pending_requests_by_device_id_.begin(); - entry != pending_requests_by_device_id_.end(); ++entry) { + for (auto entry = pending_requests_by_receiver_id_.begin(); + entry != pending_requests_by_receiver_id_.end(); ++entry) { auto& pending_requests = entry->second; auto it = std::find_if(pending_requests.availability.begin(), pending_requests.availability.end(), @@ -128,7 +129,6 @@ void CastPlatformClient::OnMessage(VirtualConnectionRouter* router, } ErrorOr<Json::Value> dict_or_error = json::Parse(message.payload_utf8()); if (dict_or_error.is_error()) { - OSP_DVLOG << "Failed to deserialize CastMessage payload."; return; } @@ -137,22 +137,22 @@ void CastPlatformClient::OnMessage(VirtualConnectionRouter* router, MaybeGetInt(dict, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyRequestId)); if (request_id) { auto entry = std::find_if( - socket_id_by_device_id_.begin(), socket_id_by_device_id_.end(), + socket_id_by_receiver_id_.begin(), socket_id_by_receiver_id_.end(), [socket_id = ToCastSocketId(socket)](const std::pair<std::string, int>& entry) { return entry.second == socket_id; }); - if (entry != socket_id_by_device_id_.end()) { + if (entry != socket_id_by_receiver_id_.end()) { HandleResponse(entry->first, request_id.value(), dict); } } } -void CastPlatformClient::HandleResponse(const std::string& device_id, +void CastPlatformClient::HandleResponse(const std::string& receiver_id, int request_id, const Json::Value& message) { - auto entry = pending_requests_by_device_id_.find(device_id); - if (entry == pending_requests_by_device_id_.end()) { + auto entry = pending_requests_by_receiver_id_.find(receiver_id); + if (entry == pending_requests_by_receiver_id_.end()) { return; } PendingRequests& pending_requests = entry->second; @@ -178,7 +178,7 @@ void CastPlatformClient::HandleResponse(const std::string& device_id, } else if (result.value() == kMessageValueAppUnavailable) { availability_result = AppAvailabilityResult::kUnavailable; } else { - OSP_DVLOG << "Invalid availability result: " << result.value(); + OSP_VLOG << "Invalid availability result: " << result.value(); } it->callback(it->app_id, availability_result); } @@ -188,7 +188,7 @@ void CastPlatformClient::HandleResponse(const std::string& device_id, } void CastPlatformClient::CancelAppAvailabilityRequest(int request_id) { - for (auto& entry : pending_requests_by_device_id_) { + for (auto& entry : pending_requests_by_receiver_id_) { PendingRequests& pending_requests = entry.second; auto it = std::find_if(pending_requests.availability.begin(), pending_requests.availability.end(), diff --git a/cast/sender/cast_platform_client.h b/cast/sender/cast_platform_client.h index 8ea9a99a..c80a8b95 100644 --- a/cast/sender/cast_platform_client.h +++ b/cast/sender/cast_platform_client.h @@ -20,7 +20,7 @@ namespace openscreen { namespace cast { -struct ServiceInfo; +struct ReceiverInfo; class VirtualConnectionRouter; // This class handles Cast messages that generally relate to the "platform", in @@ -41,15 +41,15 @@ class CastPlatformClient final : public CastMessageHandler { ~CastPlatformClient() override; // Requests availability information for |app_id| from the receiver identified - // by |device_id|. |callback| will be called exactly once with a result. - absl::optional<int> RequestAppAvailability(const std::string& device_id, + // by |receiver_id|. |callback| will be called exactly once with a result. + absl::optional<int> RequestAppAvailability(const std::string& receiver_id, const std::string& app_id, AppAvailabilityCallback callback); // Notifies this object about general receiver connectivity or property // changes. - void AddOrUpdateReceiver(const ServiceInfo& device, int socket_id); - void RemoveReceiver(const ServiceInfo& device); + void AddOrUpdateReceiver(const ReceiverInfo& receiver, int socket_id); + void RemoveReceiver(const ReceiverInfo& receiver); void CancelRequest(int request_id); @@ -70,7 +70,7 @@ class CastPlatformClient final : public CastMessageHandler { CastSocket* socket, ::cast::channel::CastMessage message) override; - void HandleResponse(const std::string& device_id, + void HandleResponse(const std::string& receiver_id, int request_id, const Json::Value& message); @@ -82,9 +82,9 @@ class CastPlatformClient final : public CastMessageHandler { const std::string sender_id_; VirtualConnectionRouter* const virtual_conn_router_; - std::map<std::string /* device_id */, int> socket_id_by_device_id_; - std::map<std::string /* device_id */, PendingRequests> - pending_requests_by_device_id_; + std::map<std::string /* receiver_id */, int> socket_id_by_receiver_id_; + std::map<std::string /* receiver_id */, PendingRequests> + pending_requests_by_receiver_id_; const ClockNowFunctionPtr clock_; TaskRunner* const task_runner_; diff --git a/cast/sender/cast_platform_client_unittest.cc b/cast/sender/cast_platform_client_unittest.cc index ae721a15..702e35eb 100644 --- a/cast/sender/cast_platform_client_unittest.cc +++ b/cast/sender/cast_platform_client_unittest.cc @@ -9,7 +9,7 @@ #include "cast/common/channel/testing/fake_cast_socket.h" #include "cast/common/channel/testing/mock_socket_error_handler.h" #include "cast/common/channel/virtual_connection_router.h" -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include "cast/sender/testing/test_helpers.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -34,7 +34,7 @@ class CastPlatformClientTest : public ::testing::Test { receiver_.v4_address = IPAddress{192, 168, 0, 17}; receiver_.port = 4434; - receiver_.unique_id = "deviceId1"; + receiver_.unique_id = "receiverId1"; platform_client_.AddOrUpdateReceiver(receiver_, socket_->socket_id()); } @@ -51,7 +51,7 @@ class CastPlatformClientTest : public ::testing::Test { FakeClock clock_{Clock::now()}; FakeTaskRunner task_runner_{&clock_}; CastPlatformClient platform_client_{&router_, &FakeClock::now, &task_runner_}; - ServiceInfo receiver_; + ReceiverInfo receiver_; }; TEST_F(CastPlatformClientTest, AppAvailability) { @@ -64,7 +64,7 @@ TEST_F(CastPlatformClientTest, AppAvailability) { }); bool ran = false; platform_client_.RequestAppAvailability( - "deviceId1", "AAA", + "receiverId1", "AAA", [&ran](const std::string& app_id, AppAvailabilityResult availability) { EXPECT_EQ("AAA", app_id); EXPECT_EQ(availability, AppAvailabilityResult::kAvailable); @@ -92,7 +92,7 @@ TEST_F(CastPlatformClientTest, CancelRequest) { }); absl::optional<int> maybe_request_id = platform_client_.RequestAppAvailability( - "deviceId1", "AAA", + "receiverId1", "AAA", [](const std::string& app_id, AppAvailabilityResult availability) { EXPECT_TRUE(false); }); diff --git a/cast/sender/channel/cast_auth_util.cc b/cast/sender/channel/cast_auth_util.cc index cb1ced69..f0e1c36f 100644 --- a/cast/sender/channel/cast_auth_util.cc +++ b/cast/sender/channel/cast_auth_util.cc @@ -9,6 +9,7 @@ #include <algorithm> #include <memory> +#include "absl/strings/str_cat.h" #include "cast/common/certificate/cast_cert_validator.h" #include "cast/common/certificate/cast_cert_validator_internal.h" #include "cast/common/certificate/cast_crl.h" @@ -104,43 +105,56 @@ class CastNonce { std::chrono::seconds nonce_generation_time_; }; -// Maps Error::Code from certificate verification to Error. -// If crl_required is set to false, all revocation related errors are ignored. -Error MapToOpenscreenError(Error::Code error, bool crl_required) { - switch (error) { +// Maps an error from certificate verification to an error reported to the +// library client. If crl_required is set to false, all revocation related +// errors are ignored. +// +// TODO(https://issuetracker.google.com/issues/193164666): It would be simpler +// to just pass the underlying verification error directly to the client. +Error MapToOpenscreenError(Error verify_error, bool crl_required) { + switch (verify_error.code()) { case Error::Code::kErrCertsMissing: return Error(Error::Code::kCastV2PeerCertEmpty, - "Failed to locate certificates."); + absl::StrCat("Failed to locate certificates: ", + verify_error.message())); case Error::Code::kErrCertsParse: return Error(Error::Code::kErrCertsParse, - "Failed to parse certificates."); + absl::StrCat("Failed to parse certificates: ", + verify_error.message())); case Error::Code::kErrCertsDateInvalid: - return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, - "Failed date validity check."); + return Error( + Error::Code::kCastV2CertNotSignedByTrustedCa, + absl::StrCat("Failed date validity check: ", verify_error.message())); case Error::Code::kErrCertsVerifyGeneric: - return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, - "Failed with a generic certificate verification error."); + return Error( + Error::Code::kCastV2CertNotSignedByTrustedCa, + absl::StrCat("Failed with a generic certificate verification error: ", + verify_error.message())); case Error::Code::kErrCertsRestrictions: return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, - "Failed certificate restrictions."); + absl::StrCat("Failed certificate restrictions: ", + verify_error.message())); case Error::Code::kErrCertsVerifyUntrustedCert: return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, - "Failed with untrusted certificate."); + absl::StrCat("Failed with untrusted certificate: ", + verify_error.message())); case Error::Code::kErrCrlInvalid: // This error is only encountered if |crl_required| is true. OSP_DCHECK(crl_required); return Error(Error::Code::kErrCrlInvalid, - "Failed to provide a valid CRL."); + absl::StrCat("Failed to provide a valid CRL: ", + verify_error.message())); case Error::Code::kErrCertsRevoked: return Error(Error::Code::kErrCertsRevoked, - "Failed certificate revocation check."); + absl::StrCat("Failed certificate revocation check: ", + verify_error.message())); case Error::Code::kNone: return Error::None(); default: return Error(Error::Code::kCastV2CertNotSignedByTrustedCa, - "Failed verifying cast device certificate."); + absl::StrCat("Failed verifying cast device certificate: ", + verify_error.message())); } - return Error::None(); } Error VerifyAndMapDigestAlgorithm(HashAlgorithm response_digest_algorithm, @@ -170,6 +184,21 @@ AuthContext AuthContext::Create() { return AuthContext(CastNonce::Get()); } +// static +AuthContext AuthContext::CreateForTest(const std::string& nonce_data) { + std::string nonce; + if (nonce_data.empty()) { + nonce = std::string(kNonceSizeInBytes, '0'); + } else { + while (nonce.size() < kNonceSizeInBytes) { + nonce += nonce_data; + } + nonce.erase(kNonceSizeInBytes); + } + OSP_DCHECK_EQ(nonce.size(), kNonceSizeInBytes); + return AuthContext(nonce); +} + AuthContext::AuthContext(const std::string& nonce) : nonce_(nonce) {} AuthContext::~AuthContext() {} @@ -356,7 +385,7 @@ ErrorOr<CastDeviceCertPolicy> VerifyCredentialsImpl( &device_policy, crl.get(), crl_policy, cast_trust_store); // Handle and report errors. - Error result = MapToOpenscreenError(verify_result.code(), + Error result = MapToOpenscreenError(verify_result, crl_policy == CRLPolicy::kCrlRequired); if (!result.ok()) { return result; diff --git a/cast/sender/channel/cast_auth_util.h b/cast/sender/channel/cast_auth_util.h index d23ebd7e..9c0646ec 100644 --- a/cast/sender/channel/cast_auth_util.h +++ b/cast/sender/channel/cast_auth_util.h @@ -36,6 +36,9 @@ class AuthContext { // The same context must be used in the challenge and reply. static AuthContext Create(); + // Create a context with some seed nonce data for testing. + static AuthContext CreateForTest(const std::string& nonce_data); + // Verifies the nonce received in the response is equivalent to the one sent. // Returns success if |nonce_response| matches nonce_ Error VerifySenderNonce(const std::string& nonce_response, diff --git a/cast/sender/channel/sender_socket_factory.cc b/cast/sender/channel/sender_socket_factory.cc index e971976b..c0924836 100644 --- a/cast/sender/channel/sender_socket_factory.cc +++ b/cast/sender/channel/sender_socket_factory.cc @@ -16,6 +16,8 @@ using ::cast::channel::CastMessage; namespace openscreen { namespace cast { +SenderSocketFactory::Client::~Client() = default; + bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a, int b) { return a && a->socket->socket_id() < b; @@ -106,8 +108,6 @@ void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory, const IPEndpoint& remote_address) { auto it = FindPendingConnection(remote_address); if (it == pending_connections_.end()) { - OSP_DVLOG << "OnConnectionFailed reported for untracked address: " - << remote_address; return; } pending_connections_.erase(it); diff --git a/cast/sender/public/README.md b/cast/sender/public/README.md index b670d110..eb7527f5 100644 --- a/cast/sender/public/README.md +++ b/cast/sender/public/README.md @@ -1,5 +1,5 @@ # cast/sender/public This module contains an implementation of the Cast "sender", i.e. the client -that discovers Cast devices on the LAN and launches apps on them. +that discovers Cast receivers on the LAN and launches apps on them. diff --git a/cast/sender/public/cast_app_discovery_service.h b/cast/sender/public/cast_app_discovery_service.h index c05d66bd..2e095b6d 100644 --- a/cast/sender/public/cast_app_discovery_service.h +++ b/cast/sender/public/cast_app_discovery_service.h @@ -7,19 +7,19 @@ #include <vector> -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" namespace openscreen { namespace cast { class CastMediaSource; -// Interface for app discovery for Cast devices. +// Interface for app discovery for Cast receivers. class CastAppDiscoveryService { public: using AvailabilityCallback = std::function<void(const CastMediaSource& source, - const std::vector<ServiceInfo>& devices)>; + const std::vector<ReceiverInfo>& receivers)>; class Subscription { public: @@ -47,7 +47,7 @@ class CastAppDiscoveryService { // returned via |callback| until the returned Subscription is destroyed by the // caller. If there are cached results available, |callback| will be invoked // before this method returns. |callback| may be invoked with an empty list - // if all devices respond to the respective queries with "unavailable" or + // if all receivers respond to the respective queries with "unavailable" or // don't respond before a timeout. |callback| may be invoked successively // with the same list. virtual Subscription StartObservingAvailability( diff --git a/cast/sender/public/sender_socket_factory.h b/cast/sender/public/sender_socket_factory.h index f0247a28..0b9b05aa 100644 --- a/cast/sender/public/sender_socket_factory.h +++ b/cast/sender/public/sender_socket_factory.h @@ -7,6 +7,7 @@ #include <openssl/x509.h> +#include <memory> #include <set> #include <utility> #include <vector> @@ -33,6 +34,9 @@ class SenderSocketFactory final : public TlsConnectionFactory::Client, virtual void OnError(SenderSocketFactory* factory, const IPEndpoint& endpoint, Error error) = 0; + + protected: + virtual ~Client(); }; enum class DeviceMediaPolicy { diff --git a/cast/standalone_e2e.py b/cast/standalone_e2e.py new file mode 100755 index 00000000..0d8a7c98 --- /dev/null +++ b/cast/standalone_e2e.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +# Copyright 2021 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. +""" +This script is intended to cover end to end testing for the standalone sender +and receiver executables in cast. This ensures that the basic functionality of +these executables is not impaired, such as the TLS/UDP connections and encoding +and decoding video. +""" + +import argparse +import os +import pathlib +import logging +import subprocess +import sys +import time +import unittest +import ssl +from collections import namedtuple + +from enum import IntEnum, IntFlag +from urllib import request + +# Environment variables that can be overridden to set test properties. +ROOT_ENVVAR = 'OPENSCREEN_ROOT_DIR' +BUILD_ENVVAR = 'OPENSCREEN_BUILD_DIR' +LIBAOM_ENVVAR = 'OPENSCREEN_HAVE_LIBAOM' + +TEST_VIDEO_NAME = 'Contador_Glam.mp4' +# NOTE: we use the HTTP protocol instead of HTTPS due to certificate issues +# in the legacy urllib.request API. +TEST_VIDEO_URL = ('https://storage.googleapis.com/openscreen_standalone/' + + TEST_VIDEO_NAME) + +PROCESS_TIMEOUT = 15 # seconds + +# Open Screen test certificates expire after 3 days. We crop this slightly (by +# 8 hours) to account for potential errors in time calculations. +CERT_EXPIRY_AGE = (3 * 24 - 8) * 60 * 60 + +# These properties are based on compiled settings in Open Screen, and should +# not change without updating this file. +TEST_CERT_NAME = 'generated_root_cast_receiver.crt' +TEST_KEY_NAME = 'generated_root_cast_receiver.key' +SENDER_BINARY_NAME = 'cast_sender' +RECEIVER_BINARY_NAME = 'cast_receiver' + +EXPECTED_RECEIVER_MESSAGES = [ + "CastService is running.", "Found codec: opus (known to FFMPEG as opus)", + "Successfully negotiated a session, creating SDL players.", + "Receivers are currently destroying, resetting SDL players." +] + +class VideoCodec(IntEnum): + """There are different messages printed by the receiver depending on the codec + chosen. """ + Vp8 = 0 + Vp9 = 1 + Av1 = 2 + +VIDEO_CODEC_SPECIFIC_RECEIVER_MESSAGES = [ + "Found codec: vp8 (known to FFMPEG as vp8)", + "Found codec: vp9 (known to FFMPEG as vp9)", + "Found codec: libaom-av1 (known to FFMPEG as av1)" +] + +EXPECTED_SENDER_MESSAGES = [ + "Launching Mirroring App on the Cast Receiver", + "Max allowed media bitrate (audio + video) will be", + "Contador_Glam.mp4 (starts in one second)...", + "The video capturer has reached the end of the media stream.", + "The audio capturer has reached the end of the media stream.", + "Video complete. Exiting...", "Shutting down..." +] + +MISSING_LOG_MESSAGE = """Missing an expected message from either the sender +or receiver. This either means that one of the binaries misbehaved, or you +changed or deleted one of the log messages used for validation. Please ensure +that the necessary log messages are left unchanged, or update this +test suite's expectations.""" + +DESCRIPTION = """Runs end to end tests for the standalone Cast Streaming sender +and receiver. By default, this script assumes it is being ran from a current +working directory inside Open Screen's source directory, and uses +<root_dir>/out/Default as the build directory. To override these, set the +OPENSCREEN_ROOT_DIR and OPENSCREEN_BUILD_DIR environment variables. If the root +directory is set and the build directory is not, +<OPENSCREEN_ROOT_DIR>/out/Default will be used. In addition, if LibAOM is +installed, one can choose to run AV1 tests by defining the +OPENSCREEN_HAVE_LIBAOM environment variable. + +See below for the the help output generated by the `unittest` package.""" + + +def _set_log_level(is_verbose): + """Sets the logging level, either DEBUG or ERROR as appropriate.""" + level = logging.DEBUG if is_verbose else logging.INFO + logging.basicConfig(stream=sys.stdout, level=level) + + +def _get_loopback_adapter_name(): + """Retrieves the name of the loopback adapter (lo on Linux/lo0 on Mac).""" + if sys.platform == 'linux' or sys.platform == 'linux2': + return 'lo' + if sys.platform == 'darwin': + return 'lo0' + return None + + +def _get_file_age_in_seconds(path): + """Get the age of a given file in seconds""" + # Time is stored in seconds since epoch + file_last_modified = 0 + if path.exists(): + file_last_modified = path.stat().st_mtime + return time.time() - file_last_modified + + +def _get_build_paths(): + """Gets the root and build paths (either default or from the environment + variables), and sets related paths to binaries and files.""" + root_path = pathlib.Path( + os.environ[ROOT_ENVVAR] if os.getenv(ROOT_ENVVAR) else subprocess. + getoutput('git rev-parse --show-toplevel')) + assert root_path.exists(), 'Could not find openscreen root!' + + build_path = pathlib.Path(os.environ[BUILD_ENVVAR]) if os.getenv( + BUILD_ENVVAR) else root_path.joinpath('out', + 'Default').resolve() + assert build_path.exists(), 'Could not find openscreen build!' + + BuildPaths = namedtuple("BuildPaths", + "root build test_video cast_receiver cast_sender") + return BuildPaths(root = root_path, + build = build_path, + test_video = build_path.joinpath(TEST_VIDEO_NAME).resolve(), + cast_receiver = build_path.joinpath(RECEIVER_BINARY_NAME).resolve(), + cast_sender = build_path.joinpath(SENDER_BINARY_NAME).resolve() + ) + + +class TestFlags(IntFlag): + """ + Test flags, primarily used to control sender and receiver configuration + to test different features of the standalone libraries. + """ + UseRemoting = 1 + UseAndroidHack = 2 + + +class StandaloneCastTest(unittest.TestCase): + """ + Test class for setting up and running end to end tests on the + standalone sender and receiver binaries. This class uses the unittest + package, so methods that are executed as tests all have named prefixed + with "test_". + + This suite sets the current working directory to the root of the Open + Screen repository, and references all files from the root directory. + Generated certificates should always be in |cls.build_paths.root|. + """ + + @classmethod + def setUpClass(cls): + """Shared setup method for all tests, handles one-time updates.""" + cls.build_paths = _get_build_paths() + os.chdir(cls.build_paths.root) + cls.download_video() + cls.generate_certificates() + + @classmethod + def download_video(cls): + """Downloads the test video from Google storage.""" + if os.path.exists(cls.build_paths.test_video): + logging.debug('Video already exists, skipping download...') + return + + logging.debug('Downloading video from %s', TEST_VIDEO_URL) + with request.urlopen(TEST_VIDEO_URL, context=ssl.SSLContext()) as url: + with open(cls.build_paths.test_video, 'wb') as file: + file.write(url.read()) + + @classmethod + def generate_certificates(cls): + """Generates test certificates using the cast receiver.""" + cert_age = _get_file_age_in_seconds(pathlib.Path(TEST_CERT_NAME)) + key_age = _get_file_age_in_seconds(pathlib.Path(TEST_KEY_NAME)) + if cert_age < CERT_EXPIRY_AGE and key_age < CERT_EXPIRY_AGE: + logging.debug('Credentials are up to date...') + return + + logging.debug('Credentials out of date, generating new ones...') + try: + subprocess.check_output( + [ + cls.build_paths.cast_receiver, + '-g', # Generate certificate and private key. + '-v' # Enable verbose logging. + ], + stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + print('Generation failed with output: ', e.output.decode()) + raise + + def launch_receiver(self): + """Launches the receiver process with discovery disabled.""" + logging.debug('Launching the receiver application...') + loopback = _get_loopback_adapter_name() + self.assertTrue(loopback) + + #pylint: disable = consider-using-with + return subprocess.Popen( + [ + self.build_paths.cast_receiver, + '-d', + TEST_CERT_NAME, + '-p', + TEST_KEY_NAME, + '-x', # Skip discovery, only necessary on Mac OS X. + '-v', # Enable verbose logging. + loopback + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + def launch_sender(self, flags, codec=None): + """Launches the sender process, running the test video file once.""" + logging.debug('Launching the sender application...') + command = [ + self.build_paths.cast_sender, + '127.0.0.1:8010', + self.build_paths.test_video, + '-d', + TEST_CERT_NAME, + '-n' # Only play the video once, and then exit. + ] + if TestFlags.UseAndroidHack in flags: + command.append('-a') + if TestFlags.UseRemoting in flags: + command.append('-r') + + # The standalone sender sends VP8 if no codec command line argument is + # passed. + if codec: + command.append('-c') + if codec == VideoCodec.Vp8: + command.append('vp8') + elif codec == VideoCodec.Vp9: + command.append('vp9') + else: + self.assertTrue(codec == VideoCodec.Av1) + command.append('av1') + + #pylint: disable = consider-using-with + return subprocess.Popen(command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + def check_logs(self, logs, codec=None): + """Checks that the outputted logs contain expected behavior.""" + + # If a codec was not provided, we should make sure that the standalone + # sender sent VP8. + if codec == None: + codec = VideoCodec.Vp8 + + for message in (EXPECTED_RECEIVER_MESSAGES + + [VIDEO_CODEC_SPECIFIC_RECEIVER_MESSAGES[codec]]): + self.assertTrue( + message in logs[0], + 'Missing log message: {}.\n{}'.format(message, + MISSING_LOG_MESSAGE)) + for message in EXPECTED_SENDER_MESSAGES: + self.assertTrue( + message in logs[1], + 'Missing log message: {}.\n{}'.format(message, + MISSING_LOG_MESSAGE)) + for log, prefix in logs, ["[ERROR:", "[FATAL:"]: + self.assertTrue(prefix not in log, "Logs contained an error") + logging.debug('Finished validating log output') + + def get_output(self, flags, codec=None): + """Launches the sender and receiver, and handles exit output.""" + receiver_process = self.launch_receiver() + logging.debug('Letting the receiver start up...') + time.sleep(3) + sender_process = self.launch_sender(flags, codec) + + logging.debug('Launched sender PID %i and receiver PID %i...', + sender_process.pid, receiver_process.pid) + logging.debug('collating output...') + output = (receiver_process.communicate( + timeout=PROCESS_TIMEOUT)[1].decode('utf-8'), + sender_process.communicate( + timeout=PROCESS_TIMEOUT)[1].decode('utf-8')) + + # TODO(issuetracker.google.com/194292855): standalones should exit zero. + # Remoting causes the sender to exit with code -4. + if not TestFlags.UseRemoting in flags: + self.assertEqual(sender_process.returncode, 0, + 'sender had non-zero exit code') + return output + + def test_golden_case(self): + """Tests that when settings are normal, things work end to end.""" + output = self.get_output([]) + self.check_logs(output) + + def test_remoting(self): + """Tests that basic remoting works.""" + output = self.get_output(TestFlags.UseRemoting) + self.check_logs(output) + + def test_with_android_hack(self): + """Tests that things work when the Android RTP hack is enabled.""" + output = self.get_output(TestFlags.UseAndroidHack) + self.check_logs(output) + + def test_vp8_flag(self): + """Tests that the VP8 flag works with standard settings.""" + output = self.get_output([], VideoCodec.Vp8) + self.check_logs(output, VideoCodec.Vp8) + + def test_vp9_flag(self): + """Tests that the VP9 flag works with standard settings.""" + output = self.get_output([], VideoCodec.Vp9) + self.check_logs(output, VideoCodec.Vp9) + + @unittest.skipUnless(os.getenv(LIBAOM_ENVVAR), + 'Skipping AV1 test since LibAOM not installed.') + def test_av1_flag(self): + """Tests that the AV1 flag works with standard settings.""" + output = self.get_output([], VideoCodec.Av1) + self.check_logs(output, VideoCodec.Av1) + + +def parse_args(): + """Parses the command line arguments and sets up the logging module.""" + # NOTE for future developers: the `unittest` module will complain if it is + # passed any args that it doesn't understand. If any Open Screen-specific + # command line arguments are added in the future, they should be cropped + # from sys.argv before |unittest.main()| is called. + parser = argparse.ArgumentParser(description=DESCRIPTION) + parser.add_argument('-v', + '--verbose', + help='enable debug logging', + action='store_true') + + parsed_args = parser.parse_args(sys.argv[1:]) + _set_log_level(parsed_args.verbose) + + +if __name__ == '__main__': + parse_args() + unittest.main() diff --git a/cast/standalone_receiver/BUILD.gn b/cast/standalone_receiver/BUILD.gn index 74d53f65..23d394ac 100644 --- a/cast/standalone_receiver/BUILD.gn +++ b/cast/standalone_receiver/BUILD.gn @@ -8,18 +8,27 @@ import("//build_overrides/build.gni") # Define the executable target only when the build is configured to use the # standalone platform implementation; since this is itself a standalone # application. +# +# See [external_libraries.md](../../build/config/external_libraries.md) for more information. if (!build_with_chromium) { shared_sources = [ "cast_service.cc", "cast_service.h", "mirroring_application.cc", "mirroring_application.h", + "simple_remoting_receiver.cc", + "simple_remoting_receiver.h", "streaming_playback_controller.cc", "streaming_playback_controller.h", ] shared_deps = [ + "../../discovery:dnssd", + "../../discovery:public", + "../../platform:standalone_impl", "../common:public", + "../receiver:agent", + "../receiver:channel", "../streaming:receiver", ] @@ -63,11 +72,7 @@ if (!build_with_chromium) { executable("cast_receiver") { sources = [ "main.cc" ] - deps = [ - "../receiver:agent", - "../receiver:channel", - ] - + deps = shared_deps configs += [ "../common:certificate_config" ] if (have_external_libs) { diff --git a/cast/standalone_receiver/cast_service.cc b/cast/standalone_receiver/cast_service.cc index 75790197..92ffce93 100644 --- a/cast/standalone_receiver/cast_service.cc +++ b/cast/standalone_receiver/cast_service.cc @@ -4,12 +4,16 @@ #include "cast/standalone_receiver/cast_service.h" +#include <stdint.h> + +#include <array> #include <utility> #include "discovery/common/config.h" #include "platform/api/tls_connection_factory.h" #include "platform/base/interface_info.h" #include "platform/base/tls_listen_options.h" +#include "util/crypto/random_bytes.h" #include "util/osp_logging.h" #include "util/stringprintf.h" @@ -19,6 +23,7 @@ namespace cast { namespace { constexpr uint16_t kDefaultCastServicePort = 8010; +constexpr int kCastUniqueIdLength = 6; constexpr int kDefaultMaxBacklogSize = 64; const TlsListenOptions kDefaultListenOptions{kDefaultMaxBacklogSize}; @@ -32,59 +37,57 @@ IPEndpoint DetermineEndpoint(const InterfaceInfo& interface) { } discovery::Config MakeDiscoveryConfig(const InterfaceInfo& interface) { - discovery::Config config; - - discovery::Config::NetworkInfo::AddressFamilies supported_address_families = - discovery::Config::NetworkInfo::kNoAddressFamily; - if (interface.GetIpAddressV4()) { - supported_address_families |= discovery::Config::NetworkInfo::kUseIpV4; - } else if (interface.GetIpAddressV6()) { - supported_address_families |= discovery::Config::NetworkInfo::kUseIpV6; - } - config.network_info.push_back({interface, supported_address_families}); - - return config; + return discovery::Config{.network_info = {interface}}; } } // namespace -CastService::CastService(TaskRunner* task_runner, - const InterfaceInfo& interface, - GeneratedCredentials credentials, - const std::string& friendly_name, - const std::string& model_name, - bool enable_discovery) - : local_endpoint_(DetermineEndpoint(interface)), - credentials_(std::move(credentials)), - agent_(task_runner, credentials_.provider.get()), - mirroring_application_(task_runner, local_endpoint_.address, &agent_), +CastService::CastService(CastService::Configuration config) + : local_endpoint_(DetermineEndpoint(config.interface)), + credentials_(std::move(config.credentials)), + agent_(config.task_runner, credentials_.provider.get()), + mirroring_application_(config.task_runner, + local_endpoint_.address, + &agent_), socket_factory_(&agent_, agent_.cast_socket_client()), connection_factory_( - TlsConnectionFactory::CreateFactory(&socket_factory_, task_runner)), - discovery_service_(enable_discovery ? discovery::CreateDnsSdService( - task_runner, - this, - MakeDiscoveryConfig(interface)) - : LazyDeletedDiscoveryService()), + TlsConnectionFactory::CreateFactory(&socket_factory_, + config.task_runner)), + discovery_service_(config.enable_discovery + ? discovery::CreateDnsSdService( + config.task_runner, + this, + MakeDiscoveryConfig(config.interface)) + : LazyDeletedDiscoveryService()), discovery_publisher_( discovery_service_ - ? MakeSerialDelete<discovery::DnsSdServicePublisher<ServiceInfo>>( - task_runner, + ? MakeSerialDelete< + discovery::DnsSdServicePublisher<ReceiverInfo>>( + config.task_runner, discovery_service_.get(), kCastV2ServiceId, - ServiceInfoToDnsSdInstance) + ReceiverInfoToDnsSdInstance) : LazyDeletedDiscoveryPublisher()) { connection_factory_->SetListenCredentials(credentials_.tls_credentials); connection_factory_->Listen(local_endpoint_, kDefaultListenOptions); if (discovery_publisher_) { - ServiceInfo info; + ReceiverInfo info; info.port = local_endpoint_.port; - info.unique_id = HexEncode(interface.hardware_address); - info.friendly_name = friendly_name; - info.model_name = model_name; + if (config.interface.HasHardwareAddress()) { + info.unique_id = HexEncode(config.interface.hardware_address.data(), + config.interface.hardware_address.size()); + } else { + OSP_LOG_WARN << "Hardware address for interface " << config.interface.name + << " is empty. Generating a random unique_id."; + std::array<uint8_t, kCastUniqueIdLength> random_bytes; + GenerateRandomBytes(random_bytes.data(), kCastUniqueIdLength); + info.unique_id = HexEncode(random_bytes.data(), kCastUniqueIdLength); + } + info.friendly_name = config.friendly_name; + info.model_name = config.model_name; info.capabilities = kHasVideoOutput | kHasAudioOutput; - Error error = discovery_publisher_->Register(info); + const Error error = discovery_publisher_->Register(info); if (!error.ok()) { OnFatalError(std::move(error)); } diff --git a/cast/standalone_receiver/cast_service.h b/cast/standalone_receiver/cast_service.h index 99137de2..57bedcbe 100644 --- a/cast/standalone_receiver/cast_service.h +++ b/cast/standalone_receiver/cast_service.h @@ -8,7 +8,7 @@ #include <memory> #include <string> -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include "cast/receiver/application_agent.h" #include "cast/receiver/channel/static_credentials.h" #include "cast/receiver/public/receiver_socket_factory.h" @@ -41,19 +41,33 @@ namespace cast { // * Publishes over mDNS to be discoverable to all senders on the same LAN. class CastService final : public discovery::ReportingClient { public: - CastService(TaskRunner* task_runner, - const InterfaceInfo& interface, - GeneratedCredentials credentials, - const std::string& friendly_name, - const std::string& model_name, - bool enable_discovery = true); + struct Configuration { + // The task runner to be used for async calls. + TaskRunner* task_runner; + // The interface the cast service is running on. + InterfaceInfo interface; + + // The credentials that the cast service should use for TLS. + GeneratedCredentials credentials; + + // The friendly name to be used for broadcasting. + std::string friendly_name; + + // The model name to be used for broadcasting. + std::string model_name; + + // Whether we should broadcast over mDNS/DNS-SD. + bool enable_discovery = true; + }; + + explicit CastService(Configuration config); ~CastService() final; private: using LazyDeletedDiscoveryService = SerialDeletePtr<discovery::DnsSdService>; using LazyDeletedDiscoveryPublisher = - SerialDeletePtr<discovery::DnsSdServicePublisher<ServiceInfo>>; + SerialDeletePtr<discovery::DnsSdServicePublisher<ReceiverInfo>>; // discovery::ReportingClient overrides. void OnFatalError(Error error) final; diff --git a/cast/standalone_receiver/decoder.cc b/cast/standalone_receiver/decoder.cc index 9a2324e3..92cdc901 100644 --- a/cast/standalone_receiver/decoder.cc +++ b/cast/standalone_receiver/decoder.cc @@ -4,6 +4,8 @@ #include "cast/standalone_receiver/decoder.h" +#include <libavcodec/version.h> + #include <algorithm> #include <sstream> #include <thread> @@ -44,7 +46,14 @@ absl::Span<uint8_t> Decoder::Buffer::GetSpan() { Decoder::Client::Client() = default; Decoder::Client::~Client() = default; -Decoder::Decoder(const std::string& codec_name) : codec_name_(codec_name) {} +Decoder::Decoder(const std::string& codec_name) : codec_name_(codec_name) { +#if LIBAVCODEC_VERSION_MAJOR < 59 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + avcodec_register_all(); +#pragma GCC diagnostic pop +#endif // LIBAVCODEC_VERSION_MAJOR < 59 +} Decoder::~Decoder() = default; diff --git a/cast/standalone_receiver/decoder.h b/cast/standalone_receiver/decoder.h index 1d4d0791..30e56553 100644 --- a/cast/standalone_receiver/decoder.h +++ b/cast/standalone_receiver/decoder.h @@ -38,7 +38,6 @@ class Decoder { // Interface for receiving decoded frames and/or errors. class Client { public: - virtual ~Client(); virtual void OnFrameDecoded(FrameId frame_id, const AVFrame& frame) = 0; virtual void OnDecodeError(FrameId frame_id, std::string message) = 0; @@ -46,6 +45,7 @@ class Decoder { protected: Client(); + virtual ~Client(); }; // |codec_name| should be the codec_name field from an OFFER message. diff --git a/cast/standalone_receiver/install_demo_deps_debian.sh b/cast/standalone_receiver/install_demo_deps_debian.sh deleted file mode 100755 index c082455c..00000000 --- a/cast/standalone_receiver/install_demo_deps_debian.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env sh - -# Installs dependencies necessary for libSDL and libAVcodec on Debian systems. - -sudo apt-get install libsdl2-2.0 libsdl2-dev libavcodec libavcodec-dev \ - libavformat libavformat-dev libavutil libavutil-dev \ - libswresample libswresample-dev
\ No newline at end of file diff --git a/cast/standalone_receiver/install_demo_deps_raspian.sh b/cast/standalone_receiver/install_demo_deps_raspian.sh deleted file mode 100755 index 91acaaa6..00000000 --- a/cast/standalone_receiver/install_demo_deps_raspian.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env sh - -# Installs dependencies necessary for libSDL and libAVcodec on -# Raspberry PI units running Raspian. - -sudo apt-get install libavcodec58=7:4.1.4* libavcodec-dev=7:4.1.4* \ - libsdl2-2.0-0=2.0.9* libsdl2-dev=2.0.9* \ - libavformat-dev=7:4.1.4* diff --git a/cast/standalone_receiver/main.cc b/cast/standalone_receiver/main.cc index 9e305c8b..ac001f58 100644 --- a/cast/standalone_receiver/main.cc +++ b/cast/standalone_receiver/main.cc @@ -58,15 +58,20 @@ options: private key and certificate can then be used as values for the -p and -s flags. - -f, --friendly-name: Friendly name to be used for device discovery. + -f, --friendly-name: Friendly name to be used for receiver discovery. - -m, --model-name: Model name to be used for device discovery. + -m, --model-name: Model name to be used for receiver discovery. -t, --tracing: Enable performance tracing logging. -v, --verbose: Enable verbose logging. -h, --help: Show this help message. + + -x, --disable-discovery: Disable discovery, useful for platforms like Mac OS + where our implementation is incompatible with + the native Bonjour service. + )"; std::cerr << StringPrintf(kTemplate, argv0); @@ -95,30 +100,22 @@ InterfaceInfo GetInterfaceInfoFromName(const char* name) { return interface_info; } -void RunCastService(TaskRunnerImpl* task_runner, - const InterfaceInfo& interface, - GeneratedCredentials creds, - const std::string& friendly_name, - const std::string& model_name, - bool discovery_enabled) { +void RunCastService(TaskRunnerImpl* runner, CastService::Configuration config) { std::unique_ptr<CastService> service; - task_runner->PostTask([&] { - service = std::make_unique<CastService>(task_runner, interface, - std::move(creds), friendly_name, - model_name, discovery_enabled); - }); + runner->PostTask( + [&] { service = std::make_unique<CastService>(std::move(config)); }); OSP_LOG_INFO << "CastService is running. CTRL-C (SIGINT), or send a " "SIGTERM to exit."; - task_runner->RunUntilSignaled(); + runner->RunUntilSignaled(); // Spin the TaskRunner to execute destruction/shutdown tasks. OSP_LOG_INFO << "Shutting down..."; - task_runner->PostTask([&] { + runner->PostTask([&] { service.reset(); - task_runner->RequestStopSoon(); + runner->RequestStopSoon(); }); - task_runner->RunUntilStopped(); + runner->RunUntilStopped(); OSP_LOG_INFO << "Bye!"; } @@ -132,6 +129,14 @@ int RunStandaloneReceiver(int argc, char* argv[]) { return 1; #endif +#if !defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) + OSP_LOG_INFO + << "Note: compiled without external libs. The dummy player will " + "be linked and no video decoding will occur. If this is not desired, " + "install the required external libraries. For more information, see: " + "[external_libraries.md](../../build/config/external_libraries.md)."; +#endif + // A note about modifying command line arguments: consider uniformity // between all Open Screen executables. If it is a platform feature // being exposed, consider if it applies to the standalone receiver, @@ -152,7 +157,7 @@ int RunStandaloneReceiver(int argc, char* argv[]) { {nullptr, 0, nullptr, 0}}; bool is_verbose = false; - bool discovery_enabled = true; + bool enable_discovery = true; std::string private_key_path; std::string developer_certificate_path; std::string friendly_name = "Cast Standalone Receiver"; @@ -160,7 +165,7 @@ int RunStandaloneReceiver(int argc, char* argv[]) { bool should_generate_credentials = false; std::unique_ptr<TextTraceLoggingPlatform> trace_logger; int ch = -1; - while ((ch = getopt_long(argc, argv, "p:d:f:m:gtvhx", kArgumentOptions, + while ((ch = getopt_long(argc, argv, "p:d:f:m:grtvhx", kArgumentOptions, nullptr)) != -1) { switch (ch) { case 'p': @@ -185,7 +190,7 @@ int RunStandaloneReceiver(int argc, char* argv[]) { is_verbose = true; break; case 'x': - discovery_enabled = false; + enable_discovery = false; break; case 'h': LogUsage(argv[0]); @@ -211,29 +216,22 @@ int RunStandaloneReceiver(int argc, char* argv[]) { OSP_CHECK(interface_name && strlen(interface_name) > 0) << "No interface name provided."; - std::string device_id = + std::string receiver_id = absl::StrCat("Standalone Receiver on ", interface_name); ErrorOr<GeneratedCredentials> creds = GenerateCredentials( - device_id, private_key_path, developer_certificate_path); + receiver_id, private_key_path, developer_certificate_path); OSP_CHECK(creds.is_value()) << creds.error(); const InterfaceInfo interface = GetInterfaceInfoFromName(interface_name); OSP_CHECK(interface.GetIpAddressV4() || interface.GetIpAddressV6()); - if (std::all_of(interface.hardware_address.begin(), - interface.hardware_address.end(), - [](int e) { return e == 0; })) { - OSP_LOG_WARN - << "Hardware address is empty. Either you are on a loopback device " - "or getting the network interface information failed somehow. " - "Discovery publishing will be disabled."; - discovery_enabled = false; - } auto* const task_runner = new TaskRunnerImpl(&Clock::now); PlatformClientPosix::Create(milliseconds(50), std::unique_ptr<TaskRunnerImpl>(task_runner)); - RunCastService(task_runner, interface, std::move(creds.value()), - friendly_name, model_name, discovery_enabled); + RunCastService(task_runner, + CastService::Configuration{ + task_runner, interface, std::move(creds.value()), + friendly_name, model_name, enable_discovery}); PlatformClientPosix::ShutDown(); return 0; diff --git a/cast/standalone_receiver/mirroring_application.cc b/cast/standalone_receiver/mirroring_application.cc index a04c401a..683fad5d 100644 --- a/cast/standalone_receiver/mirroring_application.cc +++ b/cast/standalone_receiver/mirroring_application.cc @@ -4,7 +4,10 @@ #include "cast/standalone_receiver/mirroring_application.h" +#include <utility> + #include "cast/common/public/message_port.h" +#include "cast/streaming/constants.h" #include "cast/streaming/environment.h" #include "cast/streaming/message_fields.h" #include "cast/streaming/receiver_session.h" @@ -14,9 +17,6 @@ namespace openscreen { namespace cast { -const char kMirroringAppId[] = "0F5096E8"; -const char kMirroringAudioOnlyAppId[] = "85CDB22F"; - const char kMirroringDisplayName[] = "Chrome Mirroring"; const char kRemotingRpcNamespace[] = "urn:x-cast:com.google.cast.remoting"; @@ -55,9 +55,15 @@ bool MirroringApplication::Launch(const std::string& app_id, IPEndpoint{interface_address_, kDefaultCastStreamingPort}); controller_ = std::make_unique<StreamingPlaybackController>(task_runner_, this); - current_session_ = std::make_unique<ReceiverSession>( - controller_.get(), environment_.get(), message_port, - ReceiverSession::Preferences{}); + + ReceiverSession::Preferences preferences; + preferences.video_codecs.insert(preferences.video_codecs.end(), + {VideoCodec::kVp9, VideoCodec::kAv1}); + preferences.remoting = + std::make_unique<ReceiverSession::RemotingPreferences>(); + current_session_ = + std::make_unique<ReceiverSession>(controller_.get(), environment_.get(), + message_port, std::move(preferences)); return true; } diff --git a/cast/standalone_receiver/sdl_glue.cc b/cast/standalone_receiver/sdl_glue.cc index 7c2c94da..c4619f09 100644 --- a/cast/standalone_receiver/sdl_glue.cc +++ b/cast/standalone_receiver/sdl_glue.cc @@ -4,6 +4,8 @@ #include "cast/standalone_receiver/sdl_glue.h" +#include <utility> + #include "platform/api/task_runner.h" #include "platform/api/time.h" #include "util/osp_logging.h" @@ -21,6 +23,11 @@ SDLEventLoopProcessor::SDLEventLoopProcessor( SDLEventLoopProcessor::~SDLEventLoopProcessor() = default; +void SDLEventLoopProcessor::RegisterForKeyboardEvent( + SDLEventLoopProcessor::KeyboardEventCallback cb) { + keyboard_callbacks_.push_back(std::move(cb)); +} + void SDLEventLoopProcessor::ProcessPendingEvents() { // Process all pending events. SDL_Event event; @@ -30,6 +37,10 @@ void SDLEventLoopProcessor::ProcessPendingEvents() { if (quit_callback_) { quit_callback_(); } + } else if (event.type == SDL_KEYUP) { + for (auto& cb : keyboard_callbacks_) { + cb(event.key); + } } } diff --git a/cast/standalone_receiver/sdl_glue.h b/cast/standalone_receiver/sdl_glue.h index 59a3a020..7e136074 100644 --- a/cast/standalone_receiver/sdl_glue.h +++ b/cast/standalone_receiver/sdl_glue.h @@ -7,14 +7,16 @@ #include <stdint.h> -#include <functional> -#include <memory> - #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wimplicit-fallthrough" #include <SDL2/SDL.h> #pragma GCC diagnostic pop +#include <functional> +#include <memory> +#include <utility> +#include <vector> + #include "util/alarm.h" namespace openscreen { @@ -66,11 +68,15 @@ class SDLEventLoopProcessor { std::function<void()> quit_callback); ~SDLEventLoopProcessor(); + using KeyboardEventCallback = std::function<void(const SDL_KeyboardEvent&)>; + void RegisterForKeyboardEvent(KeyboardEventCallback cb); + private: void ProcessPendingEvents(); Alarm alarm_; std::function<void()> quit_callback_; + std::vector<KeyboardEventCallback> keyboard_callbacks_; }; } // namespace cast diff --git a/cast/standalone_receiver/sdl_video_player.cc b/cast/standalone_receiver/sdl_video_player.cc index 999545de..a1b89177 100644 --- a/cast/standalone_receiver/sdl_video_player.cc +++ b/cast/standalone_receiver/sdl_video_player.cc @@ -8,6 +8,7 @@ #include <utility> #include "cast/standalone_receiver/avcodec_glue.h" +#include "util/enum_name_table.h" #include "util/osp_logging.h" #include "util/trace_logging.h" @@ -18,6 +19,13 @@ namespace { constexpr char kVideoMediaType[] = "video"; } // namespace +constexpr EnumNameTable<VideoCodec, 6> kFfmpegCodecDescriptors{ + {{"h264", VideoCodec::kH264}, + {"vp8", VideoCodec::kVp8}, + {"hevc", VideoCodec::kHevc}, + {"vp9", VideoCodec::kVp9}, + {"libaom-av1", VideoCodec::kAv1}}}; + SDLVideoPlayer::SDLVideoPlayer(ClockNowFunctionPtr now_function, TaskRunner* task_runner, Receiver* receiver, @@ -27,7 +35,7 @@ SDLVideoPlayer::SDLVideoPlayer(ClockNowFunctionPtr now_function, : SDLPlayerBase(now_function, task_runner, receiver, - CodecToString(codec), + GetEnumName(kFfmpegCodecDescriptors, codec).value(), std::move(error_callback), kVideoMediaType), renderer_(renderer) { diff --git a/cast/standalone_receiver/simple_remoting_receiver.cc b/cast/standalone_receiver/simple_remoting_receiver.cc new file mode 100644 index 00000000..c22f3271 --- /dev/null +++ b/cast/standalone_receiver/simple_remoting_receiver.cc @@ -0,0 +1,118 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_receiver/simple_remoting_receiver.h" + +#include <utility> + +#include "cast/streaming/message_fields.h" +#include "cast/streaming/remoting.pb.h" + +namespace openscreen { +namespace cast { + +namespace { + +VideoCodec ParseProtoCodec(VideoDecoderConfig::Codec value) { + switch (value) { + case VideoDecoderConfig_Codec_kCodecHEVC: + return VideoCodec::kHevc; + + case VideoDecoderConfig_Codec_kCodecH264: + return VideoCodec::kH264; + + case VideoDecoderConfig_Codec_kCodecVP8: + return VideoCodec::kVp8; + + case VideoDecoderConfig_Codec_kCodecVP9: + return VideoCodec::kVp9; + + case VideoDecoderConfig_Codec_kCodecAV1: + return VideoCodec::kAv1; + + default: + return VideoCodec::kNotSpecified; + } +} + +AudioCodec ParseProtoCodec(AudioDecoderConfig::Codec value) { + switch (value) { + case AudioDecoderConfig_Codec_kCodecAAC: + return AudioCodec::kAac; + + case AudioDecoderConfig_Codec_kCodecOpus: + return AudioCodec::kOpus; + + default: + return AudioCodec::kNotSpecified; + } +} + +} // namespace + +SimpleRemotingReceiver::SimpleRemotingReceiver(RpcMessenger* messenger) + : messenger_(messenger) { + messenger_->RegisterMessageReceiverCallback( + RpcMessenger::kFirstHandle, [this](std::unique_ptr<RpcMessage> message) { + this->OnInitializeCallbackMessage(std::move(message)); + }); +} + +SimpleRemotingReceiver::~SimpleRemotingReceiver() { + messenger_->UnregisterMessageReceiverCallback(RpcMessenger::kFirstHandle); +} + +void SimpleRemotingReceiver::SendInitializeMessage( + SimpleRemotingReceiver::InitializeCallback initialize_cb) { + initialize_cb_ = std::move(initialize_cb); + + OSP_DVLOG + << "Indicating to the sender we are ready for remoting initialization."; + openscreen::cast::RpcMessage rpc; + rpc.set_handle(RpcMessenger::kAcquireRendererHandle); + rpc.set_proc(openscreen::cast::RpcMessage::RPC_DS_INITIALIZE); + + // The initialize message contains the handle to be used for sending the + // initialization callback message. + rpc.set_integer_value(RpcMessenger::kFirstHandle); + messenger_->SendMessageToRemote(rpc); +} + +void SimpleRemotingReceiver::SendPlaybackRateMessage(double playback_rate) { + openscreen::cast::RpcMessage rpc; + rpc.set_handle(RpcMessenger::kAcquireRendererHandle); + rpc.set_proc(openscreen::cast::RpcMessage::RPC_R_SETPLAYBACKRATE); + rpc.set_double_value(playback_rate); + messenger_->SendMessageToRemote(rpc); +} + +void SimpleRemotingReceiver::OnInitializeCallbackMessage( + std::unique_ptr<RpcMessage> message) { + OSP_DCHECK(message->proc() == RpcMessage::RPC_DS_INITIALIZE_CALLBACK); + if (!initialize_cb_) { + OSP_DLOG_INFO << "Received an initialization callback message but no " + "callback was set."; + return; + } + + const DemuxerStreamInitializeCallback& callback_message = + message->demuxerstream_initializecb_rpc(); + const auto audio_codec = + callback_message.has_audio_decoder_config() + ? ParseProtoCodec(callback_message.audio_decoder_config().codec()) + : AudioCodec::kNotSpecified; + const auto video_codec = + callback_message.has_video_decoder_config() + ? ParseProtoCodec(callback_message.video_decoder_config().codec()) + : VideoCodec::kNotSpecified; + + OSP_DLOG_INFO << "Initializing remoting with audio codec " + << CodecToString(audio_codec) << " and video codec " + << CodecToString(video_codec); + initialize_cb_(audio_codec, video_codec); + initialize_cb_ = nullptr; +} + +} // namespace cast +} // namespace openscreen diff --git a/cast/standalone_receiver/simple_remoting_receiver.h b/cast/standalone_receiver/simple_remoting_receiver.h new file mode 100644 index 00000000..8e672574 --- /dev/null +++ b/cast/standalone_receiver/simple_remoting_receiver.h @@ -0,0 +1,56 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_RECEIVER_SIMPLE_REMOTING_RECEIVER_H_ +#define CAST_STANDALONE_RECEIVER_SIMPLE_REMOTING_RECEIVER_H_ + +#include <functional> +#include <memory> + +#include "cast/streaming/constants.h" +#include "cast/streaming/rpc_messenger.h" + +namespace openscreen { +namespace cast { + +// This class behaves like a pared-down version of Chrome's DemuxerStreamAdapter +// (see +// https://source.chromium.org/chromium/chromium/src/+/main:/media/remoting/demuxer_stream_adapter.h +// ). Instead of providing a full adapter implementation, it just provides a +// callback register that can be used to notify a component when the +// RemotingProvider sends an initialization message with audio and video codec +// information. +// +// Due to the sheer complexity of remoting, we don't have a fully functional +// implementation of remoting in the standalone_* components, instead Chrome is +// the reference implementation and we have these simple classes to exercise +// the public APIs. +class SimpleRemotingReceiver { + public: + explicit SimpleRemotingReceiver(RpcMessenger* messenger); + ~SimpleRemotingReceiver(); + + // The flow here closely mirrors the remoting.proto. The standalone receiver + // indicates it is ready for initialization by calling + // |SendInitializeMessage|, then this class sends an initialize message to the + // sender. The sender then replies with an initialization message containing + // configurations, which is passed to |initialize_cb|. + using InitializeCallback = std::function<void(AudioCodec, VideoCodec)>; + void SendInitializeMessage(InitializeCallback initialize_cb); + + // The speed at which the content is decoded is synchronized with the + // playback rate. Pausing is a special case with a playback rate of 0.0. + void SendPlaybackRateMessage(double playback_rate); + + private: + void OnInitializeCallbackMessage(std::unique_ptr<RpcMessage> message); + + RpcMessenger* messenger_; + InitializeCallback initialize_cb_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_RECEIVER_SIMPLE_REMOTING_RECEIVER_H_ diff --git a/cast/standalone_receiver/streaming_playback_controller.cc b/cast/standalone_receiver/streaming_playback_controller.cc index f9196ae5..5f6412a4 100644 --- a/cast/standalone_receiver/streaming_playback_controller.cc +++ b/cast/standalone_receiver/streaming_playback_controller.cc @@ -7,11 +7,11 @@ #include <string> #if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) -#include "cast/standalone_receiver/sdl_audio_player.h" -#include "cast/standalone_receiver/sdl_glue.h" -#include "cast/standalone_receiver/sdl_video_player.h" +#include "cast/standalone_receiver/sdl_audio_player.h" // nogncheck +#include "cast/standalone_receiver/sdl_glue.h" // nogncheck +#include "cast/standalone_receiver/sdl_video_player.h" // nogncheck #else -#include "cast/standalone_receiver/dummy_player.h" +#include "cast/standalone_receiver/dummy_player.h" // nogncheck #endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) #include "util/trace_logging.h" @@ -19,6 +19,8 @@ namespace openscreen { namespace cast { +StreamingPlaybackController::Client::~Client() = default; + #if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) StreamingPlaybackController::StreamingPlaybackController( TaskRunner* task_runner, @@ -42,6 +44,11 @@ StreamingPlaybackController::StreamingPlaybackController( OSP_CHECK(window_) << "Failed to create SDL window: " << SDL_GetError(); renderer_ = MakeUniqueSDLRenderer(window_.get(), -1, 0); OSP_CHECK(renderer_) << "Failed to create SDL renderer: " << SDL_GetError(); + + sdl_event_loop_.RegisterForKeyboardEvent( + [this](const SDL_KeyboardEvent& event) { + this->HandleKeyboardEvent(event); + }); } #else StreamingPlaybackController::StreamingPlaybackController( @@ -53,11 +60,49 @@ StreamingPlaybackController::StreamingPlaybackController( } #endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) -void StreamingPlaybackController::OnMirroringNegotiated( +void StreamingPlaybackController::OnNegotiated( const ReceiverSession* session, ReceiverSession::ConfiguredReceivers receivers) { TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneReceiver); + Initialize(receivers); +} + +void StreamingPlaybackController::OnRemotingNegotiated( + const ReceiverSession* session, + ReceiverSession::RemotingNegotiation negotiation) { + remoting_receiver_ = + std::make_unique<SimpleRemotingReceiver>(negotiation.messenger); + remoting_receiver_->SendInitializeMessage( + [this, receivers = negotiation.receivers](AudioCodec audio_codec, + VideoCodec video_codec) { + // The configurations in |negotiation| do not have the actual codecs, + // only REMOTE_AUDIO and REMOTE_VIDEO. Once we receive the + // initialization callback method, we can override with the actual + // codecs here. + auto mutable_receivers = receivers; + mutable_receivers.audio_config.codec = audio_codec; + mutable_receivers.video_config.codec = video_codec; + Initialize(mutable_receivers); + }); +} + +void StreamingPlaybackController::OnReceiversDestroying( + const ReceiverSession* session, + ReceiversDestroyingReason reason) { + OSP_LOG_INFO << "Receivers are currently destroying, resetting SDL players."; + audio_player_.reset(); + video_player_.reset(); +} + +void StreamingPlaybackController::OnError(const ReceiverSession* session, + Error error) { + client_->OnPlaybackError(this, error); +} + +void StreamingPlaybackController::Initialize( + ReceiverSession::ConfiguredReceivers receivers) { #if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) + OSP_LOG_INFO << "Successfully negotiated a session, creating SDL players."; if (receivers.audio_receiver) { audio_player_ = std::make_unique<SDLAudioPlayer>( &Clock::now, task_runner_, receivers.audio_receiver, @@ -83,17 +128,24 @@ void StreamingPlaybackController::OnMirroringNegotiated( #endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) } -void StreamingPlaybackController::OnReceiversDestroying( - const ReceiverSession* session, - ReceiversDestroyingReason reason) { - audio_player_.reset(); - video_player_.reset(); -} +#if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) +void StreamingPlaybackController::HandleKeyboardEvent( + const SDL_KeyboardEvent& event) { + // We only handle keyboard events if we are remoting. + if (!remoting_receiver_) { + return; + } -void StreamingPlaybackController::OnError(const ReceiverSession* session, - Error error) { - client_->OnPlaybackError(this, error); + switch (event.keysym.sym) { + // See codes here: https://wiki.libsdl.org/SDL_Scancode + case SDLK_KP_SPACE: // fallthrough, "Keypad Space" + case SDLK_SPACE: // "Space" + is_playing_ = !is_playing_; + remoting_receiver_->SendPlaybackRateMessage(is_playing_ ? 1.0 : 0.0); + break; + } } +#endif } // namespace cast } // namespace openscreen diff --git a/cast/standalone_receiver/streaming_playback_controller.h b/cast/standalone_receiver/streaming_playback_controller.h index 1e81ed5b..109b8adb 100644 --- a/cast/standalone_receiver/streaming_playback_controller.h +++ b/cast/standalone_receiver/streaming_playback_controller.h @@ -7,15 +7,16 @@ #include <memory> +#include "cast/standalone_receiver/simple_remoting_receiver.h" #include "cast/streaming/receiver_session.h" #include "platform/impl/task_runner.h" #if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) -#include "cast/standalone_receiver/sdl_audio_player.h" -#include "cast/standalone_receiver/sdl_glue.h" -#include "cast/standalone_receiver/sdl_video_player.h" +#include "cast/standalone_receiver/sdl_audio_player.h" // nogncheck +#include "cast/standalone_receiver/sdl_glue.h" // nogncheck +#include "cast/standalone_receiver/sdl_video_player.h" // nogncheck #else -#include "cast/standalone_receiver/dummy_player.h" +#include "cast/standalone_receiver/dummy_player.h" // nogncheck #endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) namespace openscreen { @@ -27,41 +28,51 @@ class StreamingPlaybackController final : public ReceiverSession::Client { public: virtual void OnPlaybackError(StreamingPlaybackController* controller, Error error) = 0; + + protected: + virtual ~Client(); }; StreamingPlaybackController(TaskRunner* task_runner, StreamingPlaybackController::Client* client); // ReceiverSession::Client overrides. - void OnMirroringNegotiated( + void OnNegotiated(const ReceiverSession* session, + ReceiverSession::ConfiguredReceivers receivers) override; + void OnRemotingNegotiated( const ReceiverSession* session, - ReceiverSession::ConfiguredReceivers receivers) override; - + ReceiverSession::RemotingNegotiation negotiation) override; void OnReceiversDestroying(const ReceiverSession* session, ReceiversDestroyingReason reason) override; - void OnError(const ReceiverSession* session, Error error) override; private: TaskRunner* const task_runner_; StreamingPlaybackController::Client* client_; + void Initialize(ReceiverSession::ConfiguredReceivers receivers); + #if defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) + void HandleKeyboardEvent(const SDL_KeyboardEvent& event); + // NOTE: member ordering is important, since the sub systems must be // first-constructed, last-destroyed. Make sure any new SDL related // members are added below the sub systems. const ScopedSDLSubSystem<SDL_INIT_AUDIO> sdl_audio_sub_system_; const ScopedSDLSubSystem<SDL_INIT_VIDEO> sdl_video_sub_system_; - const SDLEventLoopProcessor sdl_event_loop_; + SDLEventLoopProcessor sdl_event_loop_; SDLWindowUniquePtr window_; SDLRendererUniquePtr renderer_; std::unique_ptr<SDLAudioPlayer> audio_player_; std::unique_ptr<SDLVideoPlayer> video_player_; + double is_playing_ = true; #else std::unique_ptr<DummyPlayer> audio_player_; std::unique_ptr<DummyPlayer> video_player_; #endif // defined(CAST_STANDALONE_RECEIVER_HAVE_EXTERNAL_LIBS) + + std::unique_ptr<SimpleRemotingReceiver> remoting_receiver_; }; } // namespace cast diff --git a/cast/standalone_sender/BUILD.gn b/cast/standalone_sender/BUILD.gn index a59d2449..a65c8bbf 100644 --- a/cast/standalone_sender/BUILD.gn +++ b/cast/standalone_sender/BUILD.gn @@ -10,19 +10,26 @@ import("//build_overrides/build.gni") # application. if (!build_with_chromium) { declare_args() { - have_libs = have_ffmpeg && have_libopus && have_libvpx + have_external_libs = have_ffmpeg && have_libopus && have_libvpx } config("standalone_external_libs") { defines = [] - if (have_libs) { + if (have_external_libs) { defines += [ "CAST_STANDALONE_SENDER_HAVE_EXTERNAL_LIBS" ] } + if (have_libaom) { + defines += [ "CAST_STANDALONE_SENDER_HAVE_LIBAOM" ] + } } executable("cast_sender") { deps = [ + "../../discovery:dnssd", + "../../discovery:public", "../../platform", + "../../platform:standalone_impl", + "../../third_party/aomedia", "../../third_party/jsoncpp", "../../util", "../common:public", @@ -36,8 +43,9 @@ if (!build_with_chromium) { include_dirs = [] lib_dirs = [] libs = [] - if (have_ffmpeg && have_libopus && have_libvpx) { + if (have_external_libs) { sources += [ + "connection_settings.h", "ffmpeg_glue.cc", "ffmpeg_glue.h", "looping_file_cast_agent.cc", @@ -46,17 +54,37 @@ if (!build_with_chromium) { "looping_file_sender.h", "receiver_chooser.cc", "receiver_chooser.h", + "remoting_sender.cc", + "remoting_sender.h", "simulated_capturer.cc", "simulated_capturer.h", + "streaming_encoder_util.cc", + "streaming_encoder_util.h", "streaming_opus_encoder.cc", "streaming_opus_encoder.h", - "streaming_vp8_encoder.cc", - "streaming_vp8_encoder.h", + "streaming_video_encoder.cc", + "streaming_video_encoder.h", + "streaming_vpx_encoder.cc", + "streaming_vpx_encoder.h", ] + include_dirs += ffmpeg_include_dirs + libopus_include_dirs + libvpx_include_dirs lib_dirs += ffmpeg_lib_dirs + libopus_lib_dirs + libvpx_lib_dirs libs += ffmpeg_libs + libopus_libs + libvpx_libs + + # LibAOM support currently recommends building from source, so is included + # separately here. + if (have_libaom) { + sources += [ + "streaming_av1_encoder.cc", + "streaming_av1_encoder.h", + ] + + include_dirs += libaom_include_dirs + lib_dirs += libaom_lib_dirs + libs += libaom_libs + } } configs += [ "../common:certificate_config" ] diff --git a/cast/standalone_sender/connection_settings.h b/cast/standalone_sender/connection_settings.h new file mode 100644 index 00000000..4c7e4849 --- /dev/null +++ b/cast/standalone_sender/connection_settings.h @@ -0,0 +1,52 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_CONNECTION_SETTINGS_H_ +#define CAST_STANDALONE_SENDER_CONNECTION_SETTINGS_H_ + +#include <string> + +#include "cast/streaming/constants.h" +#include "platform/base/interface_info.h" + +namespace openscreen { +namespace cast { + +// The connection settings for a given standalone sender instance. These fields +// are used throughout the standalone sender component to initialize state from +// the command line parameters. +struct ConnectionSettings { + // The endpoint of the receiver we wish to connect to. + IPEndpoint receiver_endpoint; + + // The path to the file that we want to play. + std::string path_to_file; + + // The maximum bitrate. Default value means a reasonable default will be + // selected. + int max_bitrate = 0; + + // Whether the stream should include video, or just be audio only. + bool should_include_video = true; + + // Whether we should use the hacky RTP stream IDs for legacy android + // receivers, or if we should use the proper values. For more information, + // see https://issuetracker.google.com/184438154. + bool use_android_rtp_hack = true; + + // Whether we should use remoting for the video, instead of the default of + // mirroring. + bool use_remoting = false; + + // Whether we should loop the video when it is completed. + bool should_loop_video = true; + + // The codec to use for encoding negotiated video streams. + VideoCodec codec; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_CONNECTION_SETTINGS_H_ diff --git a/cast/standalone_sender/ffmpeg_glue.cc b/cast/standalone_sender/ffmpeg_glue.cc index a6645886..7f476582 100644 --- a/cast/standalone_sender/ffmpeg_glue.cc +++ b/cast/standalone_sender/ffmpeg_glue.cc @@ -4,6 +4,8 @@ #include "cast/standalone_sender/ffmpeg_glue.h" +#include <libavcodec/version.h> + #include "util/osp_logging.h" namespace openscreen { @@ -12,6 +14,13 @@ namespace internal { AVFormatContext* CreateAVFormatContextForFile(const char* path) { AVFormatContext* format_context = nullptr; +#if LIBAVCODEC_VERSION_MAJOR < 59 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + av_register_all(); +#pragma GCC diagnostic pop +#endif // LIBAVCODEC_VERSION_MAJOR < 59 + int result = avformat_open_input(&format_context, path, nullptr, nullptr); if (result < 0) { OSP_LOG_ERROR << "Cannot open " << path << ": " << av_err2str(result); diff --git a/cast/standalone_sender/looping_file_cast_agent.cc b/cast/standalone_sender/looping_file_cast_agent.cc index 9d4558ad..0e17ecab 100644 --- a/cast/standalone_sender/looping_file_cast_agent.cc +++ b/cast/standalone_sender/looping_file_cast_agent.cc @@ -15,6 +15,7 @@ #include "cast/streaming/offer_messages.h" #include "json/value.h" #include "platform/api/tls_connection_factory.h" +#include "util/json/json_helpers.h" #include "util/stringprintf.h" #include "util/trace_logging.h" @@ -24,45 +25,6 @@ namespace { using DeviceMediaPolicy = SenderSocketFactory::DeviceMediaPolicy; -// TODO(miu): These string constants appear in a few places and should be -// de-duped to a common location. -constexpr char kMirroringAppId[] = "0F5096E8"; -constexpr char kMirroringAudioOnlyAppId[] = "85CDB22F"; - -// Parses the given string as a JSON object. If the parse fails, an empty object -// is returned. -// -// TODO(miu): De-dupe this code (same as in cast/receiver/application_agent.cc)! -Json::Value ParseAsObject(absl::string_view value) { - ErrorOr<Json::Value> parsed = json::Parse(value); - if (parsed.is_value() && parsed.value().isObject()) { - return std::move(parsed.value()); - } - return Json::Value(Json::objectValue); -} - -// Returns true if the 'type' field in |object| has the given |type|. -// -// TODO(miu): De-dupe this code (same as in cast/receiver/application_agent.cc)! -bool HasType(const Json::Value& object, CastMessageType type) { - OSP_DCHECK(object.isObject()); - const Json::Value& value = - object.get(kMessageKeyType, Json::Value::nullSingleton()); - return value.isString() && value.asString() == CastMessageTypeToString(type); -} - -// Returns the string found in object[field] if possible; otherwise, returns -// |fallback|. The fallback string is returned if |object| is not an object or -// the |field| key does not reference a string within the object. -std::string ExtractStringFieldValue(const Json::Value& object, - const char* field, - std::string fallback = {}) { - if (object.isObject() && object[field].isString()) { - return object[field].asString(); - } - return fallback; -} - } // namespace LoopingFileCastAgent::LoopingFileCastAgent(TaskRunner* task_runner, @@ -161,18 +123,29 @@ void LoopingFileCastAgent::OnMessage(VirtualConnectionRouter* router, if (message.namespace_() == kReceiverNamespace && message_port_.GetSocketId() == ToCastSocketId(socket)) { - const Json::Value payload = ParseAsObject(message.payload_utf8()); - if (HasType(payload, CastMessageType::kReceiverStatus)) { - HandleReceiverStatus(payload); - } else if (HasType(payload, CastMessageType::kLaunchError)) { + const ErrorOr<Json::Value> payload = json::Parse(message.payload_utf8()); + if (payload.is_error()) { + OSP_LOG_ERROR << "Failed to parse message: " << payload.error(); + } + + if (HasType(payload.value(), CastMessageType::kReceiverStatus)) { + HandleReceiverStatus(payload.value()); + } else if (HasType(payload.value(), CastMessageType::kLaunchError)) { + std::string reason; + if (!json::TryParseString(payload.value()[kMessageKeyReason], &reason)) { + reason = "UNKNOWN"; + } OSP_LOG_ERROR << "Failed to launch the Cast Mirroring App on the Receiver! Reason: " - << ExtractStringFieldValue(payload, kMessageKeyReason, "UNKNOWN"); + << reason; Shutdown(); - } else if (HasType(payload, CastMessageType::kInvalidRequest)) { + } else if (HasType(payload.value(), CastMessageType::kInvalidRequest)) { + std::string reason; + if (!json::TryParseString(payload.value()[kMessageKeyReason], &reason)) { + reason = "UNKNOWN"; + } OSP_LOG_ERROR << "Cast Receiver thinks our request is invalid: " - << ExtractStringFieldValue(payload, kMessageKeyReason, - "UNKNOWN"); + << reason; } } } @@ -191,9 +164,9 @@ void LoopingFileCastAgent::HandleReceiverStatus(const Json::Value& status) { ? status[kMessageKeyStatus][kMessageKeyApplications][0] : Json::Value(); - const std::string& running_app_id = - ExtractStringFieldValue(details, kMessageKeyAppId); - if (running_app_id != GetMirroringAppId()) { + std::string running_app_id; + if (!json::TryParseString(details[kMessageKeyAppId], &running_app_id) || + running_app_id != GetMirroringAppId()) { // The mirroring app is not running. If it was just stopped, Shutdown() will // tear everything down. If it has been stopped already, Shutdown() is a // no-op. @@ -201,9 +174,9 @@ void LoopingFileCastAgent::HandleReceiverStatus(const Json::Value& status) { return; } - const std::string& session_id = - ExtractStringFieldValue(details, kMessageKeySessionId); - if (session_id.empty()) { + std::string session_id; + if (!json::TryParseString(details[kMessageKeySessionId], &session_id) || + session_id.empty()) { OSP_LOG_ERROR << "Cannot continue: Cast Receiver did not provide a session ID for " "the Mirroring App running on it."; @@ -229,9 +202,10 @@ void LoopingFileCastAgent::HandleReceiverStatus(const Json::Value& status) { return; } - const std::string& message_destination_id = - ExtractStringFieldValue(details, kMessageKeyTransportId); - if (message_destination_id.empty()) { + std::string message_destination_id; + if (!json::TryParseString(details[kMessageKeyTransportId], + &message_destination_id) || + message_destination_id.empty()) { OSP_LOG_ERROR << "Cannot continue: Cast Receiver did not provide a transport ID for " "routing messages to the Mirroring App running on it."; @@ -268,34 +242,51 @@ void LoopingFileCastAgent::OnRemoteMessagingOpened(bool success) { void LoopingFileCastAgent::CreateAndStartSession() { TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneSender); + OSP_DCHECK(remote_connection_.has_value()); environment_ = std::make_unique<Environment>(&Clock::now, task_runner_, IPEndpoint{}); - OSP_DCHECK(remote_connection_.has_value()); - current_session_ = std::make_unique<SenderSession>( - connection_settings_->receiver_endpoint.address, this, environment_.get(), - &message_port_, remote_connection_->local_id, - remote_connection_->peer_id); + + SenderSession::Configuration config{ + connection_settings_->receiver_endpoint.address, + this, + environment_.get(), + &message_port_, + remote_connection_->local_id, + remote_connection_->peer_id, + connection_settings_->use_android_rtp_hack}; + current_session_ = std::make_unique<SenderSession>(std::move(config)); OSP_DCHECK(!message_port_.client_sender_id().empty()); AudioCaptureConfig audio_config; // Opus does best at 192kbps, so we cap that here. audio_config.bit_rate = 192 * 1000; - VideoCaptureConfig video_config; - // The video config is allowed to use whatever is left over after audio. - video_config.max_bit_rate = - connection_settings_->max_bitrate - audio_config.bit_rate; + VideoCaptureConfig video_config = { + .codec = connection_settings_->codec, + // The video config is allowed to use whatever is left over after audio. + .max_bit_rate = + connection_settings_->max_bitrate - audio_config.bit_rate}; // Use default display resolution of 1080P. - video_config.resolutions.emplace_back(DisplayResolution{}); + video_config.resolutions.emplace_back(Resolution{1920, 1080}); OSP_VLOG << "Starting session negotiation."; - const Error negotiation_error = - current_session_->NegotiateMirroring({audio_config}, {video_config}); + Error negotiation_error; + if (connection_settings_->use_remoting) { + remoting_sender_ = std::make_unique<RemotingSender>( + current_session_->rpc_messenger(), AudioCodec::kOpus, + connection_settings_->codec, this); + + negotiation_error = + current_session_->NegotiateRemoting(audio_config, video_config); + } else { + negotiation_error = + current_session_->Negotiate({audio_config}, {video_config}); + } if (!negotiation_error.ok()) { OSP_LOG_ERROR << "Failed to negotiate a session: " << negotiation_error; } } -void LoopingFileCastAgent::OnMirroringNegotiated( +void LoopingFileCastAgent::OnNegotiated( const SenderSession* session, SenderSession::ConfiguredSenders senders, capture_recommendations::Recommendations capture_recommendations) { @@ -305,8 +296,24 @@ void LoopingFileCastAgent::OnMirroringNegotiated( } file_sender_ = std::make_unique<LoopingFileSender>( - environment_.get(), connection_settings_->path_to_file.c_str(), session, - std::move(senders), connection_settings_->max_bitrate); + environment_.get(), connection_settings_.value(), session, + std::move(senders), [this]() { shutdown_callback_(); }); +} + +void LoopingFileCastAgent::OnRemotingNegotiated( + const SenderSession* session, + SenderSession::RemotingNegotiation negotiation) { + if (negotiation.senders.audio_sender == nullptr && + negotiation.senders.video_sender == nullptr) { + OSP_LOG_ERROR << "Missing both audio and video, so exiting..."; + return; + } + + current_negotiation_ = + std::make_unique<SenderSession::RemotingNegotiation>(negotiation); + if (is_ready_for_remoting_) { + StartRemotingSenders(); + } } void LoopingFileCastAgent::OnError(const SenderSession* session, Error error) { @@ -314,6 +321,27 @@ void LoopingFileCastAgent::OnError(const SenderSession* session, Error error) { Shutdown(); } +void LoopingFileCastAgent::OnReady() { + is_ready_for_remoting_ = true; + if (current_negotiation_) { + StartRemotingSenders(); + } +} + +void LoopingFileCastAgent::OnPlaybackRateChange(double rate) { + file_sender_->SetPlaybackRate(rate); +} + +void LoopingFileCastAgent::StartRemotingSenders() { + OSP_DCHECK(current_negotiation_); + file_sender_ = std::make_unique<LoopingFileSender>( + environment_.get(), connection_settings_.value(), current_session_.get(), + std::move(current_negotiation_->senders), + [this]() { shutdown_callback_(); }); + current_negotiation_.reset(); + is_ready_for_remoting_ = false; +} + void LoopingFileCastAgent::Shutdown() { TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneSender); diff --git a/cast/standalone_sender/looping_file_cast_agent.h b/cast/standalone_sender/looping_file_cast_agent.h index cc2d5869..3ec2e8fa 100644 --- a/cast/standalone_sender/looping_file_cast_agent.h +++ b/cast/standalone_sender/looping_file_cast_agent.h @@ -19,7 +19,9 @@ #include "cast/common/channel/virtual_connection_router.h" #include "cast/common/public/cast_socket.h" #include "cast/sender/public/sender_socket_factory.h" +#include "cast/standalone_sender/connection_settings.h" #include "cast/standalone_sender/looping_file_sender.h" +#include "cast/standalone_sender/remoting_sender.h" #include "cast/streaming/environment.h" #include "cast/streaming/sender_session.h" #include "platform/api/scoped_wake_lock.h" @@ -68,7 +70,8 @@ class LoopingFileCastAgent final public VirtualConnectionRouter::SocketErrorHandler, public ConnectionNamespaceHandler::VirtualConnectionPolicy, public CastMessageHandler, - public SenderSession::Client { + public SenderSession::Client, + public RemotingSender::Client { public: using ShutdownCallback = std::function<void()>; @@ -78,25 +81,6 @@ class LoopingFileCastAgent final ShutdownCallback shutdown_callback); ~LoopingFileCastAgent(); - struct ConnectionSettings { - // The endpoint of the receiver we wish to connect to. - IPEndpoint receiver_endpoint; - - // The path to the file that we want to play. - std::string path_to_file; - - // The maximum bitrate. Default value means a reasonable default will be - // selected. - int max_bitrate = 0; - - // Whether the stream should include video, or just be audio only. - bool should_include_video = true; - - // Whether we should use the hacky RTP stream IDs for legacy android - // receivers, or if we should use the proper values. - bool use_android_rtp_hack = true; - }; - // Connect to a Cast Receiver, and start the workflow to establish a // mirroring/streaming session. Destroy the LoopingFileCastAgent to shutdown // and disconnect. @@ -124,6 +108,10 @@ class LoopingFileCastAgent final CastSocket* socket, ::cast::channel::CastMessage message) override; + // RemotingSender::Client overrides. + void OnReady() override; + void OnPlaybackRateChange(double rate) override; + // Returns the Cast application ID for either A/V mirroring or audio-only // mirroring, as configured by the ConnectionSettings. const char* GetMirroringAppId() const; @@ -143,12 +131,20 @@ class LoopingFileCastAgent final void CreateAndStartSession(); // SenderSession::Client overrides. - void OnMirroringNegotiated(const SenderSession* session, - SenderSession::ConfiguredSenders senders, - capture_recommendations::Recommendations - capture_recommendations) override; + void OnNegotiated(const SenderSession* session, + SenderSession::ConfiguredSenders senders, + capture_recommendations::Recommendations + capture_recommendations) override; + void OnRemotingNegotiated( + const SenderSession* session, + SenderSession::RemotingNegotiation negotiation) override; void OnError(const SenderSession* session, Error error) override; + // Starts the remoting sender. This may occur when remoting is "ready" if the + // session is already negotiated, or upon session negotiation if the receiver + // is already ready. + void StartRemotingSenders(); + // Helper for stopping the current session, and/or unwinding a remote // connection request (pre-session). This ensures LoopingFileCastAgent is in a // terminal shutdown state. @@ -183,6 +179,17 @@ class LoopingFileCastAgent final std::unique_ptr<Environment> environment_; std::unique_ptr<SenderSession> current_session_; std::unique_ptr<LoopingFileSender> file_sender_; + + // Remoting specific member variables. + std::unique_ptr<RemotingSender> remoting_sender_; + + // Set when remoting is successfully negotiated. However, remoting streams + // won't start until |is_ready_for_remoting_| is true. + std::unique_ptr<SenderSession::RemotingNegotiation> current_negotiation_; + + // Set to true when the remoting receiver is ready. However, remoting streams + // won't start until remoting is successfully negotiated. + bool is_ready_for_remoting_ = false; }; } // namespace cast diff --git a/cast/standalone_sender/looping_file_sender.cc b/cast/standalone_sender/looping_file_sender.cc index 9fae8439..4362add6 100644 --- a/cast/standalone_sender/looping_file_sender.cc +++ b/cast/standalone_sender/looping_file_sender.cc @@ -4,36 +4,46 @@ #include "cast/standalone_sender/looping_file_sender.h" +#include <utility> + +#if defined(CAST_STANDALONE_SENDER_HAVE_LIBAOM) +#include "cast/standalone_sender/streaming_av1_encoder.h" +#endif +#include "cast/standalone_sender/streaming_vpx_encoder.h" +#include "util/osp_logging.h" #include "util/trace_logging.h" namespace openscreen { namespace cast { LoopingFileSender::LoopingFileSender(Environment* environment, - const char* path, + ConnectionSettings settings, const SenderSession* session, SenderSession::ConfiguredSenders senders, - int max_bitrate) + ShutdownCallback shutdown_callback) : env_(environment), - path_(path), + settings_(std::move(settings)), session_(session), - max_bitrate_(max_bitrate), + shutdown_callback_(std::move(shutdown_callback)), audio_encoder_(senders.audio_sender->config().channels, StreamingOpusEncoder::kDefaultCastAudioFramesPerSecond, senders.audio_sender), - video_encoder_(StreamingVp8Encoder::Parameters{}, - env_->task_runner(), - senders.video_sender), + video_encoder_(CreateVideoEncoder( + StreamingVideoEncoder::Parameters{.codec = settings.codec}, + env_->task_runner(), + senders.video_sender)), next_task_(env_->now_function(), env_->task_runner()), console_update_task_(env_->now_function(), env_->task_runner()) { // Opus and Vp8 are the default values for the config, and if these are set // to a different value that means we offered a codec that we do not // support, which is a developer error. OSP_CHECK(senders.audio_config.codec == AudioCodec::kOpus); - OSP_CHECK(senders.video_config.codec == VideoCodec::kVp8); + OSP_CHECK(senders.video_config.codec == VideoCodec::kVp8 || + senders.video_config.codec == VideoCodec::kVp9 || + senders.video_config.codec == VideoCodec::kAv1); OSP_LOG_INFO << "Max allowed media bitrate (audio + video) will be " - << max_bitrate_; - bandwidth_being_utilized_ = max_bitrate_ / 2; + << settings_.max_bitrate; + bandwidth_being_utilized_ = settings_.max_bitrate / 2; UpdateEncoderBitrates(); next_task_.Schedule([this] { SendFileAgain(); }, Alarm::kImmediately); @@ -41,14 +51,19 @@ LoopingFileSender::LoopingFileSender(Environment* environment, LoopingFileSender::~LoopingFileSender() = default; +void LoopingFileSender::SetPlaybackRate(double rate) { + video_capturer_->SetPlaybackRate(rate); + audio_capturer_->SetPlaybackRate(rate); +} + void LoopingFileSender::UpdateEncoderBitrates() { if (bandwidth_being_utilized_ >= kHighBandwidthThreshold) { audio_encoder_.UseHighQuality(); } else { audio_encoder_.UseStandardQuality(); } - video_encoder_.SetTargetBitrate(bandwidth_being_utilized_ - - audio_encoder_.GetBitrate()); + video_encoder_->SetTargetBitrate(bandwidth_being_utilized_ - + audio_encoder_.GetBitrate()); } void LoopingFileSender::ControlForNetworkCongestion() { @@ -72,7 +87,7 @@ void LoopingFileSender::ControlForNetworkCongestion() { // Repsect the user's maximum bitrate setting. bandwidth_being_utilized_ = - std::min(bandwidth_being_utilized_, max_bitrate_); + std::min(bandwidth_being_utilized_, settings_.max_bitrate); UpdateEncoderBitrates(); } else { @@ -84,16 +99,18 @@ void LoopingFileSender::ControlForNetworkCongestion() { } void LoopingFileSender::SendFileAgain() { - OSP_LOG_INFO << "Sending " << path_ << " (starts in one second)..."; + OSP_LOG_INFO << "Sending " << settings_.path_to_file + << " (starts in one second)..."; TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneSender); OSP_DCHECK_EQ(num_capturers_running_, 0); num_capturers_running_ = 2; capture_start_time_ = latest_frame_time_ = env_->now() + seconds(1); - audio_capturer_.emplace(env_, path_, audio_encoder_.num_channels(), - audio_encoder_.sample_rate(), capture_start_time_, - this); - video_capturer_.emplace(env_, path_, capture_start_time_, this); + audio_capturer_.emplace( + env_, settings_.path_to_file.c_str(), audio_encoder_.num_channels(), + audio_encoder_.sample_rate(), capture_start_time_, this); + video_capturer_.emplace(env_, settings_.path_to_file.c_str(), + capture_start_time_, this); next_task_.ScheduleFromNow([this] { ControlForNetworkCongestion(); }, kCongestionCheckInterval); @@ -113,7 +130,7 @@ void LoopingFileSender::OnVideoFrame(const AVFrame& av_frame, Clock::time_point capture_time) { TRACE_DEFAULT_SCOPED(TraceCategory::kStandaloneSender); latest_frame_time_ = std::max(capture_time, latest_frame_time_); - StreamingVp8Encoder::VideoFrame frame{}; + StreamingVideoEncoder::VideoFrame frame{}; frame.width = av_frame.width - av_frame.crop_left - av_frame.crop_right; frame.height = av_frame.height - av_frame.crop_top - av_frame.crop_bottom; frame.yuv_planes[0] = av_frame.data[0] + av_frame.crop_left + @@ -125,9 +142,9 @@ void LoopingFileSender::OnVideoFrame(const AVFrame& av_frame, for (int i = 0; i < 3; ++i) { frame.yuv_strides[i] = av_frame.linesize[i]; } - // TODO(miu): Add performance metrics visual overlay (based on Stats + // TODO(jophba): Add performance metrics visual overlay (based on Stats // callback). - video_encoder_.EncodeAndSend(frame, capture_time, {}); + video_encoder_->EncodeAndSend(frame, capture_time, {}); } void LoopingFileSender::UpdateStatusOnConsole() { @@ -156,7 +173,14 @@ void LoopingFileSender::OnEndOfFile(SimulatedCapturer* capturer) { --num_capturers_running_; if (num_capturers_running_ == 0) { console_update_task_.Cancel(); - next_task_.Schedule([this] { SendFileAgain(); }, Alarm::kImmediately); + + if (settings_.should_loop_video) { + OSP_DLOG_INFO << "Starting the media stream over again."; + next_task_.Schedule([this] { SendFileAgain(); }, Alarm::kImmediately); + } else { + OSP_DLOG_INFO << "Video complete. Exiting..."; + shutdown_callback_(); + } } } @@ -183,5 +207,28 @@ const char* LoopingFileSender::ToTrackName(SimulatedCapturer* capturer) const { return which; } +std::unique_ptr<StreamingVideoEncoder> LoopingFileSender::CreateVideoEncoder( + const StreamingVideoEncoder::Parameters& params, + TaskRunner* task_runner, + Sender* sender) { + switch (params.codec) { + case VideoCodec::kVp8: + case VideoCodec::kVp9: + return std::make_unique<StreamingVpxEncoder>(params, task_runner, sender); + case VideoCodec::kAv1: +#if defined(CAST_STANDALONE_SENDER_HAVE_LIBAOM) + return std::make_unique<StreamingAv1Encoder>(params, task_runner, sender); +#else + OSP_LOG_FATAL << "AV1 codec selected, but could not be used because " + "LibAOM not installed."; +#endif + default: + // Since we only support VP8, VP9, and AV1, any other codec value here + // should be due only to developer error. + OSP_LOG_ERROR << "Unsupported codec " << CodecToString(params.codec); + OSP_NOTREACHED(); + } +} + } // namespace cast } // namespace openscreen diff --git a/cast/standalone_sender/looping_file_sender.h b/cast/standalone_sender/looping_file_sender.h index e55a4a7e..75508e8e 100644 --- a/cast/standalone_sender/looping_file_sender.h +++ b/cast/standalone_sender/looping_file_sender.h @@ -6,12 +6,14 @@ #define CAST_STANDALONE_SENDER_LOOPING_FILE_SENDER_H_ #include <algorithm> +#include <memory> #include <string> +#include "cast/standalone_sender/connection_settings.h" #include "cast/standalone_sender/constants.h" #include "cast/standalone_sender/simulated_capturer.h" #include "cast/standalone_sender/streaming_opus_encoder.h" -#include "cast/standalone_sender/streaming_vp8_encoder.h" +#include "cast/standalone_sender/streaming_video_encoder.h" #include "cast/streaming/sender_session.h" namespace openscreen { @@ -22,14 +24,18 @@ namespace cast { class LoopingFileSender final : public SimulatedAudioCapturer::Client, public SimulatedVideoCapturer::Client { public: + using ShutdownCallback = std::function<void()>; + LoopingFileSender(Environment* environment, - const char* path, + ConnectionSettings settings, const SenderSession* session, SenderSession::ConfiguredSenders senders, - int max_bitrate); + ShutdownCallback shutdown_callback); ~LoopingFileSender() final; + void SetPlaybackRate(double rate); + private: void UpdateEncoderBitrates(); void ControlForNetworkCongestion(); @@ -46,31 +52,36 @@ class LoopingFileSender final : public SimulatedAudioCapturer::Client, void UpdateStatusOnConsole(); - // SimulatedCapturer overrides. + // SimulatedCapturer::Client overrides. void OnEndOfFile(SimulatedCapturer* capturer) final; void OnError(SimulatedCapturer* capturer, std::string message) final; const char* ToTrackName(SimulatedCapturer* capturer) const; + std::unique_ptr<StreamingVideoEncoder> CreateVideoEncoder( + const StreamingVideoEncoder::Parameters& params, + TaskRunner* task_runner, + Sender* sender); + // Holds the required injected dependencies (clock, task runner) used for Cast // Streaming, and owns the UDP socket over which all communications occur with // the remote's Receivers. Environment* const env_; - // The path to the media file to stream over and over. - const char* const path_; + // The connection settings used for this session. + const ConnectionSettings settings_; // Session to query for bandwidth information. const SenderSession* session_; - // User provided maximum bitrate (from command line argument). - const int max_bitrate_; + // Callback for tearing down the sender process. + ShutdownCallback shutdown_callback_; int bandwidth_estimate_ = 0; int bandwidth_being_utilized_; StreamingOpusEncoder audio_encoder_; - StreamingVp8Encoder video_encoder_; + std::unique_ptr<StreamingVideoEncoder> video_encoder_; int num_capturers_running_ = 0; Clock::time_point capture_start_time_{}; diff --git a/cast/standalone_sender/main.cc b/cast/standalone_sender/main.cc index 75b50553..71923d7e 100644 --- a/cast/standalone_sender/main.cc +++ b/cast/standalone_sender/main.cc @@ -23,6 +23,7 @@ #include "platform/api/time.h" #include "platform/base/error.h" #include "platform/base/ip_address.h" +#include "platform/impl/network_interface.h" #include "platform/impl/platform_client_posix.h" #include "platform/impl/task_runner.h" #include "platform/impl/text_trace_logging_platform.h" @@ -54,6 +55,10 @@ usage: %s <options> addr[:port] media_file Specifies the maximum bits per second for the media streams. Default if not set: %d + + -n, --no-looping + Disable looping the passed in video after it finishes playing. + )" #if defined(CAST_ALLOW_DEVELOPER_CERTIFICATE) R"( @@ -68,13 +73,18 @@ usage: %s <options> addr[:port] media_file R"( -a, --android-hack: Use the wrong RTP payload types, for compatibility with older Android - TV receivers. + TV receivers. See https://crbug.com/631828. + + -r, --remoting: Enable remoting content instead of mirroring. -t, --tracing: Enable performance tracing logging. -v, --verbose: Enable verbose logging. -h, --help: Show this help message. + + -c, --codec: Specifies the video codec to be used. Can be one of: + vp8, vp9, av1. Defaults to vp8 if not specified. )"; std::cerr << StringPrintf(kTemplate, argv0, argv0, kDefaultCastPort, @@ -107,23 +117,29 @@ int StandaloneSenderMain(int argc, char* argv[]) { // standalone sender, osp demo, and test_main argument options. const struct option kArgumentOptions[] = { {"max-bitrate", required_argument, nullptr, 'm'}, + {"no-looping", no_argument, nullptr, 'n'}, #if defined(CAST_ALLOW_DEVELOPER_CERTIFICATE) {"developer-certificate", required_argument, nullptr, 'd'}, #endif {"android-hack", no_argument, nullptr, 'a'}, + {"remoting", no_argument, nullptr, 'r'}, {"tracing", no_argument, nullptr, 't'}, {"verbose", no_argument, nullptr, 'v'}, {"help", no_argument, nullptr, 'h'}, + {"codec", required_argument, nullptr, 'c'}, {nullptr, 0, nullptr, 0} }; - bool is_verbose = false; + int max_bitrate = kDefaultMaxBitrate; + bool should_loop_video = true; std::string developer_certificate_path; bool use_android_rtp_hack = false; - int max_bitrate = kDefaultMaxBitrate; + bool use_remoting = false; + bool is_verbose = false; + VideoCodec codec = VideoCodec::kVp8; std::unique_ptr<TextTraceLoggingPlatform> trace_logger; int ch = -1; - while ((ch = getopt_long(argc, argv, "m:d:atvh", kArgumentOptions, + while ((ch = getopt_long(argc, argv, "m:nd:artvhc:", kArgumentOptions, nullptr)) != -1) { switch (ch) { case 'm': @@ -135,6 +151,9 @@ int StandaloneSenderMain(int argc, char* argv[]) { return 1; } break; + case 'n': + should_loop_video = false; + break; #if defined(CAST_ALLOW_DEVELOPER_CERTIFICATE) case 'd': developer_certificate_path = optarg; @@ -143,6 +162,9 @@ int StandaloneSenderMain(int argc, char* argv[]) { case 'a': use_android_rtp_hack = true; break; + case 'r': + use_remoting = true; + break; case 't': trace_logger = std::make_unique<TextTraceLoggingPlatform>(); break; @@ -152,6 +174,20 @@ int StandaloneSenderMain(int argc, char* argv[]) { case 'h': LogUsage(argv[0]); return 1; + case 'c': + auto specified_codec = StringToVideoCodec(optarg); + if (specified_codec.is_value() && + (specified_codec.value() == VideoCodec::kVp8 || + specified_codec.value() == VideoCodec::kVp9 || + specified_codec.value() == VideoCodec::kAv1)) { + codec = specified_codec.value(); + } else { + OSP_LOG_ERROR << "Invalid --codec specified: " << optarg + << " is not one of: vp8, vp9, av1."; + LogUsage(argv[0]); + return 1; + } + break; } } @@ -179,7 +215,7 @@ int StandaloneSenderMain(int argc, char* argv[]) { IPEndpoint remote_endpoint = ParseAsEndpoint(iface_or_endpoint); if (!remote_endpoint.port) { - for (const InterfaceInfo& interface : GetNetworkInterfaces()) { + for (const InterfaceInfo& interface : GetAllInterfaces()) { if (interface.name == iface_or_endpoint) { ReceiverChooser chooser(interface, task_runner, [&](IPEndpoint endpoint) { @@ -205,9 +241,15 @@ int StandaloneSenderMain(int argc, char* argv[]) { task_runner->PostTask([&] { cast_agent = new LoopingFileCastAgent( task_runner, [&] { task_runner->RequestStopSoon(); }); - cast_agent->Connect({remote_endpoint, path, max_bitrate, - true /* should_include_video */, - use_android_rtp_hack}); + + cast_agent->Connect({.receiver_endpoint = remote_endpoint, + .path_to_file = path, + .max_bitrate = max_bitrate, + .should_include_video = true, + .use_android_rtp_hack = use_android_rtp_hack, + .use_remoting = use_remoting, + .should_loop_video = should_loop_video, + .codec = codec}); }); // Run the event loop until SIGINT (e.g., CTRL-C at the console) or @@ -239,7 +281,9 @@ int main(int argc, char* argv[]) { #else OSP_LOG_ERROR << "It compiled! However, you need to configure the build to point to " - "external libraries in order to build a useful app."; + "external libraries in order to build a useful app. For more " + "information, see " + "[external_libraries.md](../../build/config/external_libraries.md)."; return 1; #endif } diff --git a/cast/standalone_sender/receiver_chooser.cc b/cast/standalone_sender/receiver_chooser.cc index 828ea8ef..8a9b209c 100644 --- a/cast/standalone_sender/receiver_chooser.cc +++ b/cast/standalone_sender/receiver_chooser.cc @@ -27,28 +27,14 @@ ReceiverChooser::ReceiverChooser(const InterfaceInfo& interface, ResultCallback result_callback) : result_callback_(std::move(result_callback)), menu_alarm_(&Clock::now, task_runner) { - using discovery::Config; - Config config; - // TODO(miu): Remove AddressFamilies from the Config in a follow-up patch. No - // client uses this to do anything other than "enabled for all address - // families," and so it doesn't need to be configurable. - Config::NetworkInfo::AddressFamilies families = - Config::NetworkInfo::kNoAddressFamily; - if (interface.GetIpAddressV4()) { - families |= Config::NetworkInfo::kUseIpV4; - } - if (interface.GetIpAddressV6()) { - families |= Config::NetworkInfo::kUseIpV6; - } - config.network_info.push_back({interface, families}); - config.enable_publication = false; - config.enable_querying = true; - service_ = - discovery::CreateDnsSdService(task_runner, this, std::move(config)); - - watcher_ = std::make_unique<discovery::DnsSdServiceWatcher<ServiceInfo>>( - service_.get(), kCastV2ServiceId, DnsSdInstanceEndpointToServiceInfo, - [this](std::vector<std::reference_wrapper<const ServiceInfo>> all) { + discovery::Config config{.network_info = {interface}, + .enable_publication = false, + .enable_querying = true}; + discovery::CreateDnsSdService(task_runner, this, std::move(config)); + + watcher_ = std::make_unique<discovery::DnsSdServiceWatcher<ReceiverInfo>>( + service_.get(), kCastV2ServiceId, DnsSdInstanceEndpointToReceiverInfo, + [this](std::vector<std::reference_wrapper<const ReceiverInfo>> all) { OnDnsWatcherUpdate(std::move(all)); }); @@ -68,15 +54,15 @@ void ReceiverChooser::OnRecoverableError(Error error) { } void ReceiverChooser::OnDnsWatcherUpdate( - std::vector<std::reference_wrapper<const ServiceInfo>> all) { + std::vector<std::reference_wrapper<const ReceiverInfo>> all) { bool added_some = false; - for (const ServiceInfo& info : all) { + for (const ReceiverInfo& info : all) { if (!info.IsValid() || (!info.v4_address && !info.v6_address)) { continue; } const std::string& instance_id = info.GetInstanceId(); if (std::any_of(discovered_receivers_.begin(), discovered_receivers_.end(), - [&](const ServiceInfo& known) { + [&](const ReceiverInfo& known) { return known.GetInstanceId() == instance_id; })) { continue; @@ -101,7 +87,7 @@ void ReceiverChooser::PrintMenuAndHandleChoice() { std::cout << '\n'; for (size_t i = 0; i < discovered_receivers_.size(); ++i) { - const ServiceInfo& info = discovered_receivers_[i]; + const ReceiverInfo& info = discovered_receivers_[i]; std::cout << '[' << i << "]: " << info.friendly_name << " @ "; if (info.v6_address) { std::cout << info.v6_address; @@ -118,7 +104,7 @@ void ReceiverChooser::PrintMenuAndHandleChoice() { const auto callback_on_stack = std::move(result_callback_); if (menu_choice >= 0 && menu_choice < static_cast<int>(discovered_receivers_.size())) { - const ServiceInfo& choice = discovered_receivers_[menu_choice]; + const ReceiverInfo& choice = discovered_receivers_[menu_choice]; if (choice.v6_address) { callback_on_stack(IPEndpoint{choice.v6_address, choice.port}); } else { diff --git a/cast/standalone_sender/receiver_chooser.h b/cast/standalone_sender/receiver_chooser.h index a2fd398f..1c7fc607 100644 --- a/cast/standalone_sender/receiver_chooser.h +++ b/cast/standalone_sender/receiver_chooser.h @@ -9,7 +9,7 @@ #include <memory> #include <vector> -#include "cast/common/public/service_info.h" +#include "cast/common/public/receiver_info.h" #include "discovery/common/reporting_client.h" #include "discovery/public/dns_sd_service_factory.h" #include "discovery/public/dns_sd_service_watcher.h" @@ -40,10 +40,10 @@ class ReceiverChooser final : public discovery::ReportingClient { void OnFatalError(Error error) final; void OnRecoverableError(Error error) final; - // Called from the DnsWatcher with |all| ServiceInfos any time there is a + // Called from the DnsWatcher with |all| ReceiverInfos any time there is a // change in the set of discovered devices. void OnDnsWatcherUpdate( - std::vector<std::reference_wrapper<const ServiceInfo>> all); + std::vector<std::reference_wrapper<const ReceiverInfo>> all); // Called from |menu_alarm_| when it is a good time for the user to choose // from the discovered-so-far set of Cast Receivers. @@ -51,8 +51,8 @@ class ReceiverChooser final : public discovery::ReportingClient { ResultCallback result_callback_; SerialDeletePtr<discovery::DnsSdService> service_; - std::unique_ptr<discovery::DnsSdServiceWatcher<ServiceInfo>> watcher_; - std::vector<ServiceInfo> discovered_receivers_; + std::unique_ptr<discovery::DnsSdServiceWatcher<ReceiverInfo>> watcher_; + std::vector<ReceiverInfo> discovered_receivers_; Alarm menu_alarm_; // After there is another Cast Receiver discovered, ready to show to the user diff --git a/cast/standalone_sender/remoting_sender.cc b/cast/standalone_sender/remoting_sender.cc new file mode 100644 index 00000000..e28c9ae1 --- /dev/null +++ b/cast/standalone_sender/remoting_sender.cc @@ -0,0 +1,113 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_sender/remoting_sender.h" + +#include <utility> + +#include "cast/streaming/message_fields.h" + +namespace openscreen { +namespace cast { + +namespace { + +VideoDecoderConfig::Codec ToProtoCodec(VideoCodec value) { + switch (value) { + case VideoCodec::kHevc: + return VideoDecoderConfig_Codec_kCodecHEVC; + case VideoCodec::kH264: + return VideoDecoderConfig_Codec_kCodecH264; + case VideoCodec::kVp8: + return VideoDecoderConfig_Codec_kCodecVP8; + case VideoCodec::kVp9: + return VideoDecoderConfig_Codec_kCodecVP9; + case VideoCodec::kAv1: + return VideoDecoderConfig_Codec_kCodecAV1; + default: + return VideoDecoderConfig_Codec_kUnknownVideoCodec; + } +} + +AudioDecoderConfig::Codec ToProtoCodec(AudioCodec value) { + switch (value) { + case AudioCodec::kAac: + return AudioDecoderConfig_Codec_kCodecAAC; + case AudioCodec::kOpus: + return AudioDecoderConfig_Codec_kCodecOpus; + default: + return AudioDecoderConfig_Codec_kUnknownAudioCodec; + } +} + +} // namespace + +RemotingSender::Client::~Client() = default; + +RemotingSender::RemotingSender(RpcMessenger* messenger, + AudioCodec audio_codec, + VideoCodec video_codec, + Client* client) + : messenger_(messenger), + audio_codec_(audio_codec), + video_codec_(video_codec), + client_(client) { + OSP_DCHECK(client_); + messenger_->RegisterMessageReceiverCallback( + RpcMessenger::kAcquireRendererHandle, + [this](std::unique_ptr<RpcMessage> message) { + OSP_DCHECK(message); + this->OnMessage(*message); + }); +} + +RemotingSender::~RemotingSender() { + messenger_->UnregisterMessageReceiverCallback( + RpcMessenger::kAcquireRendererHandle); +} + +void RemotingSender::OnMessage(const RpcMessage& message) { + if (!message.has_proc()) { + return; + } + if (message.proc() == RpcMessage_RpcProc_RPC_DS_INITIALIZE) { + OSP_VLOG << "Received initialize message"; + OnInitializeMessage(message); + } else if (message.proc() == RpcMessage_RpcProc_RPC_R_SETPLAYBACKRATE) { + OSP_VLOG << "Received playback rate message: " << message.double_value(); + OnPlaybackRateMessage(message); + } +} + +void RemotingSender::OnInitializeMessage(const RpcMessage& message) { + receiver_handle_ = message.integer_value(); + + RpcMessage callback_message; + callback_message.set_handle(receiver_handle_); + callback_message.set_proc(RpcMessage::RPC_DS_INITIALIZE_CALLBACK); + + auto* callback_body = + callback_message.mutable_demuxerstream_initializecb_rpc(); + + // In Chrome, separate calls are used for the audio and video configs, but + // for simplicity's sake we combine them here. + callback_body->mutable_audio_decoder_config()->set_codec( + ToProtoCodec(audio_codec_)); + callback_body->mutable_video_decoder_config()->set_codec( + ToProtoCodec(video_codec_)); + + OSP_DLOG_INFO << "Initializing receiver handle " << receiver_handle_ + << " with audio codec " << CodecToString(audio_codec_) + << " and video codec " << CodecToString(video_codec_); + messenger_->SendMessageToRemote(callback_message); + + client_->OnReady(); +} + +void RemotingSender::OnPlaybackRateMessage(const RpcMessage& message) { + client_->OnPlaybackRateChange(message.double_value()); +} + +} // namespace cast +} // namespace openscreen diff --git a/cast/standalone_sender/remoting_sender.h b/cast/standalone_sender/remoting_sender.h new file mode 100644 index 00000000..7d09dc69 --- /dev/null +++ b/cast/standalone_sender/remoting_sender.h @@ -0,0 +1,75 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_REMOTING_SENDER_H_ +#define CAST_STANDALONE_SENDER_REMOTING_SENDER_H_ + +#include <memory> + +#include "cast/streaming/constants.h" +#include "cast/streaming/rpc_messenger.h" + +namespace openscreen { +namespace cast { + +// This class behaves like a pared-down version of Chrome's StreamProvider (see +// https://source.chromium.org/chromium/chromium/src/+/main:media/remoting/stream_provider.h +// ). Instead of fully managing a media::DemuxerStream however, it just provides +// an RPC initialization routine that notifies the standalone receiver's +// SimpleRemotingReceiver instance (if configured) that initialization has been +// complete and what codecs were selected. +// +// Due to the sheer complexity of remoting, we don't have a fully functional +// implementation of remoting in the standalone_* components, instead Chrome is +// the reference implementation and we have these simple classes to exercise +// the public APIs. +class RemotingSender { + public: + // The remoting sender expects a valid client to handle received messages. + class Client { + public: + virtual ~Client(); + + // Executed when we receive the initialize message from the receiver. + virtual void OnReady() = 0; + + // Executed when we receive a playback rate message from the receiver. + virtual void OnPlaybackRateChange(double rate) = 0; + }; + + RemotingSender(RpcMessenger* messenger, + AudioCodec audio_codec, + VideoCodec video_codec, + Client* client); + ~RemotingSender(); + + private: + // Helper for parsing any received RPC messages. + void OnMessage(const RpcMessage& message); + void OnInitializeMessage(const RpcMessage& message); + void OnPlaybackRateMessage(const RpcMessage& message); + + // The messenger is the only caller of OnInitializeMessage, so there are no + // lifetime concerns. However, if this class outlives |messenger_|, it will + // no longer receive initialization messages. + RpcMessenger* messenger_; + + // Unlike in Chrome, here we should know the video and audio codecs before any + // of the remoting code gets set up, and for simplicity's sake we can only + // populate the AudioDecoderConfig and VideoDecoderConfig objects with the + // codecs and use the rest of the fields as-is from the OFFER/ANSWER exchange. + const AudioCodec audio_codec_; + const VideoCodec video_codec_; + + Client* client_; + + // The initialization message from the receiver contains the handle the + // callback should go to. + RpcMessenger::Handle receiver_handle_ = RpcMessenger::kInvalidHandle; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_REMOTING_SENDER_H_ diff --git a/cast/standalone_sender/simulated_capturer.cc b/cast/standalone_sender/simulated_capturer.cc index 87313010..713caa24 100644 --- a/cast/standalone_sender/simulated_capturer.cc +++ b/cast/standalone_sender/simulated_capturer.cc @@ -85,6 +85,14 @@ SimulatedCapturer::SimulatedCapturer(Environment* environment, SimulatedCapturer::~SimulatedCapturer() = default; +void SimulatedCapturer::SetPlaybackRate(double rate) { + playback_rate_is_non_zero_ = rate > 0; + if (playback_rate_is_non_zero_) { + // Restart playback now that playback rate is nonzero. + StartDecodingNextFrame(); + } +} + void SimulatedCapturer::SetAdditionalDecoderParameters( AVCodecContext* decoder_context) {} @@ -119,6 +127,9 @@ Clock::duration SimulatedCapturer::ToApproximateClockDuration( } void SimulatedCapturer::StartDecodingNextFrame() { + if (!playback_rate_is_non_zero_) { + return; + } const int read_frame_result = av_read_frame(format_context_.get(), packet_.get()); if (read_frame_result < 0) { diff --git a/cast/standalone_sender/simulated_capturer.h b/cast/standalone_sender/simulated_capturer.h index 8d32085a..61738e1f 100644 --- a/cast/standalone_sender/simulated_capturer.h +++ b/cast/standalone_sender/simulated_capturer.h @@ -40,6 +40,8 @@ class SimulatedCapturer { virtual ~Observer(); }; + void SetPlaybackRate(double rate); + protected: SimulatedCapturer(Environment* environment, const char* path, @@ -103,6 +105,10 @@ class SimulatedCapturer { // Used to schedule the next task to execute and when it should execute. There // is only ever one task scheduled/running at any time. Alarm next_task_; + + // Used to determine playback rate. Currently, we only support "playing" + // at 1x speed, or "pausing" at 0x speed. + bool playback_rate_is_non_zero_ = true; }; // Emits the primary audio stream from a file. diff --git a/cast/standalone_sender/streaming_av1_encoder.cc b/cast/standalone_sender/streaming_av1_encoder.cc new file mode 100644 index 00000000..c39332e5 --- /dev/null +++ b/cast/standalone_sender/streaming_av1_encoder.cc @@ -0,0 +1,425 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_sender/streaming_av1_encoder.h" + +#include <aom/aomcx.h> + +#include <chrono> +#include <cmath> +#include <utility> + +#include "cast/standalone_sender/streaming_encoder_util.h" +#include "cast/streaming/encoded_frame.h" +#include "cast/streaming/environment.h" +#include "cast/streaming/sender.h" +#include "util/chrono_helpers.h" +#include "util/osp_logging.h" +#include "util/saturate_cast.h" + +namespace openscreen { +namespace cast { + +// TODO(issuetracker.google.com/issues/155336511): Fix the declarations and then +// remove this: +using openscreen::operator<<; // For std::chrono::duration pretty-printing. + +namespace { + +constexpr int kBytesPerKilobyte = 1024; + +// Lower and upper bounds to the frame duration passed to aom_codec_encode(), to +// ensure sanity. Note that the upper-bound is especially important in cases +// where the video paused for some lengthy amount of time. +constexpr Clock::duration kMinFrameDuration = milliseconds(1); +constexpr Clock::duration kMaxFrameDuration = milliseconds(125); + +// Highest/lowest allowed encoding speed set to the encoder. +constexpr int kHighestEncodingSpeed = 9; +constexpr int kLowestEncodingSpeed = 0; + +} // namespace + +StreamingAv1Encoder::StreamingAv1Encoder(const Parameters& params, + TaskRunner* task_runner, + Sender* sender) + : StreamingVideoEncoder(params, task_runner, sender) { + ideal_speed_setting_ = kHighestEncodingSpeed; + encode_thread_ = std::thread([this] { ProcessWorkUnitsUntilTimeToQuit(); }); + + OSP_DCHECK(params_.codec == VideoCodec::kAv1); + const auto result = aom_codec_enc_config_default(aom_codec_av1_cx(), &config_, + AOM_USAGE_REALTIME); + OSP_CHECK_EQ(result, AOM_CODEC_OK); + + // This is set to non-zero in ConfigureForNewFrameSize() later, to flag that + // the encoder has been initialized. + config_.g_threads = 0; + + // Set the timebase to match that of openscreen::Clock::duration. + config_.g_timebase.num = Clock::duration::period::num; + config_.g_timebase.den = Clock::duration::period::den; + + // |g_pass| and |g_lag_in_frames| must be "one pass" and zero, respectively, + // because of the way the libaom API is used. + config_.g_pass = AOM_RC_ONE_PASS; + config_.g_lag_in_frames = 0; + + // Rate control settings. + config_.rc_dropframe_thresh = 0; // The encoder may not drop any frames. + config_.rc_resize_mode = 0; + config_.rc_end_usage = AOM_CBR; + config_.rc_target_bitrate = target_bitrate_ / kBytesPerKilobyte; + config_.rc_min_quantizer = params_.min_quantizer; + config_.rc_max_quantizer = params_.max_quantizer; + + // The reasons for the values chosen here (rc_*shoot_pct and rc_buf_*_sz) are + // lost in history. They were brought-over from the legacy Chrome Cast + // Streaming Sender implemenation. + config_.rc_undershoot_pct = 100; + config_.rc_overshoot_pct = 15; + config_.rc_buf_initial_sz = 500; + config_.rc_buf_optimal_sz = 600; + config_.rc_buf_sz = 1000; + + config_.kf_mode = AOM_KF_DISABLED; +} + +StreamingAv1Encoder::~StreamingAv1Encoder() { + { + std::unique_lock<std::mutex> lock(mutex_); + target_bitrate_ = 0; + cv_.notify_one(); + } + encode_thread_.join(); +} + +int StreamingAv1Encoder::GetTargetBitrate() const { + // Note: No need to lock the |mutex_| since this method should be called on + // the same thread as SetTargetBitrate(). + return target_bitrate_; +} + +void StreamingAv1Encoder::SetTargetBitrate(int new_bitrate) { + // Ensure that, when bps is converted to kbps downstream, that the encoder + // bitrate will not be zero. + new_bitrate = std::max(new_bitrate, kBytesPerKilobyte); + + std::unique_lock<std::mutex> lock(mutex_); + // Only assign the new target bitrate if |target_bitrate_| has not yet been + // used to signal the |encode_thread_| to end. + if (target_bitrate_ > 0) { + target_bitrate_ = new_bitrate; + } +} + +void StreamingAv1Encoder::EncodeAndSend( + const VideoFrame& frame, + Clock::time_point reference_time, + std::function<void(Stats)> stats_callback) { + WorkUnit work_unit; + + // TODO(jophba): The |VideoFrame| struct should provide the media timestamp, + // instead of this code inferring it from the reference timestamps, since: 1) + // the video capturer's clock may tick at a different rate than the system + // clock; and 2) to reduce jitter. + if (start_time_ == Clock::time_point::min()) { + start_time_ = reference_time; + work_unit.rtp_timestamp = RtpTimeTicks(); + } else { + work_unit.rtp_timestamp = RtpTimeTicks::FromTimeSinceOrigin( + reference_time - start_time_, sender_->rtp_timebase()); + if (work_unit.rtp_timestamp <= last_enqueued_rtp_timestamp_) { + OSP_LOG_WARN << "VIDEO[" << sender_->ssrc() + << "] Dropping: RTP timestamp is not monotonically " + "increasing from last frame."; + return; + } + } + if (sender_->GetInFlightMediaDuration(work_unit.rtp_timestamp) > + sender_->GetMaxInFlightMediaDuration()) { + OSP_LOG_WARN << "VIDEO[" << sender_->ssrc() + << "] Dropping: In-flight media duration would be too high."; + return; + } + + Clock::duration frame_duration = frame.duration; + if (frame_duration <= Clock::duration::zero()) { + // The caller did not provide the frame duration in |frame|. + if (reference_time == start_time_) { + // Use the max for the first frame so libaom will spend extra effort on + // its quality. + frame_duration = kMaxFrameDuration; + } else { + // Use the actual amount of time between the current and previous frame as + // a prediction for the next frame's duration. + frame_duration = + (work_unit.rtp_timestamp - last_enqueued_rtp_timestamp_) + .ToDuration<Clock::duration>(sender_->rtp_timebase()); + } + } + work_unit.duration = + std::max(std::min(frame_duration, kMaxFrameDuration), kMinFrameDuration); + + last_enqueued_rtp_timestamp_ = work_unit.rtp_timestamp; + + work_unit.image = CloneAsAv1Image(frame); + work_unit.reference_time = reference_time; + work_unit.stats_callback = std::move(stats_callback); + const bool force_key_frame = sender_->NeedsKeyFrame(); + { + std::unique_lock<std::mutex> lock(mutex_); + needs_key_frame_ |= force_key_frame; + encode_queue_.push(std::move(work_unit)); + cv_.notify_one(); + } +} + +void StreamingAv1Encoder::DestroyEncoder() { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + if (is_encoder_initialized()) { + aom_codec_destroy(&encoder_); + // Flag that the encoder is not initialized. See header comments for + // is_encoder_initialized(). + config_.g_threads = 0; + } +} + +void StreamingAv1Encoder::ProcessWorkUnitsUntilTimeToQuit() { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + for (;;) { + WorkUnitWithResults work_unit{}; + bool force_key_frame; + int target_bitrate; + { + std::unique_lock<std::mutex> lock(mutex_); + if (target_bitrate_ <= 0) { + break; // Time to end this thread. + } + if (encode_queue_.empty()) { + cv_.wait(lock); + if (encode_queue_.empty()) { + continue; + } + } + static_cast<WorkUnit&>(work_unit) = std::move(encode_queue_.front()); + encode_queue_.pop(); + force_key_frame = needs_key_frame_; + needs_key_frame_ = false; + target_bitrate = target_bitrate_; + } + + // Clock::now() is being called directly, instead of using a + // dependency-injected "now function," since actual wall time is being + // measured. + const Clock::time_point encode_start_time = Clock::now(); + PrepareEncoder(work_unit.image->d_w, work_unit.image->d_h, target_bitrate); + EncodeFrame(force_key_frame, work_unit); + ComputeFrameEncodeStats(Clock::now() - encode_start_time, target_bitrate, + work_unit); + UpdateSpeedSettingForNextFrame(work_unit.stats); + + main_task_runner_->PostTask( + [this, results = std::move(work_unit)]() mutable { + SendEncodedFrame(std::move(results)); + }); + } + + DestroyEncoder(); +} + +void StreamingAv1Encoder::PrepareEncoder(int width, + int height, + int target_bitrate) { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + const int target_kbps = target_bitrate / kBytesPerKilobyte; + + // Translate the |ideal_speed_setting_| into the AOME_SET_CPUUSED setting and + // the minimum quantizer to use. + int speed; + int min_quantizer; + if (ideal_speed_setting_ > kHighestEncodingSpeed) { + speed = kHighestEncodingSpeed; + const double remainder = ideal_speed_setting_ - speed; + min_quantizer = rounded_saturate_cast<int>( + remainder / kEquivalentEncodingSpeedStepPerQuantizerStep + + params_.min_quantizer); + min_quantizer = std::min(min_quantizer, params_.max_cpu_saver_quantizer); + } else { + speed = std::max(rounded_saturate_cast<int>(ideal_speed_setting_), + kLowestEncodingSpeed); + min_quantizer = params_.min_quantizer; + } + + if (static_cast<int>(config_.g_w) != width || + static_cast<int>(config_.g_h) != height) { + DestroyEncoder(); + } + + if (!is_encoder_initialized()) { + config_.g_threads = params_.num_encode_threads; + config_.g_w = width; + config_.g_h = height; + config_.rc_target_bitrate = target_kbps; + config_.rc_min_quantizer = min_quantizer; + + encoder_ = {}; + const aom_codec_flags_t flags = 0; + + const auto init_result = + aom_codec_enc_init(&encoder_, aom_codec_av1_cx(), &config_, flags); + OSP_CHECK_EQ(init_result, AOM_CODEC_OK); + + // Raise the threshold for considering macroblocks as static. The default is + // zero, so this setting makes the encoder less sensitive to motion. This + // lowers the probability of needing to utilize more CPU to search for + // motion vectors. + const auto ctl_result = + aom_codec_control(&encoder_, AOME_SET_STATIC_THRESHOLD, 1); + OSP_CHECK_EQ(ctl_result, AOM_CODEC_OK); + + // Ensure the speed will be set (below). + current_speed_setting_ = ~speed; + } else if (static_cast<int>(config_.rc_target_bitrate) != target_kbps || + static_cast<int>(config_.rc_min_quantizer) != min_quantizer) { + config_.rc_target_bitrate = target_kbps; + config_.rc_min_quantizer = min_quantizer; + const auto update_config_result = + aom_codec_enc_config_set(&encoder_, &config_); + OSP_CHECK_EQ(update_config_result, AOM_CODEC_OK); + } + + if (current_speed_setting_ != speed) { + const auto ctl_result = + aom_codec_control(&encoder_, AOME_SET_CPUUSED, speed); + OSP_CHECK_EQ(ctl_result, AOM_CODEC_OK); + current_speed_setting_ = speed; + } +} + +void StreamingAv1Encoder::EncodeFrame(bool force_key_frame, + WorkUnitWithResults& work_unit) { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + // The presentation timestamp argument here is fixed to zero to force the + // encoder to base its single-frame bandwidth calculations entirely on + // |frame_duration| and the target bitrate setting. + const aom_codec_pts_t pts = 0; + const aom_enc_frame_flags_t flags = force_key_frame ? AOM_EFLAG_FORCE_KF : 0; + const auto encode_result = aom_codec_encode( + &encoder_, work_unit.image.get(), pts, work_unit.duration.count(), flags); + OSP_CHECK_EQ(encode_result, AOM_CODEC_OK); + + const aom_codec_cx_pkt_t* pkt; + for (aom_codec_iter_t iter = nullptr;;) { + pkt = aom_codec_get_cx_data(&encoder_, &iter); + // aom_codec_get_cx_data() returns null once the "iteration" is complete. + // However, that point should never be reached because a + // AOM_CODEC_CX_FRAME_PKT must be encountered before that. + OSP_CHECK(pkt); + if (pkt->kind == AOM_CODEC_CX_FRAME_PKT) { + break; + } + } + + // A copy of the payload data is being made here. That's okay since it has to + // be copied at some point anyway, to be passed back to the main thread. + auto* const begin = static_cast<const uint8_t*>(pkt->data.frame.buf); + auto* const end = begin + pkt->data.frame.sz; + work_unit.payload.assign(begin, end); + work_unit.is_key_frame = !!(pkt->data.frame.flags & AOM_FRAME_IS_KEY); +} + +void StreamingAv1Encoder::ComputeFrameEncodeStats( + Clock::duration encode_wall_time, + int target_bitrate, + WorkUnitWithResults& work_unit) { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + Stats& stats = work_unit.stats; + + // Note: stats.frame_id is set later, in SendEncodedFrame(). + stats.rtp_timestamp = work_unit.rtp_timestamp; + stats.encode_wall_time = encode_wall_time; + stats.frame_duration = work_unit.duration; + stats.encoded_size = work_unit.payload.size(); + + constexpr double kBytesPerBit = 1.0 / CHAR_BIT; + constexpr double kSecondsPerClockTick = + 1.0 / Clock::to_duration(seconds(1)).count(); + const double target_bytes_per_clock_tick = + target_bitrate * (kBytesPerBit * kSecondsPerClockTick); + stats.target_size = target_bytes_per_clock_tick * work_unit.duration.count(); + + // The quantizer the encoder used. This is the result of the AV1 encoder + // taking a guess at what quantizer value would produce an encoded frame size + // as close to the target as possible. + const auto get_quantizer_result = aom_codec_control( + &encoder_, AOME_GET_LAST_QUANTIZER_64, &stats.quantizer); + OSP_CHECK_EQ(get_quantizer_result, AOM_CODEC_OK); + + // Now that the frame has been encoded and the number of bytes is known, the + // perfect quantizer value (i.e., the one that should have been used) can be + // determined. + stats.perfect_quantizer = stats.quantizer * stats.space_utilization(); +} + +void StreamingAv1Encoder::SendEncodedFrame(WorkUnitWithResults results) { + OSP_DCHECK(main_task_runner_->IsRunningOnTaskRunner()); + + EncodedFrame frame; + frame.frame_id = sender_->GetNextFrameId(); + if (results.is_key_frame) { + frame.dependency = EncodedFrame::KEY_FRAME; + frame.referenced_frame_id = frame.frame_id; + } else { + frame.dependency = EncodedFrame::DEPENDS_ON_ANOTHER; + frame.referenced_frame_id = frame.frame_id - 1; + } + frame.rtp_timestamp = results.rtp_timestamp; + frame.reference_time = results.reference_time; + frame.data = absl::Span<uint8_t>(results.payload); + + if (sender_->EnqueueFrame(frame) != Sender::OK) { + // Since the frame will not be sent, the encoder's frame dependency chain + // has been broken. Force a key frame for the next frame. + std::unique_lock<std::mutex> lock(mutex_); + needs_key_frame_ = true; + } + + if (results.stats_callback) { + results.stats.frame_id = frame.frame_id; + results.stats_callback(results.stats); + } +} + +// static +StreamingAv1Encoder::Av1ImageUniquePtr StreamingAv1Encoder::CloneAsAv1Image( + const VideoFrame& frame) { + OSP_DCHECK_GE(frame.width, 0); + OSP_DCHECK_GE(frame.height, 0); + OSP_DCHECK_GE(frame.yuv_strides[0], 0); + OSP_DCHECK_GE(frame.yuv_strides[1], 0); + OSP_DCHECK_GE(frame.yuv_strides[2], 0); + + constexpr int kAlignment = 32; + Av1ImageUniquePtr image(aom_img_alloc(nullptr, AOM_IMG_FMT_I420, frame.width, + frame.height, kAlignment)); + OSP_CHECK(image); + + CopyPlane(frame.yuv_planes[0], frame.yuv_strides[0], frame.height, + image->planes[AOM_PLANE_Y], image->stride[AOM_PLANE_Y]); + CopyPlane(frame.yuv_planes[1], frame.yuv_strides[1], (frame.height + 1) / 2, + image->planes[AOM_PLANE_U], image->stride[AOM_PLANE_U]); + CopyPlane(frame.yuv_planes[2], frame.yuv_strides[2], (frame.height + 1) / 2, + image->planes[AOM_PLANE_V], image->stride[AOM_PLANE_V]); + + return image; +} + +} // namespace cast +} // namespace openscreen diff --git a/cast/standalone_sender/streaming_av1_encoder.h b/cast/standalone_sender/streaming_av1_encoder.h new file mode 100644 index 00000000..c40ab019 --- /dev/null +++ b/cast/standalone_sender/streaming_av1_encoder.h @@ -0,0 +1,169 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_STREAMING_AV1_ENCODER_H_ +#define CAST_STANDALONE_SENDER_STREAMING_AV1_ENCODER_H_ + +#include <aom/aom_encoder.h> +#include <aom/aom_image.h> + +#include <algorithm> +#include <condition_variable> // NOLINT +#include <functional> +#include <memory> +#include <mutex> +#include <queue> +#include <thread> +#include <vector> + +#include "absl/base/thread_annotations.h" +#include "cast/standalone_sender/streaming_video_encoder.h" +#include "cast/streaming/constants.h" +#include "cast/streaming/frame_id.h" +#include "cast/streaming/rtp_time.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" + +namespace openscreen { + +class TaskRunner; + +namespace cast { + +class Sender; + +// Uses libaom to encode AV1 video and streams it to a Sender. Includes +// extensive logic for fine-tuning the encoder parameters in real-time, to +// provide the best quality results given external, uncontrollable factors: +// CPU/network availability, and the complexity of the video frame content. +// +// Internally, a separate encode thread is created and used to prevent blocking +// the main thread while frames are being encoded. All public API methods are +// assumed to be called on the same sequence/thread as the main TaskRunner +// (injected via the constructor). +// +// Usage: +// +// 1. EncodeAndSend() is used to queue-up video frames for encoding and sending, +// which will be done on a best-effort basis. +// +// 2. The client is expected to call SetTargetBitrate() frequently based on its +// own bandwidth estimates and congestion control logic. In addition, a client +// may provide a callback for each frame's encode statistics, which can be used +// to further optimize the user experience. For example, the stats can be used +// as a signal to reduce the data volume (i.e., resolution and/or frame rate) +// coming from the video capture source. +class StreamingAv1Encoder : public StreamingVideoEncoder { + public: + StreamingAv1Encoder(const Parameters& params, + TaskRunner* task_runner, + Sender* sender); + + ~StreamingAv1Encoder(); + + int GetTargetBitrate() const override; + void SetTargetBitrate(int new_bitrate) override; + void EncodeAndSend(const VideoFrame& frame, + Clock::time_point reference_time, + std::function<void(Stats)> stats_callback) override; + + private: + // Syntactic convenience to wrap the aom_image_t alloc/free API in a smart + // pointer. + struct Av1ImageDeleter { + void operator()(aom_image_t* ptr) const { aom_img_free(ptr); } + }; + using Av1ImageUniquePtr = std::unique_ptr<aom_image_t, Av1ImageDeleter>; + + // Represents the state of one frame encode. This is created in + // EncodeAndSend(), and passed to the encode thread via the |encode_queue_|. + struct WorkUnit { + Av1ImageUniquePtr image; + Clock::duration duration; + Clock::time_point reference_time; + RtpTimeTicks rtp_timestamp; + std::function<void(Stats)> stats_callback; + }; + + // Same as WorkUnit, but with additional fields to carry the encode results. + struct WorkUnitWithResults : public WorkUnit { + std::vector<uint8_t> payload; + bool is_key_frame = false; + Stats stats; + }; + + bool is_encoder_initialized() const { return config_.g_threads != 0; } + + // Destroys the AV1 encoder context if it has been initialized. + void DestroyEncoder(); + + // The procedure for the |encode_thread_| that loops, processing work units + // from the |encode_queue_| by calling Encode() until it's time to end the + // thread. + void ProcessWorkUnitsUntilTimeToQuit(); + + // If the |encoder_| is live, attempt reconfiguration to allow it to encode + // frames at a new frame size or target bitrate. If reconfiguration is not + // possible, destroy the existing instance and re-create a new |encoder_| + // instance. + void PrepareEncoder(int width, int height, int target_bitrate); + + // Wraps the complex libaom aom_codec_encode() call using inputs from + // |work_unit| and populating results there. + void EncodeFrame(bool force_key_frame, WorkUnitWithResults& work_unit); + + // Computes and populates |work_unit.stats| after the last call to + // EncodeFrame(). + void ComputeFrameEncodeStats(Clock::duration encode_wall_time, + int target_bitrate, + WorkUnitWithResults& work_unit); + + // Assembles and enqueues an EncodedFrame with the Sender on the main thread. + void SendEncodedFrame(WorkUnitWithResults results); + + // Allocates a aom_image_t and copies the content from |frame| to it. + static Av1ImageUniquePtr CloneAsAv1Image(const VideoFrame& frame); + + // The reference time of the first frame passed to EncodeAndSend(). + Clock::time_point start_time_ = Clock::time_point::min(); + + // The RTP timestamp of the last frame that was pushed into the + // |encode_queue_| by EncodeAndSend(). This is used to check whether + // timestamps are monotonically increasing. + RtpTimeTicks last_enqueued_rtp_timestamp_; + + // Guards a few members shared by both the main and encode threads. + std::mutex mutex_; + + // Used by the encode thread to sleep until more work is available. + std::condition_variable cv_ ABSL_GUARDED_BY(mutex_); + + // These encode parameters not passed in the WorkUnit struct because it is + // desirable for them to be applied as soon as possible, with the very next + // WorkUnit popped from the |encode_queue_| on the encode thread, and not to + // wait until some later WorkUnit is processed. + bool needs_key_frame_ ABSL_GUARDED_BY(mutex_) = true; + int target_bitrate_ ABSL_GUARDED_BY(mutex_) = 2 << 20; // Default: 2 Mbps. + + // The queue of frame encodes. The size of this queue is implicitly bounded by + // EncodeAndSend(), where it checks for the total in-flight media duration and + // maybe drops a frame. + std::queue<WorkUnit> encode_queue_ ABSL_GUARDED_BY(mutex_); + + // Current AV1 encoder configuration. Most of the fields are unchanging, and + // are populated in the ctor; but thereafter, only the encode thread accesses + // this struct. + // + // The speed setting is controlled via a separate libaom API (see members + // below). + aom_codec_enc_cfg_t config_{}; + + // libaom AV1 encoder instance. Only the encode thread accesses this. + aom_codec_ctx_t encoder_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_STREAMING_AV1_ENCODER_H_ diff --git a/cast/standalone_sender/streaming_encoder_util.cc b/cast/standalone_sender/streaming_encoder_util.cc new file mode 100644 index 00000000..9ead2bd9 --- /dev/null +++ b/cast/standalone_sender/streaming_encoder_util.cc @@ -0,0 +1,30 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_sender/streaming_encoder_util.h" + +#include <string.h> + +#include <algorithm> + +namespace openscreen { +namespace cast { +void CopyPlane(const uint8_t* src, + int src_stride, + int num_rows, + uint8_t* dst, + int dst_stride) { + if (src_stride == dst_stride) { + memcpy(dst, src, src_stride * num_rows); + return; + } + const int bytes_per_row = std::min(src_stride, dst_stride); + while (--num_rows >= 0) { + memcpy(dst, src, bytes_per_row); + dst += dst_stride; + src += src_stride; + } +} +} // namespace cast +} // namespace openscreen diff --git a/cast/standalone_sender/streaming_encoder_util.h b/cast/standalone_sender/streaming_encoder_util.h new file mode 100644 index 00000000..d4d00b42 --- /dev/null +++ b/cast/standalone_sender/streaming_encoder_util.h @@ -0,0 +1,20 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_STREAMING_ENCODER_UTIL_H_ +#define CAST_STANDALONE_SENDER_STREAMING_ENCODER_UTIL_H_ + +#include <stdint.h> + +namespace openscreen { +namespace cast { +void CopyPlane(const uint8_t* src, + int src_stride, + int num_rows, + uint8_t* dst, + int dst_stride); +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_STREAMING_ENCODER_UTIL_H_ diff --git a/cast/standalone_sender/streaming_video_encoder.cc b/cast/standalone_sender/streaming_video_encoder.cc new file mode 100644 index 00000000..0e15ab2c --- /dev/null +++ b/cast/standalone_sender/streaming_video_encoder.cc @@ -0,0 +1,57 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/standalone_sender/streaming_video_encoder.h" + +#include "util/chrono_helpers.h" + +namespace openscreen { +namespace cast { + +StreamingVideoEncoder::StreamingVideoEncoder(const Parameters& params, + TaskRunner* task_runner, + Sender* sender) + : params_(params), main_task_runner_(task_runner), sender_(sender) { + OSP_DCHECK_LE(1, params_.num_encode_threads); + OSP_DCHECK_LE(kMinQuantizer, params_.min_quantizer); + OSP_DCHECK_LE(params_.min_quantizer, params_.max_cpu_saver_quantizer); + OSP_DCHECK_LE(params_.max_cpu_saver_quantizer, params_.max_quantizer); + OSP_DCHECK_LE(params_.max_quantizer, kMaxQuantizer); + OSP_DCHECK_LT(0.0, params_.max_time_utilization); + OSP_DCHECK_LE(params_.max_time_utilization, 1.0); + OSP_DCHECK(main_task_runner_); + OSP_DCHECK(sender_); +} + +StreamingVideoEncoder::~StreamingVideoEncoder() {} + +void StreamingVideoEncoder::UpdateSpeedSettingForNextFrame(const Stats& stats) { + OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); + + // Combine the speed setting that was used to encode the last frame, and the + // quantizer the encoder chose into a single speed metric. + const double speed = current_speed_setting_ + + kEquivalentEncodingSpeedStepPerQuantizerStep * + std::max(0, stats.quantizer - params_.min_quantizer); + + // Like |Stats::perfect_quantizer|, this computes a "hindsight" speed setting + // for the last frame, one that may have potentially allowed for a + // better-quality quantizer choice by the encoder, while also keeping CPU + // utilization within budget. + const double perfect_speed = + speed * stats.time_utilization() / params_.max_time_utilization; + + // Update the ideal speed setting, to be used for the next frame. An + // exponentially-decaying weighted average is used here to smooth-out noise. + // The weight is based on the duration of the frame that was encoded. + constexpr Clock::duration kDecayHalfLife = milliseconds(120); + const double ticks = stats.frame_duration.count(); + const double weight = ticks / (ticks + kDecayHalfLife.count()); + ideal_speed_setting_ = + weight * perfect_speed + (1.0 - weight) * ideal_speed_setting_; + OSP_DCHECK(std::isfinite(ideal_speed_setting_)); +} + +} // namespace cast +} // namespace openscreen diff --git a/cast/standalone_sender/streaming_video_encoder.h b/cast/standalone_sender/streaming_video_encoder.h new file mode 100644 index 00000000..52fae9cc --- /dev/null +++ b/cast/standalone_sender/streaming_video_encoder.h @@ -0,0 +1,194 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_STREAMING_VIDEO_ENCODER_H_ +#define CAST_STANDALONE_SENDER_STREAMING_VIDEO_ENCODER_H_ + +#include <algorithm> +#include <condition_variable> // NOLINT +#include <functional> +#include <memory> +#include <mutex> +#include <queue> +#include <thread> +#include <vector> + +#include "absl/base/thread_annotations.h" +#include "cast/streaming/constants.h" +#include "cast/streaming/frame_id.h" +#include "cast/streaming/rtp_time.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" + +namespace openscreen { + +class TaskRunner; + +namespace cast { + +class Sender; + +class StreamingVideoEncoder { + public: + // Configurable parameters passed to the StreamingVpxEncoder constructor. + struct Parameters { + // Number of threads to parallelize frame encoding. This should be set based + // on the number of CPU cores available for encoding, but no more than 8. + int num_encode_threads = + std::min(std::max<int>(std::thread::hardware_concurrency(), 1), 8); + + // Best-quality quantizer (lower is better quality). Range: [0,63] + int min_quantizer = 4; + + // Worst-quality quantizer (lower is better quality). Range: [0,63] + int max_quantizer = kMaxQuantizer; + + // Worst-quality quantizer to use when the CPU is extremely constrained. + // Range: [min_quantizer,max_quantizer] + int max_cpu_saver_quantizer = 25; + + // Maximum amount of wall-time a frame's encode can take, relative to the + // frame's duration, before the CPU-saver logic is activated. The default + // (70%) is appropriate for systems with four or more cores, but should be + // reduced (e.g., 50%) for systems with fewer than three cores. + // + // Example: For 30 FPS (continuous) video, the frame duration is ~33.3ms, + // and a value of 0.5 here would mean that the CPU-saver logic starts + // sacrificing quality when frame encodes start taking longer than ~16.7ms. + double max_time_utilization = 0.7; + + // Determines which codec (VP8, VP9, or AV1) is to be used for encoding. + // Defaults to VP8. + VideoCodec codec = VideoCodec::kVp8; + }; + + // Represents an input VideoFrame, passed to EncodeAndSend(). + struct VideoFrame { + // Image width and height. + int width = 0; + int height = 0; + + // I420 format image pointers and row strides (the number of bytes between + // the start of successive rows). The pointers only need to remain valid + // until the EncodeAndSend() call returns. + const uint8_t* yuv_planes[3] = {}; + int yuv_strides[3] = {}; + + // How long this frame will be held before the next frame will be displayed, + // or zero if unknown. The frame duration is passed to the video codec, + // affecting a number of important behaviors, including: per-frame + // bandwidth, CPU time spent encoding, temporal quality trade-offs, and + // key/golden/alt-ref frame generation intervals. + Clock::duration duration; + }; + + // Performance statistics for a single frame's encode. + // + // For full details on how to use these stats in an end-to-end system, see: + // https://www.chromium.org/developers/design-documents/ + // auto-throttled-screen-capture-and-mirroring + // and https://source.chromium.org/chromium/chromium/src/+/master: + // media/cast/sender/performance_metrics_overlay.h + struct Stats { + // The Cast Streaming ID that was assigned to the frame. + FrameId frame_id; + + // The RTP timestamp of the frame. + RtpTimeTicks rtp_timestamp; + + // How long the frame took to encode. This is wall time, not CPU time or + // some other load metric. + Clock::duration encode_wall_time; + + // The frame's predicted duration; or, the actual duration if it was + // provided in the VideoFrame. + Clock::duration frame_duration; + + // The encoded frame's size in bytes. + int encoded_size = 0; + + // The average size of an encoded frame in bytes, having this + // |frame_duration| and current target bitrate. + double target_size = 0.0; + + // The actual quantizer the video encoder used, in the range [0,63]. + int quantizer = 0; + + // The "hindsight" quantizer value that would have produced the best quality + // encoding of the frame at the current target bitrate. The nominal range is + // [0.0,63.0]. If it is larger than 63.0, then it was impossible to + // encode the frame within the current target bitrate (e.g., too much + // "entropy" in the image, or too low a target bitrate). + double perfect_quantizer = 0.0; + + // Utilization feedback metrics. The nominal range for each of these is + // [0.0,1.0] where 1.0 means "the entire budget available for the frame was + // exhausted." Going above 1.0 is okay for one or a few frames, since it's + // the average over many frames that matters before the system is considered + // "redlining." + // + // The max of these three provides an overall utilization control signal. + // The usual approach is for upstream control logic to increase/decrease the + // data volume (e.g., video resolution and/or frame rate) to maintain a good + // target point. + double time_utilization() const { + return static_cast<double>(encode_wall_time.count()) / + frame_duration.count(); + } + double space_utilization() const { return encoded_size / target_size; } + double entropy_utilization() const { + return perfect_quantizer / kMaxQuantizer; + } + }; + + virtual ~StreamingVideoEncoder(); + + // Get/Set the target bitrate. This may be changed at any time, as frequently + // as desired, and it will take effect internally as soon as possible. + virtual int GetTargetBitrate() const = 0; + virtual void SetTargetBitrate(int new_bitrate) = 0; + + // Encode |frame| using the video encoder, assemble an EncodedFrame, and + // enqueue into the Sender. The frame may be dropped if too many frames are + // in-flight. If provided, the |stats_callback| is run after the frame is + // enqueued in the Sender (via the main TaskRunner). + virtual void EncodeAndSend(const VideoFrame& frame, + Clock::time_point reference_time, + std::function<void(Stats)> stats_callback) = 0; + + static constexpr int kMinQuantizer = 0; + static constexpr int kMaxQuantizer = 63; + + protected: + StreamingVideoEncoder(const Parameters& params, + TaskRunner* task_runner, + Sender* sender); + + // This is the equivalent change in encoding speed per one quantizer step. + static constexpr double kEquivalentEncodingSpeedStepPerQuantizerStep = + 1 / 20.0; + + // Updates the |ideal_speed_setting_|, to take effect with the next frame + // encode, based on the given performance |stats|. + void UpdateSpeedSettingForNextFrame(const Stats& stats); + + const Parameters params_; + TaskRunner* const main_task_runner_; + Sender* const sender_; + + // These represent the magnitude of the AV1 speed setting, where larger values + // (i.e., faster speed) request less CPU usage but will provide lower video + // quality. Only the encode thread accesses these. + double ideal_speed_setting_; // A time-weighted average, from measurements. + int current_speed_setting_; // Current |encoder_| speed setting. + + // This member should be last in the class since the thread should not start + // until all above members have been initialized by the constructor. + std::thread encode_thread_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_STREAMING_VIDEO_ENCODER_H_ diff --git a/cast/standalone_sender/streaming_vp8_encoder.h b/cast/standalone_sender/streaming_vp8_encoder.h deleted file mode 100644 index c5d52248..00000000 --- a/cast/standalone_sender/streaming_vp8_encoder.h +++ /dev/null @@ -1,302 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef CAST_STANDALONE_SENDER_STREAMING_VP8_ENCODER_H_ -#define CAST_STANDALONE_SENDER_STREAMING_VP8_ENCODER_H_ - -#include <vpx/vpx_encoder.h> -#include <vpx/vpx_image.h> - -#include <algorithm> -#include <condition_variable> // NOLINT -#include <functional> -#include <memory> -#include <mutex> -#include <queue> -#include <thread> -#include <vector> - -#include "absl/base/thread_annotations.h" -#include "cast/streaming/frame_id.h" -#include "cast/streaming/rtp_time.h" -#include "platform/api/task_runner.h" -#include "platform/api/time.h" - -namespace openscreen { - -class TaskRunner; - -namespace cast { - -class Sender; - -// Uses libvpx to encode VP8 video and streams it to a Sender. Includes -// extensive logic for fine-tuning the encoder parameters in real-time, to -// provide the best quality results given external, uncontrollable factors: -// CPU/network availability, and the complexity of the video frame content. -// -// Internally, a separate encode thread is created and used to prevent blocking -// the main thread while frames are being encoded. All public API methods are -// assumed to be called on the same sequence/thread as the main TaskRunner -// (injected via the constructor). -// -// Usage: -// -// 1. EncodeAndSend() is used to queue-up video frames for encoding and sending, -// which will be done on a best-effort basis. -// -// 2. The client is expected to call SetTargetBitrate() frequently based on its -// own bandwidth estimates and congestion control logic. In addition, a client -// may provide a callback for each frame's encode statistics, which can be used -// to further optimize the user experience. For example, the stats can be used -// as a signal to reduce the data volume (i.e., resolution and/or frame rate) -// coming from the video capture source. -class StreamingVp8Encoder { - public: - // Configurable parameters passed to the StreamingVp8Encoder constructor. - struct Parameters { - // Number of threads to parallelize frame encoding. This should be set based - // on the number of CPU cores available for encoding, but no more than 8. - int num_encode_threads = - std::min(std::max<int>(std::thread::hardware_concurrency(), 1), 8); - - // Best-quality quantizer (lower is better quality). Range: [0,63] - int min_quantizer = 4; - - // Worst-quality quantizer (lower is better quality). Range: [0,63] - int max_quantizer = 63; - - // Worst-quality quantizer to use when the CPU is extremely constrained. - // Range: [min_quantizer,max_quantizer] - int max_cpu_saver_quantizer = 25; - - // Maximum amount of wall-time a frame's encode can take, relative to the - // frame's duration, before the CPU-saver logic is activated. The default - // (70%) is appropriate for systems with four or more cores, but should be - // reduced (e.g., 50%) for systems with fewer than three cores. - // - // Example: For 30 FPS (continuous) video, the frame duration is ~33.3ms, - // and a value of 0.5 here would mean that the CPU-saver logic starts - // sacrificing quality when frame encodes start taking longer than ~16.7ms. - double max_time_utilization = 0.7; - }; - - // Represents an input VideoFrame, passed to EncodeAndSend(). - struct VideoFrame { - // Image width and height. - int width; - int height; - - // I420 format image pointers and row strides (the number of bytes between - // the start of successive rows). The pointers only need to remain valid - // until the EncodeAndSend() call returns. - const uint8_t* yuv_planes[3]; - int yuv_strides[3]; - - // How long this frame will be held before the next frame will be displayed, - // or zero if unknown. The frame duration is passed to the VP8 codec, - // affecting a number of important behaviors, including: per-frame - // bandwidth, CPU time spent encoding, temporal quality trade-offs, and - // key/golden/alt-ref frame generation intervals. - Clock::duration duration; - }; - - // Performance statistics for a single frame's encode. - // - // For full details on how to use these stats in an end-to-end system, see: - // https://www.chromium.org/developers/design-documents/ - // auto-throttled-screen-capture-and-mirroring - // and https://source.chromium.org/chromium/chromium/src/+/master: - // media/cast/sender/performance_metrics_overlay.h - struct Stats { - // The Cast Streaming ID that was assigned to the frame. - FrameId frame_id; - - // The RTP timestamp of the frame. - RtpTimeTicks rtp_timestamp; - - // How long the frame took to encode. This is wall time, not CPU time or - // some other load metric. - Clock::duration encode_wall_time; - - // The frame's predicted duration; or, the actual duration if it was - // provided in the VideoFrame. - Clock::duration frame_duration; - - // The encoded frame's size in bytes. - int encoded_size; - - // The average size of an encoded frame in bytes, having this - // |frame_duration| and current target bitrate. - double target_size; - - // The actual quantizer the VP8 encoder used, in the range [0,63]. - int quantizer; - - // The "hindsight" quantizer value that would have produced the best quality - // encoding of the frame at the current target bitrate. The nominal range is - // [0.0,63.0]. If it is larger than 63.0, then it was impossible for VP8 to - // encode the frame within the current target bitrate (e.g., too much - // "entropy" in the image, or too low a target bitrate). - double perfect_quantizer; - - // Utilization feedback metrics. The nominal range for each of these is - // [0.0,1.0] where 1.0 means "the entire budget available for the frame was - // exhausted." Going above 1.0 is okay for one or a few frames, since it's - // the average over many frames that matters before the system is considered - // "redlining." - // - // The max of these three provides an overall utilization control signal. - // The usual approach is for upstream control logic to increase/decrease the - // data volume (e.g., video resolution and/or frame rate) to maintain a good - // target point. - double time_utilization() const { - return static_cast<double>(encode_wall_time.count()) / - frame_duration.count(); - } - double space_utilization() const { return encoded_size / target_size; } - double entropy_utilization() const { - return perfect_quantizer / kMaxQuantizer; - } - }; - - StreamingVp8Encoder(const Parameters& params, - TaskRunner* task_runner, - Sender* sender); - - ~StreamingVp8Encoder(); - - // Get/Set the target bitrate. This may be changed at any time, as frequently - // as desired, and it will take effect internally as soon as possible. - int GetTargetBitrate() const; - void SetTargetBitrate(int new_bitrate); - - // Encode |frame| using the VP8 encoder, assemble an EncodedFrame, and enqueue - // into the Sender. The frame may be dropped if too many frames are in-flight. - // If provided, the |stats_callback| is run after the frame is enqueued in the - // Sender (via the main TaskRunner). - void EncodeAndSend(const VideoFrame& frame, - Clock::time_point reference_time, - std::function<void(Stats)> stats_callback); - - static constexpr int kMinQuantizer = 0; - static constexpr int kMaxQuantizer = 63; - - private: - // Syntactic convenience to wrap the vpx_image_t alloc/free API in a smart - // pointer. - struct VpxImageDeleter { - void operator()(vpx_image_t* ptr) const { vpx_img_free(ptr); } - }; - using VpxImageUniquePtr = std::unique_ptr<vpx_image_t, VpxImageDeleter>; - - // Represents the state of one frame encode. This is created in - // EncodeAndSend(), and passed to the encode thread via the |encode_queue_|. - struct WorkUnit { - VpxImageUniquePtr image; - Clock::duration duration; - Clock::time_point reference_time; - RtpTimeTicks rtp_timestamp; - std::function<void(Stats)> stats_callback; - }; - - // Same as WorkUnit, but with additional fields to carry the encode results. - struct WorkUnitWithResults : public WorkUnit { - std::vector<uint8_t> payload; - bool is_key_frame; - Stats stats; - }; - - bool is_encoder_initialized() const { return config_.g_threads != 0; } - - // Destroys the VP8 encoder context if it has been initialized. - void DestroyEncoder(); - - // The procedure for the |encode_thread_| that loops, processing work units - // from the |encode_queue_| by calling Encode() until it's time to end the - // thread. - void ProcessWorkUnitsUntilTimeToQuit(); - - // If the |encoder_| is live, attempt reconfiguration to allow it to encode - // frames at a new frame size, target bitrate, or "CPU encoding speed." If - // reconfiguration is not possible, destroy the existing instance and - // re-create a new |encoder_| instance. - void PrepareEncoder(int width, int height, int target_bitrate); - - // Wraps the complex libvpx vpx_codec_encode() call using inputs from - // |work_unit| and populating results there. - void EncodeFrame(bool force_key_frame, WorkUnitWithResults* work_unit); - - // Computes and populates |work_unit.stats| after the last call to - // EncodeFrame(). - void ComputeFrameEncodeStats(Clock::duration encode_wall_time, - int target_bitrate, - WorkUnitWithResults* work_unit); - - // Updates the |ideal_speed_setting_|, to take effect with the next frame - // encode, based on the given performance |stats|. - void UpdateSpeedSettingForNextFrame(const Stats& stats); - - // Assembles and enqueues an EncodedFrame with the Sender on the main thread. - void SendEncodedFrame(WorkUnitWithResults results); - - // Allocates a vpx_image_t and copies the content from |frame| to it. - static VpxImageUniquePtr CloneAsVpxImage(const VideoFrame& frame); - - const Parameters params_; - TaskRunner* const main_task_runner_; - Sender* const sender_; - - // The reference time of the first frame passed to EncodeAndSend(). - Clock::time_point start_time_ = Clock::time_point::min(); - - // The RTP timestamp of the last frame that was pushed into the - // |encode_queue_| by EncodeAndSend(). This is used to check whether - // timestamps are monotonically increasing. - RtpTimeTicks last_enqueued_rtp_timestamp_; - - // Guards a few members shared by both the main and encode threads. - std::mutex mutex_; - - // Used by the encode thread to sleep until more work is available. - std::condition_variable cv_ ABSL_GUARDED_BY(mutex_); - - // These encode parameters not passed in the WorkUnit struct because it is - // desirable for them to be applied as soon as possible, with the very next - // WorkUnit popped from the |encode_queue_| on the encode thread, and not to - // wait until some later WorkUnit is processed. - bool needs_key_frame_ ABSL_GUARDED_BY(mutex_) = true; - int target_bitrate_ ABSL_GUARDED_BY(mutex_) = 2 << 20; // Default: 2 Mbps. - - // The queue of frame encodes. The size of this queue is implicitly bounded by - // EncodeAndSend(), where it checks for the total in-flight media duration and - // maybe drops a frame. - std::queue<WorkUnit> encode_queue_ ABSL_GUARDED_BY(mutex_); - - // Current VP8 encoder configuration. Most of the fields are unchanging, and - // are populated in the ctor; but thereafter, only the encode thread accesses - // this struct. - // - // The speed setting is controlled via a separate libvpx API (see members - // below). - vpx_codec_enc_cfg_t config_{}; - - // These represent the magnitude of the VP8 speed setting, where larger values - // (i.e., faster speed) request less CPU usage but will provide lower video - // quality. Only the encode thread accesses these. - double ideal_speed_setting_; // A time-weighted average, from measurements. - int current_speed_setting_; // Current |encoder_| speed setting. - - // libvpx VP8 encoder instance. Only the encode thread accesses this. - vpx_codec_ctx_t encoder_; - - // This member should be last in the class since the thread should not start - // until all above members have been initialized by the constructor. - std::thread encode_thread_; -}; - -} // namespace cast -} // namespace openscreen - -#endif // CAST_STANDALONE_SENDER_STREAMING_VP8_ENCODER_H_ diff --git a/cast/standalone_sender/streaming_vp8_encoder.cc b/cast/standalone_sender/streaming_vpx_encoder.cc index 8b8e18dc..1b10f92b 100644 --- a/cast/standalone_sender/streaming_vp8_encoder.cc +++ b/cast/standalone_sender/streaming_vpx_encoder.cc @@ -2,16 +2,15 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "cast/standalone_sender/streaming_vp8_encoder.h" +#include "cast/standalone_sender/streaming_vpx_encoder.h" -#include <stdint.h> -#include <string.h> #include <vpx/vp8cx.h> #include <chrono> #include <cmath> #include <utility> +#include "cast/standalone_sender/streaming_encoder_util.h" #include "cast/streaming/encoded_frame.h" #include "cast/streaming/environment.h" #include "cast/streaming/sender.h" @@ -22,8 +21,8 @@ namespace openscreen { namespace cast { -// TODO(https://crbug.com/openscreen/123): Fix the declarations and then remove -// this: +// TODO(issuetracker.google.com/issues/155336511): Fix the declarations and then +// remove this: using openscreen::operator<<; // For std::chrono::duration pretty-printing. namespace { @@ -44,31 +43,24 @@ constexpr Clock::duration kMaxFrameDuration = milliseconds(125); constexpr int kHighestEncodingSpeed = 12; constexpr int kLowestEncodingSpeed = 6; -// This is the equivalent change in encoding speed per one quantizer step. -constexpr double kEquivalentEncodingSpeedStepPerQuantizerStep = 1 / 20.0; - } // namespace -StreamingVp8Encoder::StreamingVp8Encoder(const Parameters& params, +StreamingVpxEncoder::StreamingVpxEncoder(const Parameters& params, TaskRunner* task_runner, Sender* sender) - : params_(params), - main_task_runner_(task_runner), - sender_(sender), - ideal_speed_setting_(kHighestEncodingSpeed), - encode_thread_([this] { ProcessWorkUnitsUntilTimeToQuit(); }) { - OSP_DCHECK_LE(1, params_.num_encode_threads); - OSP_DCHECK_LE(kMinQuantizer, params_.min_quantizer); - OSP_DCHECK_LE(params_.min_quantizer, params_.max_cpu_saver_quantizer); - OSP_DCHECK_LE(params_.max_cpu_saver_quantizer, params_.max_quantizer); - OSP_DCHECK_LE(params_.max_quantizer, kMaxQuantizer); - OSP_DCHECK_LT(0.0, params_.max_time_utilization); - OSP_DCHECK_LE(params_.max_time_utilization, 1.0); - OSP_DCHECK(main_task_runner_); - OSP_DCHECK(sender_); - - const auto result = - vpx_codec_enc_config_default(vpx_codec_vp8_cx(), &config_, 0); + : StreamingVideoEncoder(params, task_runner, sender) { + ideal_speed_setting_ = kHighestEncodingSpeed; + encode_thread_ = std::thread([this] { ProcessWorkUnitsUntilTimeToQuit(); }); + + vpx_codec_iface_t* ctx; + if (params_.codec == VideoCodec::kVp9) { + ctx = vpx_codec_vp9_cx(); + } else { + OSP_DCHECK(params_.codec == VideoCodec::kVp8); + ctx = vpx_codec_vp8_cx(); + } + + const auto result = vpx_codec_enc_config_default(ctx, &config_, 0); OSP_CHECK_EQ(result, VPX_CODEC_OK); // This is set to non-zero in ConfigureForNewFrameSize() later, to flag that @@ -104,7 +96,7 @@ StreamingVp8Encoder::StreamingVp8Encoder(const Parameters& params, config_.kf_mode = VPX_KF_DISABLED; } -StreamingVp8Encoder::~StreamingVp8Encoder() { +StreamingVpxEncoder::~StreamingVpxEncoder() { { std::unique_lock<std::mutex> lock(mutex_); target_bitrate_ = 0; @@ -113,13 +105,13 @@ StreamingVp8Encoder::~StreamingVp8Encoder() { encode_thread_.join(); } -int StreamingVp8Encoder::GetTargetBitrate() const { +int StreamingVpxEncoder::GetTargetBitrate() const { // Note: No need to lock the |mutex_| since this method should be called on // the same thread as SetTargetBitrate(). return target_bitrate_; } -void StreamingVp8Encoder::SetTargetBitrate(int new_bitrate) { +void StreamingVpxEncoder::SetTargetBitrate(int new_bitrate) { // Ensure that, when bps is converted to kbps downstream, that the encoder // bitrate will not be zero. new_bitrate = std::max(new_bitrate, kBytesPerKilobyte); @@ -132,13 +124,13 @@ void StreamingVp8Encoder::SetTargetBitrate(int new_bitrate) { } } -void StreamingVp8Encoder::EncodeAndSend( +void StreamingVpxEncoder::EncodeAndSend( const VideoFrame& frame, Clock::time_point reference_time, std::function<void(Stats)> stats_callback) { WorkUnit work_unit; - // TODO(miu): The |VideoFrame| struct should provide the media timestamp, + // TODO(jophba): The |VideoFrame| struct should provide the media timestamp, // instead of this code inferring it from the reference timestamps, since: 1) // the video capturer's clock may tick at a different rate than the system // clock; and 2) to reduce jitter. @@ -194,7 +186,7 @@ void StreamingVp8Encoder::EncodeAndSend( } } -void StreamingVp8Encoder::DestroyEncoder() { +void StreamingVpxEncoder::DestroyEncoder() { OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); if (is_encoder_initialized()) { @@ -205,7 +197,7 @@ void StreamingVp8Encoder::DestroyEncoder() { } } -void StreamingVp8Encoder::ProcessWorkUnitsUntilTimeToQuit() { +void StreamingVpxEncoder::ProcessWorkUnitsUntilTimeToQuit() { OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); for (;;) { @@ -235,9 +227,9 @@ void StreamingVp8Encoder::ProcessWorkUnitsUntilTimeToQuit() { // measured. const Clock::time_point encode_start_time = Clock::now(); PrepareEncoder(work_unit.image->d_w, work_unit.image->d_h, target_bitrate); - EncodeFrame(force_key_frame, &work_unit); + EncodeFrame(force_key_frame, work_unit); ComputeFrameEncodeStats(Clock::now() - encode_start_time, target_bitrate, - &work_unit); + work_unit); UpdateSpeedSettingForNextFrame(work_unit.stats); main_task_runner_->PostTask( @@ -249,7 +241,7 @@ void StreamingVp8Encoder::ProcessWorkUnitsUntilTimeToQuit() { DestroyEncoder(); } -void StreamingVp8Encoder::PrepareEncoder(int width, +void StreamingVpxEncoder::PrepareEncoder(int width, int height, int target_bitrate) { OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); @@ -287,8 +279,17 @@ void StreamingVp8Encoder::PrepareEncoder(int width, encoder_ = {}; const vpx_codec_flags_t flags = 0; + + vpx_codec_iface_t* ctx; + if (params_.codec == VideoCodec::kVp9) { + ctx = vpx_codec_vp9_cx(); + } else { + OSP_DCHECK(params_.codec == VideoCodec::kVp8); + ctx = vpx_codec_vp8_cx(); + } + const auto init_result = - vpx_codec_enc_init(&encoder_, vpx_codec_vp8_cx(), &config_, flags); + vpx_codec_enc_init(&encoder_, ctx, &config_, flags); OSP_CHECK_EQ(init_result, VPX_CODEC_OK); // Raise the threshold for considering macroblocks as static. The default is @@ -311,7 +312,7 @@ void StreamingVp8Encoder::PrepareEncoder(int width, } if (current_speed_setting_ != speed) { - // Pass the |speed| as a negative value to turn off VP8's automatic speed + // Pass the |speed| as a negative value to turn off VP8/9's automatic speed // selection logic and force the exact setting. const auto ctl_result = vpx_codec_control(&encoder_, VP8E_SET_CPUUSED, -speed); @@ -320,8 +321,8 @@ void StreamingVp8Encoder::PrepareEncoder(int width, } } -void StreamingVp8Encoder::EncodeFrame(bool force_key_frame, - WorkUnitWithResults* work_unit) { +void StreamingVpxEncoder::EncodeFrame(bool force_key_frame, + WorkUnitWithResults& work_unit) { OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); // The presentation timestamp argument here is fixed to zero to force the @@ -330,8 +331,8 @@ void StreamingVp8Encoder::EncodeFrame(bool force_key_frame, const vpx_codec_pts_t pts = 0; const vpx_enc_frame_flags_t flags = force_key_frame ? VPX_EFLAG_FORCE_KF : 0; const auto encode_result = - vpx_codec_encode(&encoder_, work_unit->image.get(), pts, - work_unit->duration.count(), flags, VPX_DL_REALTIME); + vpx_codec_encode(&encoder_, work_unit.image.get(), pts, + work_unit.duration.count(), flags, VPX_DL_REALTIME); OSP_CHECK_EQ(encode_result, VPX_CODEC_OK); const vpx_codec_cx_pkt_t* pkt; @@ -350,32 +351,32 @@ void StreamingVp8Encoder::EncodeFrame(bool force_key_frame, // be copied at some point anyway, to be passed back to the main thread. auto* const begin = static_cast<const uint8_t*>(pkt->data.frame.buf); auto* const end = begin + pkt->data.frame.sz; - work_unit->payload.assign(begin, end); - work_unit->is_key_frame = !!(pkt->data.frame.flags & VPX_FRAME_IS_KEY); + work_unit.payload.assign(begin, end); + work_unit.is_key_frame = !!(pkt->data.frame.flags & VPX_FRAME_IS_KEY); } -void StreamingVp8Encoder::ComputeFrameEncodeStats( +void StreamingVpxEncoder::ComputeFrameEncodeStats( Clock::duration encode_wall_time, int target_bitrate, - WorkUnitWithResults* work_unit) { + WorkUnitWithResults& work_unit) { OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); - Stats& stats = work_unit->stats; + Stats& stats = work_unit.stats; // Note: stats.frame_id is set later, in SendEncodedFrame(). - stats.rtp_timestamp = work_unit->rtp_timestamp; + stats.rtp_timestamp = work_unit.rtp_timestamp; stats.encode_wall_time = encode_wall_time; - stats.frame_duration = work_unit->duration; - stats.encoded_size = work_unit->payload.size(); + stats.frame_duration = work_unit.duration; + stats.encoded_size = work_unit.payload.size(); constexpr double kBytesPerBit = 1.0 / CHAR_BIT; constexpr double kSecondsPerClockTick = 1.0 / Clock::to_duration(seconds(1)).count(); const double target_bytes_per_clock_tick = target_bitrate * (kBytesPerBit * kSecondsPerClockTick); - stats.target_size = target_bytes_per_clock_tick * work_unit->duration.count(); + stats.target_size = target_bytes_per_clock_tick * work_unit.duration.count(); - // The quantizer the encoder used. This is the result of the VP8 encoder + // The quantizer the encoder used. This is the result of the VP8/9 encoder // taking a guess at what quantizer value would produce an encoded frame size // as close to the target as possible. const auto get_quantizer_result = vpx_codec_control( @@ -388,34 +389,7 @@ void StreamingVp8Encoder::ComputeFrameEncodeStats( stats.perfect_quantizer = stats.quantizer * stats.space_utilization(); } -void StreamingVp8Encoder::UpdateSpeedSettingForNextFrame(const Stats& stats) { - OSP_DCHECK_EQ(std::this_thread::get_id(), encode_thread_.get_id()); - - // Combine the speed setting that was used to encode the last frame, and the - // quantizer the encoder chose into a single speed metric. - const double speed = current_speed_setting_ + - kEquivalentEncodingSpeedStepPerQuantizerStep * - std::max(0, stats.quantizer - params_.min_quantizer); - - // Like |Stats::perfect_quantizer|, this computes a "hindsight" speed setting - // for the last frame, one that may have potentially allowed for a - // better-quality quantizer choice by the encoder, while also keeping CPU - // utilization within budget. - const double perfect_speed = - speed * stats.time_utilization() / params_.max_time_utilization; - - // Update the ideal speed setting, to be used for the next frame. An - // exponentially-decaying weighted average is used here to smooth-out noise. - // The weight is based on the duration of the frame that was encoded. - constexpr Clock::duration kDecayHalfLife = milliseconds(120); - const double ticks = stats.frame_duration.count(); - const double weight = ticks / (ticks + kDecayHalfLife.count()); - ideal_speed_setting_ = - weight * perfect_speed + (1.0 - weight) * ideal_speed_setting_; - OSP_DCHECK(std::isfinite(ideal_speed_setting_)); -} - -void StreamingVp8Encoder::SendEncodedFrame(WorkUnitWithResults results) { +void StreamingVpxEncoder::SendEncodedFrame(WorkUnitWithResults results) { OSP_DCHECK(main_task_runner_->IsRunningOnTaskRunner()); EncodedFrame frame; @@ -444,27 +418,8 @@ void StreamingVp8Encoder::SendEncodedFrame(WorkUnitWithResults results) { } } -namespace { -void CopyPlane(const uint8_t* src, - int src_stride, - int num_rows, - uint8_t* dst, - int dst_stride) { - if (src_stride == dst_stride) { - memcpy(dst, src, src_stride * num_rows); - return; - } - const int bytes_per_row = std::min(src_stride, dst_stride); - while (--num_rows >= 0) { - memcpy(dst, src, bytes_per_row); - dst += dst_stride; - src += src_stride; - } -} -} // namespace - // static -StreamingVp8Encoder::VpxImageUniquePtr StreamingVp8Encoder::CloneAsVpxImage( +StreamingVpxEncoder::VpxImageUniquePtr StreamingVpxEncoder::CloneAsVpxImage( const VideoFrame& frame) { OSP_DCHECK_GE(frame.width, 0); OSP_DCHECK_GE(frame.height, 0); diff --git a/cast/standalone_sender/streaming_vpx_encoder.h b/cast/standalone_sender/streaming_vpx_encoder.h new file mode 100644 index 00000000..5c99309e --- /dev/null +++ b/cast/standalone_sender/streaming_vpx_encoder.h @@ -0,0 +1,169 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STANDALONE_SENDER_STREAMING_VPX_ENCODER_H_ +#define CAST_STANDALONE_SENDER_STREAMING_VPX_ENCODER_H_ + +#include <vpx/vpx_encoder.h> +#include <vpx/vpx_image.h> + +#include <algorithm> +#include <condition_variable> // NOLINT +#include <functional> +#include <memory> +#include <mutex> +#include <queue> +#include <thread> +#include <vector> + +#include "absl/base/thread_annotations.h" +#include "cast/standalone_sender/streaming_video_encoder.h" +#include "cast/streaming/constants.h" +#include "cast/streaming/frame_id.h" +#include "cast/streaming/rtp_time.h" +#include "platform/api/task_runner.h" +#include "platform/api/time.h" + +namespace openscreen { + +class TaskRunner; + +namespace cast { + +class Sender; + +// Uses libvpx to encode VP8/9 video and streams it to a Sender. Includes +// extensive logic for fine-tuning the encoder parameters in real-time, to +// provide the best quality results given external, uncontrollable factors: +// CPU/network availability, and the complexity of the video frame content. +// +// Internally, a separate encode thread is created and used to prevent blocking +// the main thread while frames are being encoded. All public API methods are +// assumed to be called on the same sequence/thread as the main TaskRunner +// (injected via the constructor). +// +// Usage: +// +// 1. EncodeAndSend() is used to queue-up video frames for encoding and sending, +// which will be done on a best-effort basis. +// +// 2. The client is expected to call SetTargetBitrate() frequently based on its +// own bandwidth estimates and congestion control logic. In addition, a client +// may provide a callback for each frame's encode statistics, which can be used +// to further optimize the user experience. For example, the stats can be used +// as a signal to reduce the data volume (i.e., resolution and/or frame rate) +// coming from the video capture source. +class StreamingVpxEncoder : public StreamingVideoEncoder { + public: + StreamingVpxEncoder(const Parameters& params, + TaskRunner* task_runner, + Sender* sender); + + ~StreamingVpxEncoder(); + + int GetTargetBitrate() const override; + void SetTargetBitrate(int new_bitrate) override; + void EncodeAndSend(const VideoFrame& frame, + Clock::time_point reference_time, + std::function<void(Stats)> stats_callback) override; + + private: + // Syntactic convenience to wrap the vpx_image_t alloc/free API in a smart + // pointer. + struct VpxImageDeleter { + void operator()(vpx_image_t* ptr) const { vpx_img_free(ptr); } + }; + using VpxImageUniquePtr = std::unique_ptr<vpx_image_t, VpxImageDeleter>; + + // Represents the state of one frame encode. This is created in + // EncodeAndSend(), and passed to the encode thread via the |encode_queue_|. + struct WorkUnit { + VpxImageUniquePtr image; + Clock::duration duration; + Clock::time_point reference_time; + RtpTimeTicks rtp_timestamp; + std::function<void(Stats)> stats_callback; + }; + + // Same as WorkUnit, but with additional fields to carry the encode results. + struct WorkUnitWithResults : public WorkUnit { + std::vector<uint8_t> payload; + bool is_key_frame = false; + Stats stats; + }; + + bool is_encoder_initialized() const { return config_.g_threads != 0; } + + // Destroys the VP8 encoder context if it has been initialized. + void DestroyEncoder(); + + // The procedure for the |encode_thread_| that loops, processing work units + // from the |encode_queue_| by calling Encode() until it's time to end the + // thread. + void ProcessWorkUnitsUntilTimeToQuit(); + + // If the |encoder_| is live, attempt reconfiguration to allow it to encode + // frames at a new frame size or target bitrate. If reconfiguration is not + // possible, destroy the existing instance and re-create a new |encoder_| + // instance. + void PrepareEncoder(int width, int height, int target_bitrate); + + // Wraps the complex libvpx vpx_codec_encode() call using inputs from + // |work_unit| and populating results there. + void EncodeFrame(bool force_key_frame, WorkUnitWithResults& work_unit); + + // Computes and populates |work_unit.stats| after the last call to + // EncodeFrame(). + void ComputeFrameEncodeStats(Clock::duration encode_wall_time, + int target_bitrate, + WorkUnitWithResults& work_unit); + + // Assembles and enqueues an EncodedFrame with the Sender on the main thread. + void SendEncodedFrame(WorkUnitWithResults results); + + // Allocates a vpx_image_t and copies the content from |frame| to it. + static VpxImageUniquePtr CloneAsVpxImage(const VideoFrame& frame); + + // The reference time of the first frame passed to EncodeAndSend(). + Clock::time_point start_time_ = Clock::time_point::min(); + + // The RTP timestamp of the last frame that was pushed into the + // |encode_queue_| by EncodeAndSend(). This is used to check whether + // timestamps are monotonically increasing. + RtpTimeTicks last_enqueued_rtp_timestamp_; + + // Guards a few members shared by both the main and encode threads. + std::mutex mutex_; + + // Used by the encode thread to sleep until more work is available. + std::condition_variable cv_ ABSL_GUARDED_BY(mutex_); + + // These encode parameters not passed in the WorkUnit struct because it is + // desirable for them to be applied as soon as possible, with the very next + // WorkUnit popped from the |encode_queue_| on the encode thread, and not to + // wait until some later WorkUnit is processed. + bool needs_key_frame_ ABSL_GUARDED_BY(mutex_) = true; + int target_bitrate_ ABSL_GUARDED_BY(mutex_) = 2 << 20; // Default: 2 Mbps. + + // The queue of frame encodes. The size of this queue is implicitly bounded by + // EncodeAndSend(), where it checks for the total in-flight media duration and + // maybe drops a frame. + std::queue<WorkUnit> encode_queue_ ABSL_GUARDED_BY(mutex_); + + // Current VP8 encoder configuration. Most of the fields are unchanging, and + // are populated in the ctor; but thereafter, only the encode thread accesses + // this struct. + // + // The speed setting is controlled via a separate libvpx API (see members + // below). + vpx_codec_enc_cfg_t config_{}; + + // libvpx VP8/9 encoder instance. Only the encode thread accesses this. + vpx_codec_ctx_t encoder_; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STANDALONE_SENDER_STREAMING_VPX_ENCODER_H_ diff --git a/cast/streaming/BUILD.gn b/cast/streaming/BUILD.gn index 04424cf4..985cd1e2 100644 --- a/cast/streaming/BUILD.gn +++ b/cast/streaming/BUILD.gn @@ -11,16 +11,37 @@ fuzzable_proto_library("remoting_proto") { sources = [ "remoting.proto" ] } +source_set("streaming_configs") { + sources = [ + "capture_configs.h", + "constants.h", + "message_fields.cc", + "message_fields.h", + "resolution.cc", + "resolution.h", + ] + + public_configs = [ "../../build:openscreen_include_dirs" ] + + public_deps = [ + "../../third_party/abseil", + "../../third_party/jsoncpp", + ] + + deps = [ + "../../platform:base", + "../../util:base", + ] +} + source_set("common") { sources = [ "answer_messages.cc", "answer_messages.h", - "capture_configs.h", "capture_recommendations.cc", "capture_recommendations.h", "clock_drift_smoother.cc", "clock_drift_smoother.h", - "constants.h", "encoded_frame.cc", "encoded_frame.h", "environment.cc", @@ -30,8 +51,6 @@ source_set("common") { "frame_crypto.h", "frame_id.cc", "frame_id.h", - "message_fields.cc", - "message_fields.h", "ntp_time.cc", "ntp_time.h", "offer_messages.cc", @@ -40,8 +59,8 @@ source_set("common") { "packet_util.h", "receiver_message.cc", "receiver_message.h", - "rpc_broker.cc", - "rpc_broker.h", + "rpc_messenger.cc", + "rpc_messenger.h", "rtcp_common.cc", "rtcp_common.h", "rtcp_session.cc", @@ -54,8 +73,8 @@ source_set("common") { "sender_message.h", "session_config.cc", "session_config.h", - "session_messager.cc", - "session_messager.h", + "session_messenger.cc", + "session_messenger.h", "ssrc.cc", "ssrc.h", ] @@ -64,6 +83,7 @@ source_set("common") { public_deps = [ ":remoting_proto", + ":streaming_configs", "../../third_party/abseil", "../../third_party/boringssl", "../common:channel", @@ -91,6 +111,8 @@ source_set("receiver") { "packet_receive_stats_tracker.h", "receiver.cc", "receiver.h", + "receiver_base.cc", + "receiver_base.h", "receiver_packet_router.cc", "receiver_packet_router.h", "receiver_session.cc", @@ -112,6 +134,7 @@ source_set("sender") { "bandwidth_estimator.h", "compound_rtcp_parser.cc", "compound_rtcp_parser.h", + "remoting_capabilities.h", "rtp_packetizer.cc", "rtp_packetizer.h", "sender.cc", @@ -170,7 +193,7 @@ source_set("unittests") { "packet_util_unittest.cc", "receiver_session_unittest.cc", "receiver_unittest.cc", - "rpc_broker_unittest.cc", + "rpc_messenger_unittest.cc", "rtcp_common_unittest.cc", "rtp_packet_parser_unittest.cc", "rtp_packetizer_unittest.cc", @@ -179,7 +202,7 @@ source_set("unittests") { "sender_report_unittest.cc", "sender_session_unittest.cc", "sender_unittest.cc", - "session_messager_unittest.cc", + "session_messenger_unittest.cc", "ssrc_unittest.cc", ] diff --git a/cast/streaming/answer_messages.cc b/cast/streaming/answer_messages.cc index 906e8901..20af542f 100644 --- a/cast/streaming/answer_messages.cc +++ b/cast/streaming/answer_messages.cc @@ -9,9 +9,9 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "platform/base/error.h" +#include "util/enum_name_table.h" #include "util/json/json_helpers.h" #include "util/osp_logging.h" - namespace openscreen { namespace cast { @@ -47,7 +47,7 @@ static constexpr char kMaxPixelsPerSecond[] = "maxPixelsPerSecond"; // Minimum dimensions. If omitted, the sender will assume a reasonable minimum // with the same aspect ratio as maxDimensions, as close to 320*180 as possible. // Should reflect the true operational minimum. -static constexpr char kMinDimensions[] = "minDimensions"; +static constexpr char kMinResolution[] = "minResolution"; // Maximum dimensions, not necessarily ideal dimensions. static constexpr char kMaxDimensions[] = "maxDimensions"; @@ -57,15 +57,6 @@ static constexpr char kMaxSampleRate[] = "maxSampleRate"; // Maximum number of audio channels (1 is mono, 2 is stereo, etc.). static constexpr char kMaxChannels[] = "maxChannels"; -/// Dimension properties. -// Width in pixels. -static constexpr char kWidth[] = "width"; -// Height in pixels. -static constexpr char kHeight[] = "height"; -// Frame rate as a rational decimal number or fraction. -// E.g. 30 and "3000/1001" are both valid representations. -static constexpr char kFrameRate[] = "frameRate"; - /// Display description properties // If this optional field is included in the ANSWER message, the receiver is // attached to a fixed display that has the given dimensions and frame rate @@ -115,41 +106,33 @@ static constexpr char kReceiverRtcpEventLog[] = "receiverRtcpEventLog"; // OPtional array of numbers specifying the indexes of streams that will use // DSCP values specified in the OFFER message for RTCP packets. static constexpr char kReceiverRtcpDscp[] = "receiverRtcpDscp"; -// True if receiver can report wifi status. -static constexpr char kReceiverGetStatus[] = "receiverGetStatus"; // If this optional field is present the receiver supports the specific // RTP extensions (such as adaptive playout delay). static constexpr char kRtpExtensions[] = "rtpExtensions"; +EnumNameTable<AspectRatioConstraint, 2> kAspectRatioConstraintNames{ + {{kScalingReceiver, AspectRatioConstraint::kVariable}, + {kScalingSender, AspectRatioConstraint::kFixed}}}; + Json::Value AspectRatioConstraintToJson(AspectRatioConstraint aspect_ratio) { - switch (aspect_ratio) { - case AspectRatioConstraint::kVariable: - return Json::Value(kScalingReceiver); - case AspectRatioConstraint::kFixed: - default: - return Json::Value(kScalingSender); - } + return Json::Value(GetEnumName(kAspectRatioConstraintNames, aspect_ratio) + .value(kScalingSender)); } -bool AspectRatioConstraintParseAndValidate(const Json::Value& value, - AspectRatioConstraint* out) { - // the aspect ratio constraint is an optional field. - if (!value) { - return true; - } - +bool TryParseAspectRatioConstraint(const Json::Value& value, + AspectRatioConstraint* out) { std::string aspect_ratio; - if (!json::ParseAndValidateString(value, &aspect_ratio)) { + if (!json::TryParseString(value, &aspect_ratio)) { return false; } - if (aspect_ratio == kScalingReceiver) { - *out = AspectRatioConstraint::kVariable; - return true; - } else if (aspect_ratio == kScalingSender) { - *out = AspectRatioConstraint::kFixed; - return true; + + ErrorOr<AspectRatioConstraint> constraint = + GetEnum(kAspectRatioConstraintNames, aspect_ratio); + if (constraint.is_error()) { + return false; } - return false; + *out = constraint.value(); + return true; } template <typename T> @@ -171,7 +154,7 @@ bool ParseOptional(const Json::Value& value, absl::optional<T>* out) { return true; } T tentative_out; - if (!T::ParseAndValidate(value, &tentative_out)) { + if (!T::TryParse(value, &tentative_out)) { return false; } *out = tentative_out; @@ -181,9 +164,9 @@ bool ParseOptional(const Json::Value& value, absl::optional<T>* out) { } // namespace // static -bool AspectRatio::ParseAndValidate(const Json::Value& value, AspectRatio* out) { +bool AspectRatio::TryParse(const Json::Value& value, AspectRatio* out) { std::string parsed_value; - if (!json::ParseAndValidateString(value, &parsed_value)) { + if (!json::TryParseString(value, &parsed_value)) { return false; } @@ -205,21 +188,20 @@ bool AspectRatio::IsValid() const { } // static -bool AudioConstraints::ParseAndValidate(const Json::Value& root, - AudioConstraints* out) { - if (!json::ParseAndValidateInt(root[kMaxSampleRate], - &(out->max_sample_rate)) || - !json::ParseAndValidateInt(root[kMaxChannels], &(out->max_channels)) || - !json::ParseAndValidateInt(root[kMaxBitRate], &(out->max_bit_rate))) { +bool AudioConstraints::TryParse(const Json::Value& root, + AudioConstraints* out) { + if (!json::TryParseInt(root[kMaxSampleRate], &(out->max_sample_rate)) || + !json::TryParseInt(root[kMaxChannels], &(out->max_channels)) || + !json::TryParseInt(root[kMaxBitRate], &(out->max_bit_rate))) { return false; } std::chrono::milliseconds max_delay; - if (json::ParseAndValidateMilliseconds(root[kMaxDelay], &max_delay)) { + if (json::TryParseMilliseconds(root[kMaxDelay], &max_delay)) { out->max_delay = max_delay; } - if (!json::ParseAndValidateInt(root[kMinBitRate], &(out->min_bit_rate))) { + if (!json::TryParseInt(root[kMinBitRate], &(out->min_bit_rate))) { out->min_bit_rate = kDefaultAudioMinBitRate; } return out->IsValid(); @@ -243,52 +225,27 @@ bool AudioConstraints::IsValid() const { max_bit_rate >= min_bit_rate; } -bool Dimensions::ParseAndValidate(const Json::Value& root, Dimensions* out) { - if (!json::ParseAndValidateInt(root[kWidth], &(out->width)) || - !json::ParseAndValidateInt(root[kHeight], &(out->height)) || - !json::ParseAndValidateSimpleFraction(root[kFrameRate], - &(out->frame_rate))) { - return false; - } - return out->IsValid(); -} - -bool Dimensions::IsValid() const { - return width > 0 && height > 0 && frame_rate.is_positive(); -} - -Json::Value Dimensions::ToJson() const { - OSP_DCHECK(IsValid()); - Json::Value root; - root[kWidth] = width; - root[kHeight] = height; - root[kFrameRate] = frame_rate.ToString(); - return root; -} - // static -bool VideoConstraints::ParseAndValidate(const Json::Value& root, - VideoConstraints* out) { - if (!Dimensions::ParseAndValidate(root[kMaxDimensions], - &(out->max_dimensions)) || - !json::ParseAndValidateInt(root[kMaxBitRate], &(out->max_bit_rate)) || - !ParseOptional<Dimensions>(root[kMinDimensions], - &(out->min_dimensions))) { +bool VideoConstraints::TryParse(const Json::Value& root, + VideoConstraints* out) { + if (!Dimensions::TryParse(root[kMaxDimensions], &(out->max_dimensions)) || + !json::TryParseInt(root[kMaxBitRate], &(out->max_bit_rate)) || + !ParseOptional<Dimensions>(root[kMinResolution], + &(out->min_resolution))) { return false; } std::chrono::milliseconds max_delay; - if (json::ParseAndValidateMilliseconds(root[kMaxDelay], &max_delay)) { + if (json::TryParseMilliseconds(root[kMaxDelay], &max_delay)) { out->max_delay = max_delay; } double max_pixels_per_second; - if (json::ParseAndValidateDouble(root[kMaxPixelsPerSecond], - &max_pixels_per_second)) { + if (json::TryParseDouble(root[kMaxPixelsPerSecond], &max_pixels_per_second)) { out->max_pixels_per_second = max_pixels_per_second; } - if (!json::ParseAndValidateInt(root[kMinBitRate], &(out->min_bit_rate))) { + if (!json::TryParseInt(root[kMinBitRate], &(out->min_bit_rate))) { out->min_bit_rate = kDefaultVideoMinBitRate; } return out->IsValid(); @@ -299,8 +256,8 @@ bool VideoConstraints::IsValid() const { max_bit_rate > min_bit_rate && (!max_delay.has_value() || max_delay->count() > 0) && max_dimensions.IsValid() && - (!min_dimensions.has_value() || min_dimensions->IsValid()) && - max_dimensions.frame_rate.numerator > 0; + (!min_resolution.has_value() || min_resolution->IsValid()) && + max_dimensions.frame_rate.numerator() > 0; } Json::Value VideoConstraints::ToJson() const { @@ -313,8 +270,8 @@ Json::Value VideoConstraints::ToJson() const { root[kMaxPixelsPerSecond] = max_pixels_per_second.value(); } - if (min_dimensions.has_value()) { - root[kMinDimensions] = min_dimensions->ToJson(); + if (min_resolution.has_value()) { + root[kMinResolution] = min_resolution->ToJson(); } if (max_delay.has_value()) { @@ -324,9 +281,9 @@ Json::Value VideoConstraints::ToJson() const { } // static -bool Constraints::ParseAndValidate(const Json::Value& root, Constraints* out) { - if (!AudioConstraints::ParseAndValidate(root[kAudio], &(out->audio)) || - !VideoConstraints::ParseAndValidate(root[kVideo], &(out->video))) { +bool Constraints::TryParse(const Json::Value& root, Constraints* out) { + if (!AudioConstraints::TryParse(root[kAudio], &(out->audio)) || + !VideoConstraints::TryParse(root[kVideo], &(out->video))) { return false; } return out->IsValid(); @@ -345,15 +302,15 @@ Json::Value Constraints::ToJson() const { } // static -bool DisplayDescription::ParseAndValidate(const Json::Value& root, - DisplayDescription* out) { +bool DisplayDescription::TryParse(const Json::Value& root, + DisplayDescription* out) { if (!ParseOptional<Dimensions>(root[kDimensions], &(out->dimensions)) || !ParseOptional<AspectRatio>(root[kAspectRatio], &(out->aspect_ratio))) { return false; } AspectRatioConstraint constraint; - if (AspectRatioConstraintParseAndValidate(root[kScaling], &constraint)) { + if (TryParseAspectRatioConstraint(root[kScaling], &constraint)) { out->aspect_ratio_constraint = absl::optional<AspectRatioConstraint>(std::move(constraint)); } else { @@ -402,28 +359,25 @@ Json::Value DisplayDescription::ToJson() const { return root; } -bool Answer::ParseAndValidate(const Json::Value& root, Answer* out) { - if (!json::ParseAndValidateInt(root[kUdpPort], &(out->udp_port)) || - !json::ParseAndValidateIntArray(root[kSendIndexes], - &(out->send_indexes)) || - !json::ParseAndValidateUintArray(root[kSsrcs], &(out->ssrcs)) || +bool Answer::ParseAndValidate(const Json::Value& value, Answer* out) { + return TryParse(value, out); +} + +bool Answer::TryParse(const Json::Value& root, Answer* out) { + if (!json::TryParseInt(root[kUdpPort], &(out->udp_port)) || + !json::TryParseIntArray(root[kSendIndexes], &(out->send_indexes)) || + !json::TryParseUintArray(root[kSsrcs], &(out->ssrcs)) || !ParseOptional<Constraints>(root[kConstraints], &(out->constraints)) || !ParseOptional<DisplayDescription>(root[kDisplay], &(out->display))) { return false; } - if (!json::ParseBool(root[kReceiverGetStatus], - &(out->supports_wifi_status_reporting))) { - out->supports_wifi_status_reporting = false; - } // These function set to empty array if not present, so we can ignore // the return value for optional values. - json::ParseAndValidateIntArray(root[kReceiverRtcpEventLog], - &(out->receiver_rtcp_event_log)); - json::ParseAndValidateIntArray(root[kReceiverRtcpDscp], - &(out->receiver_rtcp_dscp)); - json::ParseAndValidateStringArray(root[kRtpExtensions], - &(out->rtp_extensions)); + json::TryParseIntArray(root[kReceiverRtcpEventLog], + &(out->receiver_rtcp_event_log)); + json::TryParseIntArray(root[kReceiverRtcpDscp], &(out->receiver_rtcp_dscp)); + json::TryParseStringArray(root[kRtpExtensions], &(out->rtp_extensions)); return out->IsValid(); } @@ -459,7 +413,6 @@ Json::Value Answer::ToJson() const { root[kDisplay] = display->ToJson(); } root[kUdpPort] = udp_port; - root[kReceiverGetStatus] = supports_wifi_status_reporting; root[kSendIndexes] = PrimitiveVectorToJson(send_indexes); root[kSsrcs] = PrimitiveVectorToJson(ssrcs); // Some sender do not handle empty array properly, so we omit these fields diff --git a/cast/streaming/answer_messages.h b/cast/streaming/answer_messages.h index 1f62706a..7095e455 100644 --- a/cast/streaming/answer_messages.h +++ b/cast/streaming/answer_messages.h @@ -15,6 +15,7 @@ #include <vector> #include "absl/types/optional.h" +#include "cast/streaming/resolution.h" #include "cast/streaming/ssrc.h" #include "json/value.h" #include "platform/base/error.h" @@ -28,14 +29,14 @@ namespace cast { // readability of the structs provided in this file by cutting down on the // amount of obscuring boilerplate code. For each of the following struct // definitions, the following method definitions are shared: -// (1) ParseAndValidate. Shall return a boolean indicating whether the out +// (1) TryParse. Shall return a boolean indicating whether the out // parameter is in a valid state after checking bounds and restrictions. // (2) ToJson. Should return a proper JSON object. Assumes that IsValid() // has been called already, OSP_DCHECKs if not IsValid(). -// (3) IsValid. Used by both ParseAndValidate and ToJson to ensure that the +// (3) IsValid. Used by both TryParse and ToJson to ensure that the // object is in a good state. struct AudioConstraints { - static bool ParseAndValidate(const Json::Value& value, AudioConstraints* out); + static bool TryParse(const Json::Value& value, AudioConstraints* out); Json::Value ToJson() const; bool IsValid() const; @@ -46,23 +47,13 @@ struct AudioConstraints { absl::optional<std::chrono::milliseconds> max_delay = {}; }; -struct Dimensions { - static bool ParseAndValidate(const Json::Value& value, Dimensions* out); - Json::Value ToJson() const; - bool IsValid() const; - - int width = 0; - int height = 0; - SimpleFraction frame_rate; -}; - struct VideoConstraints { - static bool ParseAndValidate(const Json::Value& value, VideoConstraints* out); + static bool TryParse(const Json::Value& value, VideoConstraints* out); Json::Value ToJson() const; bool IsValid() const; absl::optional<double> max_pixels_per_second = {}; - absl::optional<Dimensions> min_dimensions = {}; + absl::optional<Dimensions> min_resolution = {}; Dimensions max_dimensions = {}; int min_bit_rate = 0; // optional int max_bit_rate = 0; @@ -70,7 +61,7 @@ struct VideoConstraints { }; struct Constraints { - static bool ParseAndValidate(const Json::Value& value, Constraints* out); + static bool TryParse(const Json::Value& value, Constraints* out); Json::Value ToJson() const; bool IsValid() const; @@ -84,7 +75,7 @@ struct Constraints { enum class AspectRatioConstraint : uint8_t { kVariable = 0, kFixed }; struct AspectRatio { - static bool ParseAndValidate(const Json::Value& value, AspectRatio* out); + static bool TryParse(const Json::Value& value, AspectRatio* out); bool IsValid() const; bool operator==(const AspectRatio& other) const { @@ -96,8 +87,7 @@ struct AspectRatio { }; struct DisplayDescription { - static bool ParseAndValidate(const Json::Value& value, - DisplayDescription* out); + static bool TryParse(const Json::Value& value, DisplayDescription* out); Json::Value ToJson() const; bool IsValid() const; @@ -109,7 +99,10 @@ struct DisplayDescription { }; struct Answer { + // TODO(jophba): DEPRECATED, remove separately. static bool ParseAndValidate(const Json::Value& value, Answer* out); + + static bool TryParse(const Json::Value& value, Answer* out); Json::Value ToJson() const; bool IsValid() const; @@ -123,7 +116,6 @@ struct Answer { absl::optional<DisplayDescription> display; std::vector<int> receiver_rtcp_event_log; std::vector<int> receiver_rtcp_dscp; - bool supports_wifi_status_reporting = false; // RTP extensions should be empty, but not null. std::vector<std::string> rtp_extensions = {}; diff --git a/cast/streaming/answer_messages_unittest.cc b/cast/streaming/answer_messages_unittest.cc index e4ec82f4..3d618828 100644 --- a/cast/streaming/answer_messages_unittest.cc +++ b/cast/streaming/answer_messages_unittest.cc @@ -37,7 +37,7 @@ constexpr char kValidAnswerJson[] = R"({ }, "video": { "maxPixelsPerSecond": 62208000, - "minDimensions": { + "minResolution": { "width": 320, "height": 180, "frameRate": 0 @@ -63,7 +63,6 @@ constexpr char kValidAnswerJson[] = R"({ }, "receiverRtcpEventLog": [0, 1], "receiverRtcpDscp": [234, 567], - "receiverGetStatus": true, "rtpExtensions": ["adaptive_playout_delay"] })"; @@ -81,34 +80,22 @@ const Answer kValidAnswer{ }, // audio VideoConstraints{ 40000.0, // max_pixels_per_second - absl::optional<Dimensions>(Dimensions{ - 320, // width - 480, // height - SimpleFraction{15000, 101} // frame_rate - }), // min_dimensions - Dimensions{ - 1920, // width - 1080, // height - SimpleFraction{288, 2} // frame_rate - }, + absl::optional<Dimensions>( + Dimensions{320, 480, SimpleFraction{15000, 101}}), + Dimensions{1920, 1080, SimpleFraction{288, 2}}, 300000, // min_bit_rate 144000000, // max_bit_rate milliseconds(3000) // max_delay } // video }), // constraints absl::optional<DisplayDescription>(DisplayDescription{ - absl::optional<Dimensions>(Dimensions{ - 640, // width - 480, // height - SimpleFraction{30, 1} // frame_rate - }), + absl::optional<Dimensions>(Dimensions{640, 480, SimpleFraction{30, 1}}), absl::optional<AspectRatio>(AspectRatio{16, 9}), // aspect_ratio absl::optional<AspectRatioConstraint>( AspectRatioConstraint::kFixed), // scaling }), std::vector<int>{7, 8, 9}, // receiver_rtcp_event_log std::vector<int>{11, 12, 13}, // receiver_rtcp_dscp - true, // receiver_get_status std::vector<std::string>{"foo", "bar"} // rtp_extensions }; @@ -137,10 +124,10 @@ void ExpectEqualsValidAnswerJson(const Answer& answer) { const VideoConstraints& video = answer.constraints->video; EXPECT_EQ(62208000, video.max_pixels_per_second); - ASSERT_TRUE(video.min_dimensions.has_value()); - EXPECT_EQ(320, video.min_dimensions->width); - EXPECT_EQ(180, video.min_dimensions->height); - EXPECT_EQ((SimpleFraction{0, 1}), video.min_dimensions->frame_rate); + ASSERT_TRUE(video.min_resolution.has_value()); + EXPECT_EQ(320, video.min_resolution->width); + EXPECT_EQ(180, video.min_resolution->height); + EXPECT_EQ((SimpleFraction{0, 1}), video.min_resolution->frame_rate); EXPECT_EQ(1920, video.max_dimensions.width); EXPECT_EQ(1080, video.max_dimensions.height); EXPECT_EQ((SimpleFraction{60, 1}), video.max_dimensions.frame_rate); @@ -160,7 +147,6 @@ void ExpectEqualsValidAnswerJson(const Answer& answer) { EXPECT_THAT(answer.receiver_rtcp_event_log, ElementsAre(0, 1)); EXPECT_THAT(answer.receiver_rtcp_dscp, ElementsAre(234, 567)); - EXPECT_TRUE(answer.supports_wifi_status_reporting); EXPECT_THAT(answer.rtp_extensions, ElementsAre("adaptive_playout_delay")); } @@ -170,7 +156,7 @@ void ExpectFailureOnParse(absl::string_view raw_json) { ASSERT_TRUE(root.is_value()); Answer answer; - EXPECT_FALSE(Answer::ParseAndValidate(std::move(root.value()), &answer)); + EXPECT_FALSE(Answer::TryParse(std::move(root.value()), &answer)); EXPECT_FALSE(answer.IsValid()); } @@ -182,7 +168,7 @@ void ExpectSuccessOnParse(absl::string_view raw_json, Answer* out = nullptr) { ASSERT_TRUE(root.is_value()); Answer answer; - ASSERT_TRUE(Answer::ParseAndValidate(std::move(root.value()), &answer)); + ASSERT_TRUE(Answer::TryParse(std::move(root.value()), &answer)); EXPECT_TRUE(answer.IsValid()); if (out) { *out = std::move(answer); @@ -223,11 +209,11 @@ TEST(AnswerMessagesTest, ProperlyPopulatedAnswerSerializesProperly) { EXPECT_EQ(video["maxBitRate"], 144000000); EXPECT_EQ(video["maxDelay"], 3000); - Json::Value min_dimensions = std::move(video["minDimensions"]); - EXPECT_EQ(min_dimensions.type(), Json::ValueType::objectValue); - EXPECT_EQ(min_dimensions["width"], 320); - EXPECT_EQ(min_dimensions["height"], 480); - EXPECT_EQ(min_dimensions["frameRate"], "15000/101"); + Json::Value min_resolution = std::move(video["minResolution"]); + EXPECT_EQ(min_resolution.type(), Json::ValueType::objectValue); + EXPECT_EQ(min_resolution["width"], 320); + EXPECT_EQ(min_resolution["height"], 480); + EXPECT_EQ(min_resolution["frameRate"], "15000/101"); Json::Value max_dimensions = std::move(video["maxDimensions"]); EXPECT_EQ(max_dimensions.type(), Json::ValueType::objectValue); @@ -258,8 +244,6 @@ TEST(AnswerMessagesTest, ProperlyPopulatedAnswerSerializesProperly) { EXPECT_EQ(receiver_rtcp_dscp[1], 12); EXPECT_EQ(receiver_rtcp_dscp[2], 13); - EXPECT_EQ(root["receiverGetStatus"], true); - Json::Value rtp_extensions = std::move(root["rtpExtensions"]); EXPECT_EQ(rtp_extensions.type(), Json::ValueType::arrayValue); EXPECT_EQ(rtp_extensions[0], "foo"); @@ -330,8 +314,7 @@ TEST(AnswerMessagesTest, SucceedsWithMissingRtpFields) { ExpectSuccessOnParse(R"({ "udpPort": 1234, "sendIndexes": [1, 3], - "ssrcs": [1233324, 2234222], - "receiverGetStatus": true + "ssrcs": [1233324, 2234222] })"); } @@ -342,37 +325,22 @@ TEST(AnswerMessagesTest, ErrorOnEmptyAnswer) { TEST(AnswerMessagesTest, ErrorOnMissingUdpPort) { ExpectFailureOnParse(R"({ "sendIndexes": [1, 3], - "ssrcs": [1233324, 2234222], - "receiverGetStatus": true + "ssrcs": [1233324, 2234222] })"); } TEST(AnswerMessagesTest, ErrorOnMissingSsrcs) { ExpectFailureOnParse(R"({ "udpPort": 1234, - "sendIndexes": [1, 3], - "receiverGetStatus": true + "sendIndexes": [1, 3] })"); } TEST(AnswerMessagesTest, ErrorOnMissingSendIndexes) { ExpectFailureOnParse(R"({ "udpPort": 1234, - "ssrcs": [1233324, 2234222], - "receiverGetStatus": true - })"); -} - -TEST(AnswerMessagesTest, AssumesNoReportingIfGetStatusFalse) { - Answer answer; - ExpectSuccessOnParse(R"({ - "udpPort": 1234, - "sendIndexes": [1, 3], "ssrcs": [1233324, 2234222] - })", - &answer); - - EXPECT_FALSE(answer.supports_wifi_status_reporting); + })"); } TEST(AnswerMessagesTest, AllowsReceiverSideScaling) { @@ -420,8 +388,7 @@ TEST(AnswerMessagesTest, AssumesMinBitRateIfOmitted) { "maxBitRate": 10000000, "maxDelay": 5000 } - }, - "receiverGetStatus": true + } })", &answer); @@ -476,8 +443,8 @@ TEST(AnswerMessagesTest, VideoConstraintsIsValid) { VideoConstraints invalid_max_pixels_per_second = kValidVideoConstraints; invalid_max_pixels_per_second.max_pixels_per_second = 0; - VideoConstraints invalid_min_dimensions = kValidVideoConstraints; - invalid_min_dimensions.min_dimensions->width = 0; + VideoConstraints invalid_min_resolution = kValidVideoConstraints; + invalid_min_resolution.min_resolution->width = 0; VideoConstraints invalid_max_dimensions = kValidVideoConstraints; invalid_max_dimensions.max_dimensions.height = 0; @@ -493,7 +460,7 @@ TEST(AnswerMessagesTest, VideoConstraintsIsValid) { EXPECT_TRUE(kValidVideoConstraints.IsValid()); EXPECT_FALSE(invalid_max_pixels_per_second.IsValid()); - EXPECT_FALSE(invalid_min_dimensions.IsValid()); + EXPECT_FALSE(invalid_min_resolution.IsValid()); EXPECT_FALSE(invalid_max_dimensions.IsValid()); EXPECT_FALSE(invalid_min_bit_rate.IsValid()); EXPECT_FALSE(invalid_max_bit_rate.IsValid()); @@ -528,7 +495,7 @@ TEST(AnswerMessagesTest, AspectRatioIsValid) { EXPECT_FALSE(kInvalidHeight.IsValid()); } -TEST(AnswerMessagesTest, AspectRatioParseAndValidate) { +TEST(AnswerMessagesTest, AspectRatioTryParse) { const Json::Value kValid = "16:9"; const Json::Value kWrongDelimiter = "16-9"; const Json::Value kTooManyFields = "16:9:3"; @@ -543,24 +510,24 @@ TEST(AnswerMessagesTest, AspectRatioParseAndValidate) { const Json::Value kZeroHeight = "16:0"; AspectRatio out; - EXPECT_TRUE(AspectRatio::ParseAndValidate(kValid, &out)); + EXPECT_TRUE(AspectRatio::TryParse(kValid, &out)); EXPECT_EQ(out.width, 16); EXPECT_EQ(out.height, 9); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kWrongDelimiter, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kTooManyFields, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kTooFewFields, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kWrongDelimiter, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kNoDelimiter, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kNegativeWidth, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kNegativeHeight, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kNegativeBoth, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kNonNumberWidth, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kNonNumberHeight, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kZeroWidth, &out)); - EXPECT_FALSE(AspectRatio::ParseAndValidate(kZeroHeight, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kWrongDelimiter, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kTooManyFields, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kTooFewFields, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kWrongDelimiter, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kNoDelimiter, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kNegativeWidth, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kNegativeHeight, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kNegativeBoth, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kNonNumberWidth, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kNonNumberHeight, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kZeroWidth, &out)); + EXPECT_FALSE(AspectRatio::TryParse(kZeroHeight, &out)); } -TEST(AnswerMessagesTest, DisplayDescriptionParseAndValidate) { +TEST(AnswerMessagesTest, DisplayDescriptionTryParse) { Json::Value valid_scaling; valid_scaling["scaling"] = "receiver"; Json::Value invalid_scaling; @@ -586,25 +553,23 @@ TEST(AnswerMessagesTest, DisplayDescriptionParseAndValidate) { aspect_ratio_and_constraint["aspectRatio"] = "4:3"; DisplayDescription out; - ASSERT_TRUE(DisplayDescription::ParseAndValidate(valid_scaling, &out)); + ASSERT_TRUE(DisplayDescription::TryParse(valid_scaling, &out)); ASSERT_TRUE(out.aspect_ratio_constraint.has_value()); EXPECT_EQ(out.aspect_ratio_constraint.value(), AspectRatioConstraint::kVariable); - EXPECT_FALSE(DisplayDescription::ParseAndValidate(invalid_scaling, &out)); - EXPECT_TRUE( - DisplayDescription::ParseAndValidate(invalid_scaling_valid_ratio, &out)); + EXPECT_FALSE(DisplayDescription::TryParse(invalid_scaling, &out)); + EXPECT_TRUE(DisplayDescription::TryParse(invalid_scaling_valid_ratio, &out)); - ASSERT_TRUE(DisplayDescription::ParseAndValidate(valid_dimensions, &out)); + ASSERT_TRUE(DisplayDescription::TryParse(valid_dimensions, &out)); ASSERT_TRUE(out.dimensions.has_value()); EXPECT_EQ(1920, out.dimensions->width); EXPECT_EQ(1080, out.dimensions->height); EXPECT_EQ((SimpleFraction{30, 1}), out.dimensions->frame_rate); - EXPECT_FALSE(DisplayDescription::ParseAndValidate(invalid_dimensions, &out)); + EXPECT_FALSE(DisplayDescription::TryParse(invalid_dimensions, &out)); - ASSERT_TRUE( - DisplayDescription::ParseAndValidate(aspect_ratio_and_constraint, &out)); + ASSERT_TRUE(DisplayDescription::TryParse(aspect_ratio_and_constraint, &out)); EXPECT_EQ(AspectRatioConstraint::kFixed, out.aspect_ratio_constraint.value()); } diff --git a/cast/streaming/capture_configs.h b/cast/streaming/capture_configs.h index fd99c17c..56b15898 100644 --- a/cast/streaming/capture_configs.h +++ b/cast/streaming/capture_configs.h @@ -9,6 +9,8 @@ #include <vector> #include "cast/streaming/constants.h" +#include "cast/streaming/resolution.h" +#include "util/simple_fraction.h" namespace openscreen { namespace cast { @@ -33,25 +35,11 @@ struct AudioCaptureConfig { // Target playout delay in milliseconds. std::chrono::milliseconds target_playout_delay = kDefaultTargetPlayoutDelay; -}; - -// Display resolution in pixels. -struct DisplayResolution { - // Width in pixels. - int width = 1920; - // Height in pixels. - int height = 1080; -}; - -// Frame rates are expressed as a rational number, and must be positive. -struct FrameRate { - // For simple cases, the frame rate may be provided by simply setting the - // number to the desired value, e.g. 30 or 60FPS. Some common frame rates like - // 23.98 FPS (for NTSC compatibility) are represented as fractions, in this - // case 24000/1001. - int numerator = kDefaultFrameRate; - int denominator = 1; + // The codec parameter for this configuration. Honors the format laid out + // in RFC 6381: https://datatracker.ietf.org/doc/html/rfc6381 + // NOTE: the "profiles" parameter is not supported in our implementation. + std::string codec_parameter; }; // A configuration set that can be used by the sender to capture video, as @@ -62,7 +50,11 @@ struct VideoCaptureConfig { VideoCodec codec = VideoCodec::kVp8; // Maximum frame rate in frames per second. - FrameRate max_frame_rate; + // For simple cases, the frame rate may be provided by simply setting the + // number to the desired value, e.g. 30 or 60FPS. Some common frame rates like + // 23.98 FPS (for NTSC compatibility) are represented as fractions, in this + // case 24000/1001. + SimpleFraction max_frame_rate{kDefaultFrameRate, 1}; // Number specifying the maximum bit rate for this stream. A value of // zero means that the maximum bit rate should be automatically selected by @@ -71,10 +63,18 @@ struct VideoCaptureConfig { // Resolutions to be offered to the receiver. At least one resolution // must be provided. - std::vector<DisplayResolution> resolutions; + std::vector<Resolution> resolutions; // Target playout delay in milliseconds. std::chrono::milliseconds target_playout_delay = kDefaultTargetPlayoutDelay; + + // The codec parameter for this configuration. Honors the format laid out + // in RFC 6381: https://datatracker.ietf.org/doc/html/rfc6381. + // VP8 and VP9 codec parameter versions are defined here: + // https://developer.mozilla.org/en-US/docs/Web/Media/Formats/codecs_parameter#webm + // https://www.webmproject.org/vp9/mp4/#codecs-parameter-string + // NOTE: the "profiles" parameter is not supported in our implementation. + std::string codec_parameter; }; } // namespace cast diff --git a/cast/streaming/capture_recommendations.cc b/cast/streaming/capture_recommendations.cc index b30b5dc1..4b3bcd16 100644 --- a/cast/streaming/capture_recommendations.cc +++ b/cast/streaming/capture_recommendations.cc @@ -15,16 +15,6 @@ namespace cast { namespace capture_recommendations { namespace { -bool DoubleEquals(double a, double b) { - // Choice of epsilon for double comparison allows for proper comparison - // for both aspect ratios and frame rates. For frame rates, it is based on the - // broadcast rate of 29.97fps, which is actually 29.976. For aspect ratios, it - // allows for a one-pixel difference at a 4K resolution, we want it to be - // relatively high to avoid false negative comparison results. - const double kEpsilon = .0001; - return std::abs(a - b) < kEpsilon; -} - void ApplyDisplay(const DisplayDescription& description, Recommendations* recommendations) { recommendations->video.supports_scaling = @@ -35,14 +25,15 @@ void ApplyDisplay(const DisplayDescription& description, // We should never exceed the display's resolution, since it will always // force scaling. if (description.dimensions) { - const double frame_rate = - static_cast<double>(description.dimensions->frame_rate); - recommendations->video.maximum = - Resolution{description.dimensions->width, - description.dimensions->height, frame_rate}; + recommendations->video.maximum = description.dimensions.value(); recommendations->video.bit_rate_limits.maximum = recommendations->video.maximum.effective_bit_rate(); - recommendations->video.minimum.set_minimum(recommendations->video.maximum); + + if (recommendations->video.maximum.width < + recommendations->video.minimum.width) { + recommendations->video.minimum = + recommendations->video.maximum.ToResolution(); + } } // If the receiver gives us an aspect ratio that doesn't match the display @@ -53,16 +44,6 @@ void ApplyDisplay(const DisplayDescription& description, if (description.aspect_ratio) { aspect_ratio = static_cast<double>(description.aspect_ratio->width) / description.aspect_ratio->height; -#if OSP_DCHECK_IS_ON() - if (description.dimensions) { - const double from_dims = - static_cast<double>(description.dimensions->width) / - description.dimensions->height; - if (!DoubleEquals(from_dims, aspect_ratio)) { - OSP_DLOG_WARN << "Received mismatched aspect ratio from the receiver."; - } - } -#endif recommendations->video.maximum.width = recommendations->video.maximum.height * aspect_ratio; } else if (description.dimensions) { @@ -75,10 +56,6 @@ void ApplyDisplay(const DisplayDescription& description, recommendations->video.minimum.height * aspect_ratio; } -Resolution ToResolution(const Dimensions& dims) { - return {dims.width, dims.height, static_cast<double>(dims.frame_rate)}; -} - void ApplyConstraints(const Constraints& constraints, Recommendations* recommendations) { // Audio has no fields in the display description, so we can safely @@ -109,17 +86,18 @@ void ApplyConstraints(const Constraints& constraints, recommendations->video.bit_rate_limits.minimum), std::min(constraints.video.max_bit_rate, recommendations->video.bit_rate_limits.maximum)}; - Resolution max = ToResolution(constraints.video.max_dimensions); - if (max <= kDefaultMinResolution) { - recommendations->video.maximum = kDefaultMinResolution; - } else if (max < recommendations->video.maximum) { - recommendations->video.maximum = std::move(max); + Dimensions dimensions = constraints.video.max_dimensions; + if (dimensions.width <= kDefaultMinResolution.width) { + recommendations->video.maximum = {kDefaultMinResolution.width, + kDefaultMinResolution.height, + kDefaultFrameRate}; + } else if (dimensions.width < recommendations->video.maximum.width) { + recommendations->video.maximum = std::move(dimensions); } - // Implicit else: maximum = kDefaultMaxResolution. - if (constraints.video.min_dimensions) { - Resolution min = ToResolution(constraints.video.min_dimensions.value()); - if (kDefaultMinResolution < min) { + if (constraints.video.min_resolution) { + const Resolution& min = constraints.video.min_resolution->ToResolution(); + if (kDefaultMinResolution.width < min.width) { recommendations->video.minimum = std::move(min); } } @@ -137,25 +115,6 @@ bool Audio::operator==(const Audio& other) const { other.max_sample_rate); } -bool Resolution::operator==(const Resolution& other) const { - return (std::tie(width, height) == std::tie(other.width, other.height)) && - DoubleEquals(frame_rate, other.frame_rate); -} - -bool Resolution::operator<(const Resolution& other) const { - return effective_bit_rate() < other.effective_bit_rate(); -} - -bool Resolution::operator<=(const Resolution& other) const { - return (*this == other) || (*this < other); -} - -void Resolution::set_minimum(const Resolution& other) { - if (other < *this) { - *this = other; - } -} - bool Video::operator==(const Video& other) const { return std::tie(bit_rate_limits, minimum, maximum, supports_scaling, max_delay, max_pixels_per_second) == diff --git a/cast/streaming/capture_recommendations.h b/cast/streaming/capture_recommendations.h index ccb2475b..603b6098 100644 --- a/cast/streaming/capture_recommendations.h +++ b/cast/streaming/capture_recommendations.h @@ -11,7 +11,7 @@ #include <tuple> #include "cast/streaming/constants.h" - +#include "cast/streaming/resolution.h" namespace openscreen { namespace cast { @@ -80,30 +80,12 @@ struct Audio { int min_sample_rate = kDefaultAudioMinSampleRate; }; -struct Resolution { - bool operator==(const Resolution& other) const; - bool operator<(const Resolution& other) const; - bool operator<=(const Resolution& other) const; - void set_minimum(const Resolution& other); - - // The effective bit rate is the predicted average bit rate based on the - // properties of the Resolution instance, and is currently just the product. - constexpr int effective_bit_rate() const { - return static_cast<int>(static_cast<double>(width * height) * frame_rate); - } - - int width; - int height; - double frame_rate; -}; - // The minimum dimensions are as close as possible to low-definition // television, factoring in the receiver's aspect ratio if provided. -constexpr Resolution kDefaultMinResolution{kMinVideoWidth, kMinVideoHeight, - kDefaultFrameRate}; +constexpr Resolution kDefaultMinResolution{kMinVideoWidth, kMinVideoHeight}; // Currently mirroring only supports 1080P. -constexpr Resolution kDefaultMaxResolution{1920, 1080, kDefaultFrameRate}; +constexpr Dimensions kDefaultMaxResolution{1920, 1080, kDefaultFrameRate}; // The mirroring spec suggests 300kbps as the absolute minimum bitrate. constexpr int kDefaultVideoMinBitRate = 300 * 1000; @@ -117,7 +99,7 @@ constexpr int kDefaultVideoMaxPixelsPerSecond = // Our default limits are merely the product of the minimum and maximum // dimensions, and are only used if the receiver fails to give better // constraint information. -constexpr BitRateLimits kDefaultVideoBitRateLimits{ +const BitRateLimits kDefaultVideoBitRateLimits{ kDefaultVideoMinBitRate, kDefaultMaxResolution.effective_bit_rate()}; // Video capture recommendations. @@ -131,7 +113,7 @@ struct Video { Resolution minimum = kDefaultMinResolution; // Represents the recommended maximum resolution. - Resolution maximum = kDefaultMaxResolution; + Dimensions maximum = kDefaultMaxResolution; // Indicates whether the receiver can scale frames from a different aspect // ratio, or if it needs to be done by the sender. Default is false, meaning diff --git a/cast/streaming/capture_recommendations_unittest.cc b/cast/streaming/capture_recommendations_unittest.cc index 4f76b9d9..872b62e3 100644 --- a/cast/streaming/capture_recommendations_unittest.cc +++ b/cast/streaming/capture_recommendations_unittest.cc @@ -6,6 +6,7 @@ #include "absl/types/optional.h" #include "cast/streaming/answer_messages.h" +#include "cast/streaming/resolution.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "util/chrono_helpers.h" @@ -15,64 +16,64 @@ namespace cast { namespace capture_recommendations { namespace { -constexpr Recommendations kDefaultRecommendations{ +const Recommendations kDefaultRecommendations{ Audio{BitRateLimits{32000, 256000}, milliseconds(400), 2, 48000, 16000}, - Video{BitRateLimits{300000, 1920 * 1080 * 30}, Resolution{320, 240, 30}, - Resolution{1920, 1080, 30}, false, milliseconds(400), + Video{BitRateLimits{300000, 1920 * 1080 * 30}, Resolution{320, 240}, + Dimensions{1920, 1080, 30}, false, milliseconds(400), 1920 * 1080 * 30 / 8}}; -constexpr DisplayDescription kEmptyDescription{}; +const DisplayDescription kEmptyDescription{}; -constexpr DisplayDescription kValidOnlyResolution{ +const DisplayDescription kValidOnlyResolution{ Dimensions{1024, 768, SimpleFraction{60, 1}}, absl::nullopt, absl::nullopt}; -constexpr DisplayDescription kValidOnlyAspectRatio{ - absl::nullopt, AspectRatio{4, 3}, absl::nullopt}; +const DisplayDescription kValidOnlyAspectRatio{absl::nullopt, AspectRatio{4, 3}, + absl::nullopt}; -constexpr DisplayDescription kValidOnlyAspectRatioSixteenNine{ +const DisplayDescription kValidOnlyAspectRatioSixteenNine{ absl::nullopt, AspectRatio{16, 9}, absl::nullopt}; -constexpr DisplayDescription kValidOnlyVariable{ - absl::nullopt, absl::nullopt, AspectRatioConstraint::kVariable}; +const DisplayDescription kValidOnlyVariable{absl::nullopt, absl::nullopt, + AspectRatioConstraint::kVariable}; -constexpr DisplayDescription kInvalidOnlyFixed{absl::nullopt, absl::nullopt, - AspectRatioConstraint::kFixed}; +const DisplayDescription kInvalidOnlyFixed{absl::nullopt, absl::nullopt, + AspectRatioConstraint::kFixed}; -constexpr DisplayDescription kValidFixedAspectRatio{ +const DisplayDescription kValidFixedAspectRatio{ absl::nullopt, AspectRatio{4, 3}, AspectRatioConstraint::kFixed}; -constexpr DisplayDescription kValidVariableAspectRatio{ +const DisplayDescription kValidVariableAspectRatio{ absl::nullopt, AspectRatio{4, 3}, AspectRatioConstraint::kVariable}; -constexpr DisplayDescription kValidFixedMissingAspectRatio{ +const DisplayDescription kValidFixedMissingAspectRatio{ Dimensions{1024, 768, SimpleFraction{60, 1}}, absl::nullopt, AspectRatioConstraint::kFixed}; -constexpr DisplayDescription kValidDisplayFhd{ +const DisplayDescription kValidDisplayFhd{ Dimensions{1920, 1080, SimpleFraction{30, 1}}, AspectRatio{16, 9}, AspectRatioConstraint::kVariable}; -constexpr DisplayDescription kValidDisplayXga{ +const DisplayDescription kValidDisplayXga{ Dimensions{1024, 768, SimpleFraction{60, 1}}, AspectRatio{4, 3}, AspectRatioConstraint::kFixed}; -constexpr DisplayDescription kValidDisplayTiny{ +const DisplayDescription kValidDisplayTiny{ Dimensions{300, 200, SimpleFraction{30, 1}}, AspectRatio{3, 2}, AspectRatioConstraint::kFixed}; -constexpr DisplayDescription kValidDisplayMismatched{ +const DisplayDescription kValidDisplayMismatched{ Dimensions{300, 200, SimpleFraction{30, 1}}, AspectRatio{3, 4}, AspectRatioConstraint::kFixed}; -constexpr Constraints kEmptyConstraints{}; +const Constraints kEmptyConstraints{}; -constexpr Constraints kValidConstraintsHighEnd{ +const Constraints kValidConstraintsHighEnd{ {96100, 5, 96000, 500000, std::chrono::seconds(6)}, {6000000, Dimensions{640, 480, SimpleFraction{30, 1}}, Dimensions{3840, 2160, SimpleFraction{144, 1}}, 600000, 6000000, std::chrono::seconds(6)}}; -constexpr Constraints kValidConstraintsLowEnd{ +const Constraints kValidConstraintsLowEnd{ {22000, 2, 24000, 50000, std::chrono::seconds(1)}, {60000, Dimensions{120, 80, SimpleFraction{10, 1}}, Dimensions{1200, 800, SimpleFraction{30, 1}}, 100000, 1000000, @@ -92,7 +93,7 @@ TEST(CaptureRecommendationsTest, EmptyDisplayDescription) { TEST(CaptureRecommendationsTest, OnlyResolution) { Recommendations expected = kDefaultRecommendations; - expected.video.maximum = Resolution{1024, 768, 60.0}; + expected.video.maximum = Dimensions{1024, 768, 60.0}; expected.video.bit_rate_limits.maximum = 47185920; Answer answer; answer.display = kValidOnlyResolution; @@ -101,8 +102,8 @@ TEST(CaptureRecommendationsTest, OnlyResolution) { TEST(CaptureRecommendationsTest, OnlyAspectRatioFourThirds) { Recommendations expected = kDefaultRecommendations; - expected.video.minimum = Resolution{320, 240, 30.0}; - expected.video.maximum = Resolution{1440, 1080, 30.0}; + expected.video.minimum = Resolution{320, 240}; + expected.video.maximum = Dimensions{1440, 1080, 30.0}; Answer answer; answer.display = kValidOnlyAspectRatio; @@ -111,8 +112,8 @@ TEST(CaptureRecommendationsTest, OnlyAspectRatioFourThirds) { TEST(CaptureRecommendationsTest, OnlyAspectRatioSixteenNine) { Recommendations expected = kDefaultRecommendations; - expected.video.minimum = Resolution{426, 240, 30.0}; - expected.video.maximum = Resolution{1920, 1080, 30.0}; + expected.video.minimum = Resolution{426, 240}; + expected.video.maximum = Dimensions{1920, 1080, 30.0}; Answer answer; answer.display = kValidOnlyAspectRatioSixteenNine; @@ -139,8 +140,8 @@ TEST(CaptureRecommendationsTest, OnlyInvalidAspectRatioConstraint) { TEST(CaptureRecommendationsTest, FixedAspectRatioConstraint) { Recommendations expected = kDefaultRecommendations; - expected.video.minimum = Resolution{320, 240, 30.0}; - expected.video.maximum = Resolution{1440, 1080, 30.0}; + expected.video.minimum = Resolution{320, 240}; + expected.video.maximum = Dimensions{1440, 1080, 30.0}; expected.video.supports_scaling = false; Answer answer; answer.display = kValidFixedAspectRatio; @@ -152,8 +153,8 @@ TEST(CaptureRecommendationsTest, FixedAspectRatioConstraint) { // frame sizes between minimum and maximum can be properly scaled. TEST(CaptureRecommendationsTest, VariableAspectRatioConstraint) { Recommendations expected = kDefaultRecommendations; - expected.video.minimum = Resolution{320, 240, 30.0}; - expected.video.maximum = Resolution{1440, 1080, 30.0}; + expected.video.minimum = Resolution{320, 240}; + expected.video.maximum = Dimensions{1440, 1080, 30.0}; expected.video.supports_scaling = true; Answer answer; answer.display = kValidVariableAspectRatio; @@ -162,8 +163,8 @@ TEST(CaptureRecommendationsTest, VariableAspectRatioConstraint) { TEST(CaptureRecommendationsTest, ResolutionWithFixedConstraint) { Recommendations expected = kDefaultRecommendations; - expected.video.minimum = Resolution{320, 240, 30.0}; - expected.video.maximum = Resolution{1024, 768, 60.0}; + expected.video.minimum = Resolution{320, 240}; + expected.video.maximum = Dimensions{1024, 768, 60.0}; expected.video.supports_scaling = false; expected.video.bit_rate_limits.maximum = 47185920; Answer answer; @@ -173,7 +174,7 @@ TEST(CaptureRecommendationsTest, ResolutionWithFixedConstraint) { TEST(CaptureRecommendationsTest, ExplicitFhdChangesMinimum) { Recommendations expected = kDefaultRecommendations; - expected.video.minimum = Resolution{426, 240, 30.0}; + expected.video.minimum = Resolution{426, 240}; expected.video.supports_scaling = true; Answer answer; answer.display = kValidDisplayFhd; @@ -182,8 +183,8 @@ TEST(CaptureRecommendationsTest, ExplicitFhdChangesMinimum) { TEST(CaptureRecommendationsTest, XgaResolution) { Recommendations expected = kDefaultRecommendations; - expected.video.minimum = Resolution{320, 240, 30.0}; - expected.video.maximum = Resolution{1024, 768, 60.0}; + expected.video.minimum = Resolution{320, 240}; + expected.video.maximum = Dimensions{1024, 768, 60.0}; expected.video.supports_scaling = false; expected.video.bit_rate_limits.maximum = 47185920; Answer answer; @@ -193,8 +194,8 @@ TEST(CaptureRecommendationsTest, XgaResolution) { TEST(CaptureRecommendationsTest, MismatchedDisplayAndAspectRatio) { Recommendations expected = kDefaultRecommendations; - expected.video.minimum = Resolution{150, 200, 30.0}; - expected.video.maximum = Resolution{150, 200, 30.0}; + expected.video.minimum = Resolution{150, 200}; + expected.video.maximum = Dimensions{150, 200, 30.0}; expected.video.supports_scaling = false; expected.video.bit_rate_limits.maximum = 300 * 200 * 30; Answer answer; @@ -204,8 +205,8 @@ TEST(CaptureRecommendationsTest, MismatchedDisplayAndAspectRatio) { TEST(CaptureRecommendationsTest, TinyDisplay) { Recommendations expected = kDefaultRecommendations; - expected.video.minimum = Resolution{300, 200, 30.0}; - expected.video.maximum = Resolution{300, 200, 30.0}; + expected.video.minimum = Resolution{300, 200}; + expected.video.maximum = Dimensions{300, 200, 30.0}; expected.video.supports_scaling = false; expected.video.bit_rate_limits.maximum = 300 * 200 * 30; Answer answer; @@ -225,8 +226,8 @@ TEST(CaptureRecommendationsTest, EmptyConstraints) { TEST(CaptureRecommendationsTest, HandlesHighEnd) { const Recommendations kExpected{ Audio{BitRateLimits{96000, 500000}, milliseconds(6000), 5, 96100, 16000}, - Video{BitRateLimits{600000, 6000000}, Resolution{640, 480, 30}, - Resolution{1920, 1080, 30}, false, milliseconds(6000), 6000000}}; + Video{BitRateLimits{600000, 6000000}, Resolution{640, 480}, + Dimensions{1920, 1080, 30}, false, milliseconds(6000), 6000000}}; Answer answer; answer.constraints = kValidConstraintsHighEnd; EXPECT_EQ(kExpected, GetRecommendations(answer)); @@ -238,8 +239,8 @@ TEST(CaptureRecommendationsTest, HandlesHighEnd) { TEST(CaptureRecommendationsTest, HandlesLowEnd) { const Recommendations kExpected{ Audio{BitRateLimits{32000, 50000}, milliseconds(1000), 2, 22000, 16000}, - Video{BitRateLimits{300000, 1000000}, Resolution{320, 240, 30}, - Resolution{1200, 800, 30}, false, milliseconds(1000), 60000}}; + Video{BitRateLimits{300000, 1000000}, Resolution{320, 240}, + Dimensions{1200, 800, 30}, false, milliseconds(1000), 60000}}; Answer answer; answer.constraints = kValidConstraintsLowEnd; EXPECT_EQ(kExpected, GetRecommendations(answer)); @@ -248,20 +249,20 @@ TEST(CaptureRecommendationsTest, HandlesLowEnd) { TEST(CaptureRecommendationsTest, HandlesTooSmallScreen) { const Recommendations kExpected{ Audio{BitRateLimits{32000, 50000}, milliseconds(1000), 2, 22000, 16000}, - Video{BitRateLimits{300000, 1000000}, Resolution{320, 240, 30}, - Resolution{320, 240, 30}, false, milliseconds(1000), 60000}}; + Video{BitRateLimits{300000, 1000000}, Resolution{320, 240}, + Dimensions{320, 240, 30}, false, milliseconds(1000), 60000}}; Answer answer; answer.constraints = kValidConstraintsLowEnd; answer.constraints->video.max_dimensions = - answer.constraints->video.min_dimensions.value(); + answer.constraints->video.min_resolution.value(); EXPECT_EQ(kExpected, GetRecommendations(answer)); } TEST(CaptureRecommendationsTest, HandlesMinimumSizeScreen) { const Recommendations kExpected{ Audio{BitRateLimits{32000, 50000}, milliseconds(1000), 2, 22000, 16000}, - Video{BitRateLimits{300000, 1000000}, Resolution{320, 240, 30}, - Resolution{320, 240, 30}, false, milliseconds(1000), 60000}}; + Video{BitRateLimits{300000, 1000000}, Resolution{320, 240}, + Dimensions{320, 240, 30}, false, milliseconds(1000), 60000}}; Answer answer; answer.constraints = kValidConstraintsLowEnd; answer.constraints->video.max_dimensions = @@ -272,11 +273,11 @@ TEST(CaptureRecommendationsTest, HandlesMinimumSizeScreen) { TEST(CaptureRecommendationsTest, UsesIntersectionOfDisplayAndConstraints) { const Recommendations kExpected{ Audio{BitRateLimits{96000, 500000}, milliseconds(6000), 5, 96100, 16000}, - Video{BitRateLimits{600000, 6000000}, Resolution{640, 480, 30}, + Video{BitRateLimits{600000, 6000000}, Resolution{640, 480}, // Max resolution should be 1080P, since that's the display // resolution. No reason to capture at 4K, even though the // receiver supports it. - Resolution{1920, 1080, 30}, true, milliseconds(6000), 6000000}}; + Dimensions{1920, 1080, 30}, true, milliseconds(6000), 6000000}}; Answer answer; answer.display = kValidDisplayFhd; answer.constraints = kValidConstraintsHighEnd; diff --git a/cast/streaming/compound_rtcp_parser.h b/cast/streaming/compound_rtcp_parser.h index c74bb3ec..a8bb2e39 100644 --- a/cast/streaming/compound_rtcp_parser.h +++ b/cast/streaming/compound_rtcp_parser.h @@ -37,7 +37,6 @@ class CompoundRtcpParser { class Client { public: Client(); - virtual ~Client(); // Called when a Receiver Reference Time Report has been parsed. virtual void OnReceiverReferenceTimeAdvanced( @@ -70,6 +69,9 @@ class CompoundRtcpParser { // kAllPacketsLost indicates that all the packets are missing for a frame. // The argument's elements are in monotonically increasing order. virtual void OnReceiverIsMissingPackets(std::vector<PacketNack> nacks); + + protected: + virtual ~Client(); }; // |session| and |client| must be non-null and must outlive the diff --git a/cast/streaming/compound_rtcp_parser_fuzzer.cc b/cast/streaming/compound_rtcp_parser_fuzzer.cc index bb3dd179..05994a3a 100644 --- a/cast/streaming/compound_rtcp_parser_fuzzer.cc +++ b/cast/streaming/compound_rtcp_parser_fuzzer.cc @@ -17,6 +17,11 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { constexpr Ssrc kSenderSsrcInSeedCorpus = 1; constexpr Ssrc kReceiverSsrcInSeedCorpus = 2; + class ClientThatIgnoresEverything : public CompoundRtcpParser::Client { + public: + ClientThatIgnoresEverything() = default; + ~ClientThatIgnoresEverything() override = default; + }; // Allocate the RtcpSession and CompoundRtcpParser statically (i.e., one-time // init) to improve the fuzzer's execution rate. This is because RtcpSession // also contains a NtpTimeConverter, which samples the system clock at @@ -26,7 +31,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { #pragma clang diagnostic ignored "-Wexit-time-destructors" static RtcpSession session(kSenderSsrcInSeedCorpus, kReceiverSsrcInSeedCorpus, openscreen::Clock::time_point{}); - static CompoundRtcpParser::Client client_that_ignores_everything; + static ClientThatIgnoresEverything client_that_ignores_everything; static CompoundRtcpParser parser(&session, &client_that_ignores_everything); #pragma clang diagnostic pop diff --git a/cast/streaming/constants.h b/cast/streaming/constants.h index 1075a817..03026620 100644 --- a/cast/streaming/constants.h +++ b/cast/streaming/constants.h @@ -17,6 +17,12 @@ namespace openscreen { namespace cast { +// Mirroring App identifier. +constexpr char kMirroringAppId[] = "0F5096E8"; + +// Mirroring App identifier for audio-only mirroring. +constexpr char kMirroringAudioOnlyAppId[] = "85CDB22F"; + // Default target playout delay. The playout delay is the window of time between // capture from the source until presentation at the receiver. constexpr std::chrono::milliseconds kDefaultTargetPlayoutDelay(400); @@ -58,6 +64,27 @@ constexpr int kMinVideoWidth = 320; // The default frame rate for capture options is 30FPS. constexpr int kDefaultFrameRate = 30; +// The mirroring spec suggests 300kbps as the absolute minimum bitrate. +constexpr int kDefaultVideoMinBitRate = 300 * 1000; + +// Default video max bitrate is based on 1080P @ 30FPS, which can be played back +// at good quality around 10mbps. +constexpr int kDefaultVideoMaxBitRate = 10 * 1000 * 1000; + +// The mirroring control protocol specifies 32kbps as the absolute minimum +// for audio. Depending on the type of audio content (narrowband, fullband, +// etc.) Opus specifically can perform very well at this bitrate. +// See: https://research.google/pubs/pub41650/ +constexpr int kDefaultAudioMinBitRate = 32 * 1000; + +// Opus generally sees little improvement above 192kbps, but some older codecs +// that we may consider supporting improve at up to 256kbps. +constexpr int kDefaultAudioMaxBitRate = 256 * 1000; + +// While generally audio should be captured at the maximum sample rate, 16kHz is +// the recommended absolute minimum. +constexpr int kDefaultAudioMinSampleRate = 16000; + // The default audio sample rate is 48kHz, slightly higher than standard // consumer audio. constexpr int kDefaultAudioSampleRate = 48000; @@ -65,12 +92,26 @@ constexpr int kDefaultAudioSampleRate = 48000; // The default audio number of channels is set to stereo. constexpr int kDefaultAudioChannels = 2; +// Default maximum delay for both audio and video. Used if the sender fails +// to provide any constraints. +constexpr std::chrono::milliseconds kDefaultMaxDelayMs(1500); + +// TODO(issuetracker.google.com/184189100): As part of updating remoting +// OFFER/ANSWER and capabilities exchange, remoting version should be updated +// to 3. +constexpr int kSupportedRemotingVersion = 2; + // Codecs known and understood by cast senders and receivers. Note: receivers // are required to implement the following codecs to be Cast V2 compliant: H264, -// VP8, AAC, Opus. Senders have to implement at least one codec for audio and -// video to start a session. -enum class AudioCodec { kAac, kOpus }; -enum class VideoCodec { kH264, kVp8, kHevc, kVp9 }; +// VP8, AAC, Opus. Senders have to implement at least one codec from this +// list for audio or video to start a session. +// |kNotSpecified| is used in remoting to indicate that the stream is being +// remoted and is not specified as part of the OFFER message (indicated as +// "REMOTE_AUDIO" or "REMOTE_VIDEO"). +enum class AudioCodec { kAac, kOpus, kNotSpecified }; +enum class VideoCodec { kH264, kVp8, kHevc, kNotSpecified, kVp9, kAv1 }; + +enum class CastMode : uint8_t { kMirroring, kRemoting }; } // namespace cast } // namespace openscreen diff --git a/cast/streaming/message_fields.cc b/cast/streaming/message_fields.cc index f199ab8d..4411c80d 100644 --- a/cast/streaming/message_fields.cc +++ b/cast/streaming/message_fields.cc @@ -14,14 +14,18 @@ namespace openscreen { namespace cast { namespace { -constexpr EnumNameTable<AudioCodec, 2> kAudioCodecNames{ - {{"aac", AudioCodec::kAac}, {"opus", AudioCodec::kOpus}}}; +constexpr EnumNameTable<AudioCodec, 3> kAudioCodecNames{ + {{"aac", AudioCodec::kAac}, + {"opus", AudioCodec::kOpus}, + {"REMOTE_AUDIO", AudioCodec::kNotSpecified}}}; -constexpr EnumNameTable<VideoCodec, 4> kVideoCodecNames{ +constexpr EnumNameTable<VideoCodec, 6> kVideoCodecNames{ {{"h264", VideoCodec::kH264}, {"vp8", VideoCodec::kVp8}, {"hevc", VideoCodec::kHevc}, - {"vp9", VideoCodec::kVp9}}}; + {"REMOTE_VIDEO", VideoCodec::kNotSpecified}, + {"vp9", VideoCodec::kVp9}, + {"av1", VideoCodec::kAv1}}}; } // namespace diff --git a/cast/streaming/message_fields.h b/cast/streaming/message_fields.h index 524a0135..2d1cb969 100644 --- a/cast/streaming/message_fields.h +++ b/cast/streaming/message_fields.h @@ -28,6 +28,7 @@ constexpr char kMessageType[] = "type"; constexpr char kMessageTypeOffer[] = "OFFER"; constexpr char kOfferMessageBody[] = "offer"; constexpr char kSequenceNumber[] = "seqNum"; +constexpr char kCodecName[] = "codecName"; /// ANSWER message fields. constexpr char kMessageTypeAnswer[] = "ANSWER"; diff --git a/cast/streaming/offer_messages.cc b/cast/streaming/offer_messages.cc index cea500cd..a162f09f 100644 --- a/cast/streaming/offer_messages.cc +++ b/cast/streaming/offer_messages.cc @@ -14,7 +14,6 @@ #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" -#include "cast/streaming/capture_recommendations.h" #include "cast/streaming/constants.h" #include "platform/base/error.h" #include "util/big_endian.h" @@ -34,36 +33,78 @@ constexpr char kAudioSourceType[] = "audio_source"; constexpr char kVideoSourceType[] = "video_source"; constexpr char kStreamType[] = "type"; -ErrorOr<RtpPayloadType> ParseRtpPayloadType(const Json::Value& parent, - const std::string& field) { - auto t = json::ParseInt(parent, field); - if (!t) { - return t.error(); +bool CodecParameterIsValid(VideoCodec codec, + const std::string& codec_parameter) { + if (codec_parameter.empty()) { + return true; + } + switch (codec) { + case VideoCodec::kVp8: + return absl::StartsWith(codec_parameter, "vp08"); + case VideoCodec::kVp9: + return absl::StartsWith(codec_parameter, "vp09"); + case VideoCodec::kAv1: + return absl::StartsWith(codec_parameter, "av01"); + case VideoCodec::kHevc: + return absl::StartsWith(codec_parameter, "hev1"); + case VideoCodec::kH264: + return absl::StartsWith(codec_parameter, "avc1"); + case VideoCodec::kNotSpecified: + return false; + } + OSP_NOTREACHED(); +} + +bool CodecParameterIsValid(AudioCodec codec, + const std::string& codec_parameter) { + if (codec_parameter.empty()) { + return true; + } + switch (codec) { + case AudioCodec::kAac: + return absl::StartsWith(codec_parameter, "mp4a."); + + // Opus doesn't use codec parameters. + case AudioCodec::kOpus: // fallthrough + case AudioCodec::kNotSpecified: + return false; } + OSP_NOTREACHED(); +} + +EnumNameTable<CastMode, 2> kCastModeNames{ + {{"mirroring", CastMode::kMirroring}, {"remoting", CastMode::kRemoting}}}; - uint8_t t_small = t.value(); - if (t_small != t.value() || !IsRtpPayloadType(t_small)) { - return Error(Error::Code::kParameterInvalid, - "Received invalid RTP Payload Type."); +bool TryParseRtpPayloadType(const Json::Value& value, RtpPayloadType* out) { + int t; + if (!json::TryParseInt(value, &t)) { + return false; } - return static_cast<RtpPayloadType>(t_small); + uint8_t t_small = t; + if (t_small != t || !IsRtpPayloadType(t_small)) { + return false; + } + + *out = static_cast<RtpPayloadType>(t_small); + return true; } -ErrorOr<int> ParseRtpTimebase(const Json::Value& parent, - const std::string& field) { - auto error_or_raw = json::ParseString(parent, field); - if (!error_or_raw) { - return error_or_raw.error(); +bool TryParseRtpTimebase(const Json::Value& value, int* out) { + std::string raw_timebase; + if (!json::TryParseString(value, &raw_timebase)) { + return false; } // The spec demands a leading 1, so this isn't really a fraction. - const auto fraction = SimpleFraction::FromString(error_or_raw.value()); + const auto fraction = SimpleFraction::FromString(raw_timebase); if (fraction.is_error() || !fraction.value().is_positive() || - fraction.value().numerator != 1) { - return json::CreateParseError("RTP timebase"); + fraction.value().numerator() != 1) { + return false; } - return fraction.value().denominator; + + *out = fraction.value().denominator(); + return true; } // For a hex byte, the conversion is 4 bits to 1 character, e.g. @@ -71,226 +112,118 @@ ErrorOr<int> ParseRtpTimebase(const Json::Value& parent, constexpr int kHexDigitsPerByte = 2; constexpr int kAesBytesSize = 16; constexpr int kAesStringLength = kAesBytesSize * kHexDigitsPerByte; -ErrorOr<std::array<uint8_t, kAesBytesSize>> ParseAesHexBytes( - const Json::Value& parent, - const std::string& field) { - auto hex_string = json::ParseString(parent, field); - if (!hex_string) { - return hex_string.error(); +bool TryParseAesHexBytes(const Json::Value& value, + std::array<uint8_t, kAesBytesSize>* out) { + std::string hex_string; + if (!json::TryParseString(value, &hex_string)) { + return false; } constexpr int kHexDigitsPerScanField = 16; constexpr int kNumScanFields = kAesStringLength / kHexDigitsPerScanField; uint64_t quads[kNumScanFields]; int chars_scanned; - if (hex_string.value().size() == kAesStringLength && - sscanf(hex_string.value().c_str(), "%16" SCNx64 "%16" SCNx64 "%n", - &quads[0], &quads[1], &chars_scanned) == kNumScanFields && + if (hex_string.size() == kAesStringLength && + sscanf(hex_string.c_str(), "%16" SCNx64 "%16" SCNx64 "%n", &quads[0], + &quads[1], &chars_scanned) == kNumScanFields && chars_scanned == kAesStringLength && - std::none_of(hex_string.value().begin(), hex_string.value().end(), + std::none_of(hex_string.begin(), hex_string.end(), [](char c) { return std::isspace(c); })) { - std::array<uint8_t, kAesBytesSize> bytes; - WriteBigEndian(quads[0], bytes.data()); - WriteBigEndian(quads[1], bytes.data() + 8); - return bytes; - } - return json::CreateParseError("AES hex string bytes"); -} - -ErrorOr<Stream> ParseStream(const Json::Value& value, Stream::Type type) { - auto index = json::ParseInt(value, "index"); - if (!index) { - return index.error(); - } - // If channel is omitted, the default value is used later. - auto channels = json::ParseInt(value, "channels"); - if (channels.is_value() && channels.value() <= 0) { - return json::CreateParameterError("channel"); - } - auto rtp_profile = json::ParseString(value, "rtpProfile"); - if (!rtp_profile) { - return rtp_profile.error(); - } - auto rtp_payload_type = ParseRtpPayloadType(value, "rtpPayloadType"); - if (!rtp_payload_type) { - return rtp_payload_type.error(); - } - auto ssrc = json::ParseUint(value, "ssrc"); - if (!ssrc) { - return ssrc.error(); - } - auto aes_key = ParseAesHexBytes(value, "aesKey"); - auto aes_iv_mask = ParseAesHexBytes(value, "aesIvMask"); - if (!aes_key || !aes_iv_mask) { - return Error(Error::Code::kUnencryptedOffer, - "Offer stream must have both a valid aesKey and aesIvMask"); - } - auto rtp_timebase = ParseRtpTimebase(value, "timeBase"); - if (!rtp_timebase) { - return rtp_timebase.error(); - } - if (rtp_timebase.value() < - std::min(capture_recommendations::kDefaultAudioMinSampleRate, - kRtpVideoTimebase) || - rtp_timebase.value() > kRtpVideoTimebase) { - return json::CreateParameterError("rtp_timebase (sample rate)"); + WriteBigEndian(quads[0], out->data()); + WriteBigEndian(quads[1], out->data() + 8); + return true; } - auto target_delay = json::ParseInt(value, "targetDelay"); - std::chrono::milliseconds target_delay_ms = kDefaultTargetPlayoutDelay; - if (target_delay) { - auto d = std::chrono::milliseconds(target_delay.value()); - if (kMinTargetPlayoutDelay <= d && d <= kMaxTargetPlayoutDelay) { - target_delay_ms = d; - } - } - - auto receiver_rtcp_event_log = json::ParseBool(value, "receiverRtcpEventLog"); - auto receiver_rtcp_dscp = json::ParseString(value, "receiverRtcpDscp"); - return Stream{index.value(), - type, - channels.value(type == Stream::Type::kAudioSource - ? kDefaultNumAudioChannels - : kDefaultNumVideoChannels), - rtp_payload_type.value(), - ssrc.value(), - target_delay_ms, - aes_key.value(), - aes_iv_mask.value(), - receiver_rtcp_event_log.value({}), - receiver_rtcp_dscp.value({}), - rtp_timebase.value()}; + return false; } -ErrorOr<AudioStream> ParseAudioStream(const Json::Value& value) { - auto stream = ParseStream(value, Stream::Type::kAudioSource); - if (!stream) { - return stream.error(); - } - auto bit_rate = json::ParseInt(value, "bitRate"); - if (!bit_rate) { - return bit_rate.error(); - } - - auto codec_name = json::ParseString(value, "codecName"); - if (!codec_name) { - return codec_name.error(); - } - ErrorOr<AudioCodec> codec = StringToAudioCodec(codec_name.value()); - if (!codec) { - return Error(Error::Code::kUnknownCodec, - "Codec is not known, can't use stream"); - } - - // A bit rate of 0 is valid for some codec types, so we don't enforce here. - if (bit_rate.value() < 0) { - return json::CreateParameterError("bit rate"); +absl::string_view ToString(Stream::Type type) { + switch (type) { + case Stream::Type::kAudioSource: + return kAudioSourceType; + case Stream::Type::kVideoSource: + return kVideoSourceType; + default: { + OSP_NOTREACHED(); + } } - return AudioStream{stream.value(), codec.value(), bit_rate.value()}; } -ErrorOr<Resolution> ParseResolution(const Json::Value& value) { - auto width = json::ParseInt(value, "width"); - if (!width) { - return width.error(); - } - auto height = json::ParseInt(value, "height"); - if (!height) { - return height.error(); - } - if (width.value() <= 0 || height.value() <= 0) { - return json::CreateParameterError("resolution"); - } - return Resolution{width.value(), height.value()}; -} +bool TryParseResolutions(const Json::Value& value, + std::vector<Resolution>* out) { + out->clear(); -ErrorOr<std::vector<Resolution>> ParseResolutions(const Json::Value& parent, - const std::string& field) { - std::vector<Resolution> resolutions; // Some legacy senders don't provide resolutions, so just return empty. - const Json::Value& value = parent[field]; if (!value.isArray() || value.empty()) { - return resolutions; + return false; } for (Json::ArrayIndex i = 0; i < value.size(); ++i) { - auto r = ParseResolution(value[i]); - if (!r) { - return r.error(); + Resolution resolution; + if (!Resolution::TryParse(value[i], &resolution)) { + out->clear(); + return false; } - resolutions.push_back(r.value()); + out->push_back(std::move(resolution)); } - return resolutions; + return true; } -ErrorOr<VideoStream> ParseVideoStream(const Json::Value& value) { - auto stream = ParseStream(value, Stream::Type::kVideoSource); - if (!stream) { - return stream.error(); +} // namespace + +Error Stream::TryParse(const Json::Value& value, + Stream::Type type, + Stream* out) { + out->type = type; + + if (!json::TryParseInt(value["index"], &out->index) || + !json::TryParseUint(value["ssrc"], &out->ssrc) || + !TryParseRtpPayloadType(value["rtpPayloadType"], + &out->rtp_payload_type) || + !TryParseRtpTimebase(value["timeBase"], &out->rtp_timebase)) { + return Error(Error::Code::kJsonParseError, + "Offer stream has missing or invalid mandatory field"); } - auto codec_name = json::ParseString(value, "codecName"); - if (!codec_name) { - return codec_name.error(); + + if (!json::TryParseInt(value["channels"], &out->channels)) { + out->channels = out->type == Stream::Type::kAudioSource + ? kDefaultNumAudioChannels + : kDefaultNumVideoChannels; + } else if (out->channels <= 0) { + return Error(Error::Code::kJsonParseError, "Invalid channel count"); } - ErrorOr<VideoCodec> codec = StringToVideoCodec(codec_name.value()); - if (!codec) { - return Error(Error::Code::kUnknownCodec, - "Codec is not known, can't use stream"); + + if (!TryParseAesHexBytes(value["aesKey"], &out->aes_key) || + !TryParseAesHexBytes(value["aesIvMask"], &out->aes_iv_mask)) { + return Error(Error::Code::kUnencryptedOffer, + "Offer stream must have both a valid aesKey and aesIvMask"); } - auto resolutions = ParseResolutions(value, "resolutions"); - if (!resolutions) { - return resolutions.error(); + if (out->rtp_timebase < + std::min(kDefaultAudioMinSampleRate, kRtpVideoTimebase) || + out->rtp_timebase > kRtpVideoTimebase) { + return Error(Error::Code::kJsonParseError, "rtp_timebase (sample rate)"); } - auto raw_max_frame_rate = json::ParseString(value, "maxFrameRate"); - SimpleFraction max_frame_rate{kDefaultMaxFrameRate, 1}; - if (raw_max_frame_rate.is_value()) { - auto parsed = SimpleFraction::FromString(raw_max_frame_rate.value()); - if (parsed.is_value() && parsed.value().is_positive()) { - max_frame_rate = parsed.value(); + out->target_delay = kDefaultTargetPlayoutDelay; + int target_delay; + if (json::TryParseInt(value["targetDelay"], &target_delay)) { + auto d = std::chrono::milliseconds(target_delay); + if (kMinTargetPlayoutDelay <= d && d <= kMaxTargetPlayoutDelay) { + out->target_delay = d; } } - auto profile = json::ParseString(value, "profile"); - auto protection = json::ParseString(value, "protection"); - auto max_bit_rate = json::ParseInt(value, "maxBitRate"); - auto level = json::ParseString(value, "level"); - auto error_recovery_mode = json::ParseString(value, "errorRecoveryMode"); - return VideoStream{stream.value(), - codec.value(), - max_frame_rate, - max_bit_rate.value(4 << 20), - protection.value({}), - profile.value({}), - level.value({}), - resolutions.value(), - error_recovery_mode.value({})}; -} + json::TryParseBool(value["receiverRtcpEventLog"], + &out->receiver_rtcp_event_log); + json::TryParseString(value["receiverRtcpDscp"], &out->receiver_rtcp_dscp); + json::TryParseString(value["codecParameter"], &out->codec_parameter); -absl::string_view ToString(Stream::Type type) { - switch (type) { - case Stream::Type::kAudioSource: - return kAudioSourceType; - case Stream::Type::kVideoSource: - return kVideoSourceType; - default: { - OSP_NOTREACHED(); - } - } + return Error::None(); } -EnumNameTable<CastMode, 2> kCastModeNames{ - {{"mirroring", CastMode::kMirroring}, {"remoting", CastMode::kRemoting}}}; - -} // namespace - -ErrorOr<Json::Value> Stream::ToJson() const { - if (channels < 1 || index < 0 || target_delay.count() <= 0 || - target_delay.count() > std::numeric_limits<int>::max() || - rtp_timebase < 1) { - return json::CreateParameterError("Stream"); - } +Json::Value Stream::ToJson() const { + OSP_DCHECK(IsValid()); Json::Value root; root["index"] = index; @@ -304,152 +237,212 @@ ErrorOr<Json::Value> Stream::ToJson() const { "this code assumes Ssrc fits in a Json::UInt"); root["ssrc"] = static_cast<Json::UInt>(ssrc); root["targetDelay"] = static_cast<int>(target_delay.count()); - root["aesKey"] = HexEncode(aes_key); - root["aesIvMask"] = HexEncode(aes_iv_mask); + root["aesKey"] = HexEncode(aes_key.data(), aes_key.size()); + root["aesIvMask"] = HexEncode(aes_iv_mask.data(), aes_iv_mask.size()); root["receiverRtcpEventLog"] = receiver_rtcp_event_log; root["receiverRtcpDscp"] = receiver_rtcp_dscp; root["timeBase"] = "1/" + std::to_string(rtp_timebase); + root["codecParameter"] = codec_parameter; return root; } -ErrorOr<Json::Value> AudioStream::ToJson() const { - // A bit rate of 0 is valid for some codec types, so we don't enforce here. - if (bit_rate < 0) { - return json::CreateParameterError("AudioStream"); +bool Stream::IsValid() const { + return channels >= 1 && index >= 0 && target_delay.count() > 0 && + target_delay.count() <= std::numeric_limits<int>::max() && + rtp_timebase >= 1; +} + +Error AudioStream::TryParse(const Json::Value& value, AudioStream* out) { + Error error = + Stream::TryParse(value, Stream::Type::kAudioSource, &out->stream); + if (!error.ok()) { + return error; } - auto error_or_stream = stream.ToJson(); - if (error_or_stream.is_error()) { - return error_or_stream; + std::string codec_name; + if (!json::TryParseInt(value["bitRate"], &out->bit_rate) || + out->bit_rate < 0 || + !json::TryParseString(value[kCodecName], &codec_name)) { + return Error(Error::Code::kJsonParseError, "Invalid audio stream field"); } + ErrorOr<AudioCodec> codec = StringToAudioCodec(codec_name); + if (!codec) { + return Error(Error::Code::kUnknownCodec, + "Codec is not known, can't use stream"); + } + out->codec = codec.value(); + if (!CodecParameterIsValid(codec.value(), out->stream.codec_parameter)) { + return Error(Error::Code::kInvalidCodecParameter, + StringPrintf("Invalid audio codec parameter (%s for codec %s)", + out->stream.codec_parameter.c_str(), + CodecToString(codec.value()))); + } + return Error::None(); +} + +Json::Value AudioStream::ToJson() const { + OSP_DCHECK(IsValid()); + + Json::Value out = stream.ToJson(); + out[kCodecName] = CodecToString(codec); + out["bitRate"] = bit_rate; + return out; +} - error_or_stream.value()["codecName"] = CodecToString(codec); - error_or_stream.value()["bitRate"] = bit_rate; - return error_or_stream; +bool AudioStream::IsValid() const { + return bit_rate >= 0 && stream.IsValid(); } -ErrorOr<Json::Value> Resolution::ToJson() const { - if (width <= 0 || height <= 0) { - return json::CreateParameterError("Resolution"); +Error VideoStream::TryParse(const Json::Value& value, VideoStream* out) { + Error error = + Stream::TryParse(value, Stream::Type::kVideoSource, &out->stream); + if (!error.ok()) { + return error; } - Json::Value root; - root["width"] = width; - root["height"] = height; - return root; -} + std::string codec_name; + if (!json::TryParseString(value[kCodecName], &codec_name)) { + return Error(Error::Code::kJsonParseError, "Video stream missing codec"); + } + ErrorOr<VideoCodec> codec = StringToVideoCodec(codec_name); + if (!codec) { + return Error(Error::Code::kUnknownCodec, + "Codec is not known, can't use stream"); + } + out->codec = codec.value(); + if (!CodecParameterIsValid(codec.value(), out->stream.codec_parameter)) { + return Error(Error::Code::kInvalidCodecParameter, + StringPrintf("Invalid video codec parameter (%s for codec %s)", + out->stream.codec_parameter.c_str(), + CodecToString(codec.value()))); + } -ErrorOr<Json::Value> VideoStream::ToJson() const { - if (max_bit_rate <= 0 || !max_frame_rate.is_positive()) { - return json::CreateParameterError("VideoStream"); + out->max_frame_rate = SimpleFraction{kDefaultMaxFrameRate, 1}; + std::string raw_max_frame_rate; + if (json::TryParseString(value["maxFrameRate"], &raw_max_frame_rate)) { + auto parsed = SimpleFraction::FromString(raw_max_frame_rate); + if (parsed.is_value() && parsed.value().is_positive()) { + out->max_frame_rate = parsed.value(); + } } - auto error_or_stream = stream.ToJson(); - if (error_or_stream.is_error()) { - return error_or_stream; + TryParseResolutions(value["resolutions"], &out->resolutions); + json::TryParseString(value["profile"], &out->profile); + json::TryParseString(value["protection"], &out->protection); + json::TryParseString(value["level"], &out->level); + json::TryParseString(value["errorRecoveryMode"], &out->error_recovery_mode); + if (!json::TryParseInt(value["maxBitRate"], &out->max_bit_rate)) { + out->max_bit_rate = 4 << 20; } - auto& stream = error_or_stream.value(); - stream["codecName"] = CodecToString(codec); - stream["maxFrameRate"] = max_frame_rate.ToString(); - stream["maxBitRate"] = max_bit_rate; - stream["protection"] = protection; - stream["profile"] = profile; - stream["level"] = level; - stream["errorRecoveryMode"] = error_recovery_mode; + return Error::None(); +} + +Json::Value VideoStream::ToJson() const { + OSP_DCHECK(IsValid()); + + Json::Value out = stream.ToJson(); + out["codecName"] = CodecToString(codec); + out["maxFrameRate"] = max_frame_rate.ToString(); + out["maxBitRate"] = max_bit_rate; + out["protection"] = protection; + out["profile"] = profile; + out["level"] = level; + out["errorRecoveryMode"] = error_recovery_mode; Json::Value rs; for (auto resolution : resolutions) { - auto eoj = resolution.ToJson(); - if (eoj.is_error()) { - return eoj; - } - rs.append(eoj.value()); + rs.append(resolution.ToJson()); } - stream["resolutions"] = std::move(rs); - return error_or_stream; + out["resolutions"] = std::move(rs); + return out; +} + +bool VideoStream::IsValid() const { + return max_bit_rate > 0 && max_frame_rate.is_positive(); } // static ErrorOr<Offer> Offer::Parse(const Json::Value& root) { + Offer out; + Error error = TryParse(root, &out); + return error.ok() ? ErrorOr<Offer>(std::move(out)) + : ErrorOr<Offer>(std::move(error)); +} + +// static +Error Offer::TryParse(const Json::Value& root, Offer* out) { if (!root.isObject()) { - return json::CreateParseError("null offer"); + return Error(Error::Code::kJsonParseError, "null offer"); } - ErrorOr<CastMode> cast_mode = + const ErrorOr<CastMode> cast_mode = GetEnum(kCastModeNames, root["castMode"].asString()); - const ErrorOr<bool> get_status = json::ParseBool(root, "receiverGetStatus"); - Json::Value supported_streams = root[kSupportedStreams]; if (!supported_streams.isArray()) { - return json::CreateParseError("supported streams in offer"); + return Error(Error::Code::kJsonParseError, "supported streams in offer"); } std::vector<AudioStream> audio_streams; std::vector<VideoStream> video_streams; for (Json::ArrayIndex i = 0; i < supported_streams.size(); ++i) { const Json::Value& fields = supported_streams[i]; - auto type = json::ParseString(fields, kStreamType); - if (!type) { - return type.error(); + std::string type; + if (!json::TryParseString(fields[kStreamType], &type)) { + return Error(Error::Code::kJsonParseError, "Missing stream type"); } - if (type.value() == kAudioSourceType) { - auto stream = ParseAudioStream(fields); - if (!stream) { - if (stream.error().code() == Error::Code::kUnknownCodec) { - OSP_DVLOG << "Dropping audio stream due to unknown codec: " - << stream.error(); - continue; - } else { - return stream.error(); - } + Error error; + if (type == kAudioSourceType) { + AudioStream stream; + error = AudioStream::TryParse(fields, &stream); + if (error.ok()) { + audio_streams.push_back(std::move(stream)); + } + } else if (type == kVideoSourceType) { + VideoStream stream; + error = VideoStream::TryParse(fields, &stream); + if (error.ok()) { + video_streams.push_back(std::move(stream)); } - audio_streams.push_back(std::move(stream.value())); - } else if (type.value() == kVideoSourceType) { - auto stream = ParseVideoStream(fields); - if (!stream) { - if (stream.error().code() == Error::Code::kUnknownCodec) { - OSP_DVLOG << "Dropping video stream due to unknown codec: " - << stream.error(); - continue; - } else { - return stream.error(); - } + } + + if (!error.ok()) { + if (error.code() == Error::Code::kUnknownCodec) { + OSP_VLOG << "Dropping audio stream due to unknown codec: " << error; + continue; + } else { + return error; } - video_streams.push_back(std::move(stream.value())); } } - return Offer{cast_mode.value(CastMode::kMirroring), get_status.value({}), - std::move(audio_streams), std::move(video_streams)}; + *out = Offer{cast_mode.value(CastMode::kMirroring), std::move(audio_streams), + std::move(video_streams)}; + return Error::None(); } -ErrorOr<Json::Value> Offer::ToJson() const { +Json::Value Offer::ToJson() const { + OSP_DCHECK(IsValid()); Json::Value root; - root["castMode"] = GetEnumName(kCastModeNames, cast_mode).value(); - root["receiverGetStatus"] = supports_wifi_status_reporting; - Json::Value streams; - for (auto& as : audio_streams) { - auto eoj = as.ToJson(); - if (eoj.is_error()) { - return eoj; - } - streams.append(eoj.value()); + for (auto& stream : audio_streams) { + streams.append(stream.ToJson()); } - for (auto& vs : video_streams) { - auto eoj = vs.ToJson(); - if (eoj.is_error()) { - return eoj; - } - streams.append(eoj.value()); + for (auto& stream : video_streams) { + streams.append(stream.ToJson()); } root[kSupportedStreams] = std::move(streams); return root; } +bool Offer::IsValid() const { + return std::all_of(audio_streams.begin(), audio_streams.end(), + [](const AudioStream& a) { return a.IsValid(); }) && + std::all_of(video_streams.begin(), video_streams.end(), + [](const VideoStream& v) { return v.IsValid(); }); +} } // namespace cast } // namespace openscreen diff --git a/cast/streaming/offer_messages.h b/cast/streaming/offer_messages.h index f62c156d..765bda2a 100644 --- a/cast/streaming/offer_messages.h +++ b/cast/streaming/offer_messages.h @@ -12,6 +12,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "cast/streaming/message_fields.h" +#include "cast/streaming/resolution.h" #include "cast/streaming/rtp_defines.h" #include "cast/streaming/session_config.h" #include "json/value.h" @@ -45,7 +46,11 @@ constexpr int kDefaultNumAudioChannels = 2; struct Stream { enum class Type : uint8_t { kAudioSource, kVideoSource }; - ErrorOr<Json::Value> ToJson() const; + static Error TryParse(const Json::Value& root, + Stream::Type type, + Stream* out); + Json::Value ToJson() const; + bool IsValid() const; int index = 0; Type type = {}; @@ -60,52 +65,52 @@ struct Stream { // must be converted to a 16 digit byte array. std::array<uint8_t, 16> aes_key = {}; std::array<uint8_t, 16> aes_iv_mask = {}; - bool receiver_rtcp_event_log = {}; - std::string receiver_rtcp_dscp = {}; + bool receiver_rtcp_event_log = false; + std::string receiver_rtcp_dscp; int rtp_timebase = 0; + + // The codec parameter field honors the format laid out in RFC 6381: + // https://datatracker.ietf.org/doc/html/rfc6381. + std::string codec_parameter; }; struct AudioStream { - ErrorOr<Json::Value> ToJson() const; + static Error TryParse(const Json::Value& root, AudioStream* out); + Json::Value ToJson() const; + bool IsValid() const; - Stream stream = {}; - AudioCodec codec; + Stream stream; + AudioCodec codec = AudioCodec::kNotSpecified; int bit_rate = 0; }; -struct Resolution { - ErrorOr<Json::Value> ToJson() const; - - int width = 0; - int height = 0; -}; struct VideoStream { - ErrorOr<Json::Value> ToJson() const; + static Error TryParse(const Json::Value& root, VideoStream* out); + Json::Value ToJson() const; + bool IsValid() const; - Stream stream = {}; - VideoCodec codec; + Stream stream; + VideoCodec codec = VideoCodec::kNotSpecified; SimpleFraction max_frame_rate; int max_bit_rate = 0; - std::string protection = {}; - std::string profile = {}; - std::string level = {}; - std::vector<Resolution> resolutions = {}; - std::string error_recovery_mode = {}; + std::string protection; + std::string profile; + std::string level; + std::vector<Resolution> resolutions; + std::string error_recovery_mode; }; -enum class CastMode : uint8_t { kMirroring, kRemoting }; - struct Offer { + // TODO(jophba): remove deprecated declaration in a separate patch. static ErrorOr<Offer> Parse(const Json::Value& root); - ErrorOr<Json::Value> ToJson() const; + static Error TryParse(const Json::Value& root, Offer* out); + Json::Value ToJson() const; + bool IsValid() const; CastMode cast_mode = CastMode::kMirroring; - // This field is poorly named in the spec (receiverGetStatus), so we use - // a more descriptive name here. - bool supports_wifi_status_reporting = {}; - std::vector<AudioStream> audio_streams = {}; - std::vector<VideoStream> video_streams = {}; + std::vector<AudioStream> audio_streams; + std::vector<VideoStream> video_streams; }; } // namespace cast diff --git a/cast/streaming/offer_messages_unittest.cc b/cast/streaming/offer_messages_unittest.cc index a2117f67..62685e4d 100644 --- a/cast/streaming/offer_messages_unittest.cc +++ b/cast/streaming/offer_messages_unittest.cc @@ -21,7 +21,6 @@ namespace { constexpr char kValidOffer[] = R"({ "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": [ { "index": 0, @@ -82,6 +81,22 @@ constexpr char kValidOffer[] = R"({ "channels": 2, "aesKey": "51027e4e2347cbcb49d57ef10177aebc", "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + }, + { + "index": 3, + "type": "video_source", + "codecName": "av1", + "rtpProfile": "cast", + "rtpPayloadType": 104, + "ssrc": 19088744, + "maxFrameRate": "30000/1001", + "targetDelay": 1000, + "timeBase": "1/90000", + "maxBitRate": 5000000, + "profile": "main", + "level": "5", + "aesKey": "bbf109bf84513b456b13a184453b66ce", + "aesIvMask": "edaf9e4536e2b66191f560d9c04b2a69" } ] })"; @@ -91,24 +106,26 @@ void ExpectFailureOnParse( absl::optional<Error::Code> expected = absl::nullopt) { ErrorOr<Json::Value> root = json::Parse(body); ASSERT_TRUE(root.is_value()) << root.error(); - ErrorOr<Offer> error_or_offer = Offer::Parse(std::move(root.value())); - EXPECT_TRUE(error_or_offer.is_error()); + + Offer offer; + Error error = Offer::TryParse(std::move(root.value()), &offer); + EXPECT_FALSE(error.ok()); if (expected) { - EXPECT_EQ(expected, error_or_offer.error().code()); + EXPECT_EQ(expected, error.code()); } } void ExpectEqualsValidOffer(const Offer& offer) { EXPECT_EQ(CastMode::kMirroring, offer.cast_mode); - EXPECT_EQ(true, offer.supports_wifi_status_reporting); // Verify list of video streams. - EXPECT_EQ(2u, offer.video_streams.size()); + EXPECT_EQ(3u, offer.video_streams.size()); const auto& video_streams = offer.video_streams; const bool flipped = video_streams[0].stream.index != 0; - const VideoStream& vs_one = flipped ? video_streams[1] : video_streams[0]; - const VideoStream& vs_two = flipped ? video_streams[0] : video_streams[1]; + const VideoStream& vs_one = flipped ? video_streams[2] : video_streams[0]; + const VideoStream& vs_two = video_streams[1]; + const VideoStream& vs_three = flipped ? video_streams[0] : video_streams[2]; EXPECT_EQ(0, vs_one.stream.index); EXPECT_EQ(1, vs_one.stream.channels); @@ -163,6 +180,27 @@ void ExpectEqualsValidOffer(const Offer& offer) { const auto& resolutions_two = vs_two.resolutions; EXPECT_EQ(0u, resolutions_two.size()); + EXPECT_EQ(3, vs_three.stream.index); + EXPECT_EQ(1, vs_three.stream.channels); + EXPECT_EQ(Stream::Type::kVideoSource, vs_three.stream.type); + EXPECT_EQ(VideoCodec::kAv1, vs_three.codec); + EXPECT_EQ(RtpPayloadType::kVideoAv1, vs_three.stream.rtp_payload_type); + EXPECT_EQ(19088744u, vs_three.stream.ssrc); + EXPECT_EQ((SimpleFraction{30000, 1001}), vs_three.max_frame_rate); + EXPECT_EQ(90000, vs_three.stream.rtp_timebase); + EXPECT_EQ(5000000, vs_three.max_bit_rate); + EXPECT_EQ("main", vs_three.profile); + EXPECT_EQ("5", vs_three.level); + EXPECT_THAT(vs_three.stream.aes_key, + ElementsAre(0xbb, 0xf1, 0x09, 0xbf, 0x84, 0x51, 0x3b, 0x45, 0x6b, + 0x13, 0xa1, 0x84, 0x45, 0x3b, 0x66, 0xce)); + EXPECT_THAT(vs_three.stream.aes_iv_mask, + ElementsAre(0xed, 0xaf, 0x9e, 0x45, 0x36, 0xe2, 0xb6, 0x61, 0x91, + 0xf5, 0x60, 0xd9, 0xc0, 0x4b, 0x2a, 0x69)); + + const auto& resolutions_three = vs_three.resolutions; + EXPECT_EQ(0u, resolutions_three.size()); + // Verify list of audio streams. EXPECT_EQ(1u, offer.audio_streams.size()); const AudioStream& as = offer.audio_streams[0]; @@ -202,7 +240,9 @@ TEST(OfferTest, CanParseValidButStreamlessOffer) { "supportedStreams": [] })"); ASSERT_TRUE(root.is_value()) << root.error(); - EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value()); + + Offer offer; + EXPECT_TRUE(Offer::TryParse(std::move(root.value()), &offer).ok()); } TEST(OfferTest, ErrorOnMissingAudioStreamMandatoryField) { @@ -251,7 +291,8 @@ TEST(OfferTest, CanParseValidButMinimalAudioOffer) { }] })"); ASSERT_TRUE(root.is_value()); - EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value()); + Offer offer; + EXPECT_TRUE(Offer::TryParse(std::move(root.value()), &offer).ok()); } TEST(OfferTest, CanParseValidZeroBitRateAudioOffer) { @@ -272,8 +313,8 @@ TEST(OfferTest, CanParseValidZeroBitRateAudioOffer) { }] })"); ASSERT_TRUE(root.is_value()) << root.error(); - const auto offer = Offer::Parse(std::move(root.value())); - EXPECT_TRUE(offer.is_value()) << offer.error(); + Offer offer; + EXPECT_TRUE(Offer::TryParse(std::move(root.value()), &offer).ok()); } TEST(OfferTest, ErrorOnInvalidRtpTimebase) { @@ -422,6 +463,80 @@ TEST(OfferTest, ErrorOnMissingVideoStreamMandatoryField) { })"); } +TEST(OfferTest, ValidatesCodecParameterFormat) { + ExpectFailureOnParse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "audio_source", + "codecName": "aac", + "codecParameter": "vp08.123.332", + "rtpProfile": "cast", + "rtpPayloadType": 96, + "ssrc": 19088743, + "bitRate": 124000, + "timeBase": "1/10000000", + "channels": 2, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc", + "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + }] + })"); + + ExpectFailureOnParse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "video_source", + "codecName": "vp8", + "codecParameter": "vp09.11.23", + "rtpProfile": "cast", + "rtpPayloadType": 100, + "ssrc": 19088743, + "timeBase": "1/48000", + "resolutions": [], + "maxBitRate": 10000, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc" + }] + })"); + + const ErrorOr<Json::Value> audio_root = json::Parse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "audio_source", + "codecName": "aac", + "codecParameter": "mp4a.12", + "rtpProfile": "cast", + "rtpPayloadType": 96, + "ssrc": 19088743, + "bitRate": 124000, + "timeBase": "1/10000000", + "channels": 2, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc", + "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + }] + })"); + ASSERT_TRUE(audio_root.is_value()) << audio_root.error(); + + const ErrorOr<Json::Value> video_root = json::Parse(R"({ + "castMode": "mirroring", + "supportedStreams": [{ + "index": 2, + "type": "video_source", + "codecName": "vp9", + "codecParameter": "vp09.11.23", + "rtpProfile": "cast", + "rtpPayloadType": 100, + "ssrc": 19088743, + "timeBase": "1/48000", + "resolutions": [], + "maxBitRate": 10000, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc" + }] + })"); + ASSERT_TRUE(video_root.is_value()) << video_root.error(); +} + TEST(OfferTest, CanParseValidButMinimalVideoOffer) { ErrorOr<Json::Value> root = json::Parse(R"({ "castMode": "mirroring", @@ -441,62 +556,64 @@ TEST(OfferTest, CanParseValidButMinimalVideoOffer) { })"); ASSERT_TRUE(root.is_value()); - EXPECT_TRUE(Offer::Parse(std::move(root.value())).is_value()); + Offer offer; + EXPECT_TRUE(Offer::TryParse(std::move(root.value()), &offer).ok()); } TEST(OfferTest, CanParseValidOffer) { ErrorOr<Json::Value> root = json::Parse(kValidOffer); ASSERT_TRUE(root.is_value()); - ErrorOr<Offer> offer = Offer::Parse(std::move(root.value())); + Offer offer; + EXPECT_TRUE(Offer::TryParse(std::move(root.value()), &offer).ok()); - ExpectEqualsValidOffer(offer.value()); + ExpectEqualsValidOffer(offer); } TEST(OfferTest, ParseAndToJsonResultsInSameOffer) { ErrorOr<Json::Value> root = json::Parse(kValidOffer); ASSERT_TRUE(root.is_value()); - ErrorOr<Offer> offer = Offer::Parse(std::move(root.value())); - - ExpectEqualsValidOffer(offer.value()); + Offer offer; + EXPECT_TRUE(Offer::TryParse(std::move(root.value()), &offer).ok()); + ExpectEqualsValidOffer(offer); - auto eoj = offer.value().ToJson(); - EXPECT_TRUE(eoj.is_value()) << eoj.error(); - ErrorOr<Offer> reparsed_offer = Offer::Parse(std::move(eoj.value())); - ExpectEqualsValidOffer(reparsed_offer.value()); + Offer reparsed_offer; + EXPECT_TRUE(Offer::TryParse(std::move(root.value()), &reparsed_offer).ok()); + ExpectEqualsValidOffer(reparsed_offer); } // We don't want to enforce that a given offer must have both audio and // video, so we don't assert on either. -TEST(OfferTest, ToJsonSucceedsWithMissingStreams) { +TEST(OfferTest, IsValidWithMissingStreams) { ErrorOr<Json::Value> root = json::Parse(kValidOffer); ASSERT_TRUE(root.is_value()); - ErrorOr<Offer> offer = Offer::Parse(std::move(root.value())); - ExpectEqualsValidOffer(offer.value()); - const Offer valid_offer = std::move(offer.value()); + Offer offer; + EXPECT_TRUE(Offer::TryParse(std::move(root.value()), &offer).ok()); + ExpectEqualsValidOffer(offer); + const Offer valid_offer = std::move(offer); Offer missing_audio_streams = valid_offer; missing_audio_streams.audio_streams.clear(); - EXPECT_TRUE(missing_audio_streams.ToJson().is_value()); + EXPECT_TRUE(missing_audio_streams.IsValid()); Offer missing_video_streams = valid_offer; missing_video_streams.audio_streams.clear(); - EXPECT_TRUE(missing_video_streams.ToJson().is_value()); + EXPECT_TRUE(missing_video_streams.IsValid()); } -TEST(OfferTest, ToJsonFailsWithInvalidStreams) { +TEST(OfferTest, InvalidIfInvalidStreams) { ErrorOr<Json::Value> root = json::Parse(kValidOffer); ASSERT_TRUE(root.is_value()); - ErrorOr<Offer> offer = Offer::Parse(std::move(root.value())); - ExpectEqualsValidOffer(offer.value()); - const Offer valid_offer = std::move(offer.value()); + Offer offer; + EXPECT_TRUE(Offer::TryParse(std::move(root.value()), &offer).ok()); + ExpectEqualsValidOffer(offer); - Offer video_stream_invalid = valid_offer; - video_stream_invalid.video_streams[0].max_frame_rate.denominator = 0; - EXPECT_TRUE(video_stream_invalid.ToJson().is_error()); + Offer video_stream_invalid = offer; + video_stream_invalid.video_streams[0].max_frame_rate = SimpleFraction{1, 0}; + EXPECT_FALSE(video_stream_invalid.IsValid()); - Offer audio_stream_invalid = valid_offer; + Offer audio_stream_invalid = offer; video_stream_invalid.audio_streams[0].bit_rate = 0; - EXPECT_TRUE(video_stream_invalid.ToJson().is_error()); + EXPECT_FALSE(video_stream_invalid.IsValid()); } TEST(OfferTest, FailsIfUnencrypted) { diff --git a/cast/streaming/receiver.cc b/cast/streaming/receiver.cc index 0d3358b4..d08c181c 100644 --- a/cast/streaming/receiver.cc +++ b/cast/streaming/receiver.cc @@ -14,6 +14,7 @@ #include "util/chrono_helpers.h" #include "util/osp_logging.h" #include "util/std_util.h" +#include "util/trace_logging.h" namespace openscreen { namespace cast { @@ -22,10 +23,7 @@ namespace cast { // to help distinguish one out of multiple instances in a Cast Streaming // session. // -// TODO(miu): Replace RECEIVER_VLOG's with trace event logging once the tracing -// infrastructure is ready. #define RECEIVER_LOG(level) OSP_LOG_##level << "[SSRC:" << ssrc() << "] " -#define RECEIVER_VLOG OSP_VLOG << "[SSRC:" << ssrc() << "] " Receiver::Receiver(Environment* environment, ReceiverPacketRouter* packet_router, @@ -63,6 +61,16 @@ Receiver::~Receiver() { packet_router_->OnReceiverDestroyed(rtcp_session_.sender_ssrc()); } +const SessionConfig& Receiver::config() const { + return config_; +} +int Receiver::rtp_timebase() const { + return rtp_timebase_; +} +Ssrc Receiver::ssrc() const { + return rtcp_session_.receiver_ssrc(); +} + void Receiver::SetConsumer(Consumer* consumer) { consumer_ = consumer; ScheduleFrameReadyCheck(); @@ -85,6 +93,7 @@ void Receiver::RequestKeyFrame() { } int Receiver::AdvanceToNextFrame() { + TRACE_DEFAULT_SCOPED(TraceCategory::kReceiver); const FrameId immediate_next_frame = last_frame_consumed_ + 1; // Scan the queue for the next frame that should be consumed. Typically, this @@ -96,13 +105,11 @@ int Receiver::AdvanceToNextFrame() { const EncryptedFrame& encrypted_frame = entry.collector.PeekAtAssembledFrame(); if (f == immediate_next_frame) { // Typical case. - RECEIVER_VLOG << "AdvanceToNextFrame: Next in sequence (" << f << ')'; return FrameCrypto::GetPlaintextSize(encrypted_frame); } if (encrypted_frame.dependency != EncodedFrame::DEPENDS_ON_ANOTHER) { // Found a frame after skipping past some frames. Drop the ones being // skipped, advancing |last_frame_consumed_| before returning. - RECEIVER_VLOG << "AdvanceToNextFrame: Skipping-ahead → " << f; DropAllFramesBefore(f); return FrameCrypto::GetPlaintextSize(encrypted_frame); } @@ -130,12 +137,11 @@ int Receiver::AdvanceToNextFrame() { } } - RECEIVER_VLOG << "AdvanceToNextFrame: No frames ready. Last consumed was " - << last_frame_consumed_ << '.'; return kNoFramesReady; } EncodedFrame Receiver::ConsumeNextFrame(absl::Span<uint8_t> buffer) { + TRACE_DEFAULT_SCOPED(TraceCategory::kReceiver); // Assumption: The required call to AdvanceToNextFrame() ensures that // |last_frame_consumed_| is set to one before the frame to be consumed here. const FrameId frame_id = last_frame_consumed_ + 1; @@ -151,14 +157,13 @@ EncodedFrame Receiver::ConsumeNextFrame(absl::Span<uint8_t> buffer) { frame.reference_time = *entry.estimated_capture_time + ResolveTargetPlayoutDelay(frame_id); - RECEIVER_VLOG << "ConsumeNextFrame → " << frame.frame_id << ": " - << frame.data.size() << " payload bytes, RTP Timestamp " - << frame.rtp_timestamp - .ToTimeSinceOrigin<microseconds>(rtp_timebase_) - .count() - << " µs, to play-out " - << to_microseconds(frame.reference_time - now_()).count() - << " µs from now."; + OSP_VLOG << "ConsumeNextFrame → " << frame.frame_id << ": " + << frame.data.size() << " payload bytes, RTP Timestamp " + << frame.rtp_timestamp.ToTimeSinceOrigin<microseconds>(rtp_timebase_) + .count() + << " µs, to play-out " + << to_microseconds(frame.reference_time - now_()).count() + << " µs from now."; entry.Reset(); last_frame_consumed_ = frame_id; @@ -195,8 +200,6 @@ void Receiver::OnReceivedRtpPacket(Clock::time_point arrival_time, const FrameId max_allowed_frame_id = last_frame_consumed_ + kMaxUnackedFrames; if (part->frame_id > max_allowed_frame_id) { - RECEIVER_VLOG << "Dropping RTP packet for " << part->frame_id - << ": Too many frames are already in-flight."; return; } do { @@ -204,8 +207,6 @@ void Receiver::OnReceivedRtpPacket(Clock::time_point arrival_time, GetQueueEntry(latest_frame_expected_) .collector.set_frame_id(latest_frame_expected_); } while (latest_frame_expected_ < part->frame_id); - RECEIVER_VLOG << "Advanced latest frame expected to " - << latest_frame_expected_; } // Start-up edge case: Blatantly drop the first packet of all frames until the @@ -253,9 +254,6 @@ void Receiver::OnReceivedRtpPacket(Clock::time_point arrival_time, // If a target playout delay change was included in this packet, record it. if (part->new_playout_delay > milliseconds::zero()) { - RECEIVER_VLOG << "Target playout delay changes to " - << part->new_playout_delay.count() << " ms, as of " - << part->frame_id; RecordNewTargetPlayoutDelay(part->frame_id, part->new_playout_delay); } @@ -289,6 +287,7 @@ void Receiver::OnReceivedRtpPacket(Clock::time_point arrival_time, void Receiver::OnReceivedRtcpPacket(Clock::time_point arrival_time, std::vector<uint8_t> packet) { + TRACE_DEFAULT_SCOPED(TraceCategory::kReceiver); absl::optional<SenderReportParser::SenderReportWithId> parsed_report = rtcp_parser_.Parse(packet); if (!parsed_report) { @@ -311,10 +310,6 @@ void Receiver::OnReceivedRtcpPacket(Clock::time_point arrival_time, const Clock::duration measured_offset = arrival_time - last_sender_report_->reference_time; smoothed_clock_offset_.Update(arrival_time, measured_offset); - RECEIVER_VLOG - << "Received Sender Report: Local clock is ahead of Sender's by " - << to_microseconds(smoothed_clock_offset_.Current()).count() - << " µs (minus one-way network transit time)."; RtcpReportBlock report; report.ssrc = rtcp_session_.sender_ssrc(); @@ -347,7 +342,6 @@ void Receiver::SendRtcp() { packet_router_->SendRtcpPacket(rtcp_builder_.BuildPacket( last_rtcp_send_time_, absl::Span<uint8_t>(rtcp_buffer_.get(), rtcp_buffer_capacity_))); - RECEIVER_VLOG << "Sent RTCP packet."; // Schedule the automatic sending of another RTCP packet, if this method is // not called within some bounded amount of time. While incomplete frames @@ -413,6 +407,7 @@ milliseconds Receiver::ResolveTargetPlayoutDelay(FrameId frame_id) const { } void Receiver::AdvanceCheckpoint(FrameId new_checkpoint) { + TRACE_DEFAULT_SCOPED(TraceCategory::kReceiver); OSP_DCHECK_GT(new_checkpoint, checkpoint_frame()); OSP_DCHECK_LE(new_checkpoint, latest_frame_expected_); @@ -424,7 +419,6 @@ void Receiver::AdvanceCheckpoint(FrameId new_checkpoint) { new_checkpoint = next; } - RECEIVER_VLOG << "Advancing checkpoint to " << new_checkpoint; set_checkpoint_frame(new_checkpoint); rtcp_builder_.SetPlayoutDelay(ResolveTargetPlayoutDelay(new_checkpoint)); SendRtcp(); @@ -465,8 +459,6 @@ void Receiver::ScheduleFrameReadyCheck(Clock::time_point when) { when); } -Receiver::Consumer::~Consumer() = default; - Receiver::PendingFrame::PendingFrame() = default; Receiver::PendingFrame::~PendingFrame() = default; diff --git a/cast/streaming/receiver.h b/cast/streaming/receiver.h index 057c56d8..d7fd1c80 100644 --- a/cast/streaming/receiver.h +++ b/cast/streaming/receiver.h @@ -21,6 +21,7 @@ #include "cast/streaming/frame_collector.h" #include "cast/streaming/frame_id.h" #include "cast/streaming/packet_receive_stats_tracker.h" +#include "cast/streaming/receiver_base.h" #include "cast/streaming/rtcp_common.h" #include "cast/streaming/rtcp_session.h" #include "cast/streaming/rtp_packet_parser.h" @@ -103,20 +104,9 @@ class ReceiverPacketRouter; // 3. Last Frame Consumed: The FrameId of last frame consumed (see // ConsumeNextFrame()). Once a frame is consumed, all internal resources // related to the frame can be freed and/or re-used for later frames. -class Receiver { +class Receiver : public ReceiverBase { public: - class Consumer { - public: - virtual ~Consumer(); - - // Called whenever one or more frames have become ready for consumption. The - // |next_frame_buffer_size| argument is identical to the result of calling - // AdvanceToNextFrame(), and so the Consumer only needs to prepare a buffer - // and call ConsumeNextFrame(). It may then call AdvanceToNextFrame() to - // check whether there are any more frames ready, but this is not mandatory. - // See usage example in class-level comments. - virtual void OnFramesReady(int next_frame_buffer_size) = 0; - }; + using ReceiverBase::Consumer; // Constructs a Receiver that attaches to the given |environment| and // |packet_router|. The config contains the settings that were @@ -126,52 +116,17 @@ class Receiver { Receiver(Environment* environment, ReceiverPacketRouter* packet_router, SessionConfig config); - ~Receiver(); - - const SessionConfig& config() const { return config_; } - int rtp_timebase() const { return rtp_timebase_; } - Ssrc ssrc() const { return rtcp_session_.receiver_ssrc(); } - - // Set the Consumer receiving notifications when new frames are ready for - // consumption. Frames received before this method is called will remain in - // the queue indefinitely. - void SetConsumer(Consumer* consumer); - - // Sets how much time the consumer will need to decode/buffer/render/etc., and - // otherwise fully process a frame for on-time playback. This information is - // used by the Receiver to decide whether to skip past frames that have - // arrived too late. This method can be called repeatedly to make adjustments - // based on changing environmental conditions. - // - // Default setting: kDefaultPlayerProcessingTime - void SetPlayerProcessingTime(Clock::duration needed_time); - - // Propagates a "picture loss indicator" notification to the Sender, - // requesting a key frame so that decode/playout can recover. It is safe to - // call this redundantly. The Receiver will clear the picture loss condition - // automatically, once a key frame is received (i.e., before - // ConsumeNextFrame() is called to access it). - void RequestKeyFrame(); - - // Advances to the next frame ready for consumption. This may skip-over - // incomplete frames that will not play out on-time; but only if there are - // completed frames further down the queue that have no dependency - // relationship with them (e.g., key frames). - // - // This method returns kNoFramesReady if there is not currently a frame ready - // for consumption. The caller should wait for a Consumer::OnFramesReady() - // notification before trying again. Otherwise, the number of bytes of encoded - // data is returned, and the caller should use this to ensure the buffer it - // passes to ConsumeNextFrame() is large enough. - int AdvanceToNextFrame(); - - // Returns the next frame, both metadata and payload data. The Consumer calls - // this method after being notified via OnFramesReady(), and it can also call - // this whenever AdvanceToNextFrame() indicates another frame is ready. - // |buffer| must point to a sufficiently-sized buffer that will be populated - // with the frame's payload data. Upon return |frame->data| will be set to the - // portion of the buffer that was populated. - EncodedFrame ConsumeNextFrame(absl::Span<uint8_t> buffer); + ~Receiver() override; + + // ReceiverBase overrides. + const SessionConfig& config() const override; + int rtp_timebase() const override; + Ssrc ssrc() const override; + void SetConsumer(Consumer* consumer) override; + void SetPlayerProcessingTime(Clock::duration needed_time) override; + void RequestKeyFrame() override; + int AdvanceToNextFrame() override; + EncodedFrame ConsumeNextFrame(absl::Span<uint8_t> buffer) override; // Allows setting picture loss indication for testing. In production, this // should be done using the config. @@ -180,11 +135,12 @@ class Receiver { } // The default "player processing time" amount. See SetPlayerProcessingTime(). - static constexpr std::chrono::milliseconds kDefaultPlayerProcessingTime{5}; + static constexpr std::chrono::milliseconds kDefaultPlayerProcessingTime = + ReceiverBase::kDefaultPlayerProcessingTime; // Returned by AdvanceToNextFrame() when there are no frames currently ready // for consumption. - static constexpr int kNoFramesReady = -1; + static constexpr int kNoFramesReady = ReceiverBase::kNoFramesReady; protected: friend class ReceiverPacketRouter; @@ -346,8 +302,8 @@ class Receiver { // The interval between sending ACK/NACK feedback RTCP messages while // incomplete frames exist in the queue. // - // TODO(miu): This should be a function of the current target playout delay, - // similar to the Sender's kickstart interval logic. + // TODO(jophba): This should be a function of the current target playout + // delay, similar to the Sender's kickstart interval logic. static constexpr std::chrono::milliseconds kNackFeedbackInterval{30}; }; diff --git a/cast/streaming/receiver_base.cc b/cast/streaming/receiver_base.cc new file mode 100644 index 00000000..dd0067df --- /dev/null +++ b/cast/streaming/receiver_base.cc @@ -0,0 +1,17 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/receiver_base.h" + +namespace openscreen { +namespace cast { + +ReceiverBase::Consumer::~Consumer() = default; + +ReceiverBase::ReceiverBase() = default; + +ReceiverBase::~ReceiverBase() = default; + +} // namespace cast +} // namespace openscreen diff --git a/cast/streaming/receiver_base.h b/cast/streaming/receiver_base.h new file mode 100644 index 00000000..1a8f3981 --- /dev/null +++ b/cast/streaming/receiver_base.h @@ -0,0 +1,108 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_RECEIVER_BASE_H_ +#define CAST_STREAMING_RECEIVER_BASE_H_ + +#include <chrono> + +#include "absl/types/span.h" +#include "cast/streaming/encoded_frame.h" +#include "cast/streaming/session_config.h" +#include "cast/streaming/ssrc.h" +#include "platform/api/time.h" + +namespace openscreen { +namespace cast { + +// The Cast Streaming Receiver, a peer corresponding to some Cast Streaming +// Sender at the other end of a network link. +// +// Cast Streaming is a transport protocol which divides up the frames for one +// media stream (e.g., audio or video) into multiple RTP packets containing an +// encrypted payload. The Receiver is the peer responsible for collecting the +// RTP packets, decrypting the payload, and re-assembling a frame that can be +// passed to a decoder and played out. +// +// A Sender ↔ Receiver pair is used to transport each media stream. Typically, +// there are two pairs in a normal system, one for the audio stream and one for +// video stream. A local player is responsible for synchronizing the playout of +// the frames of each stream to achieve lip-sync. See the discussion in +// encoded_frame.h for how the |reference_time| and |rtp_timestamp| of the +// EncodedFrames are used to achieve this. +class ReceiverBase { + public: + class Consumer { + public: + virtual ~Consumer(); + + // Called whenever one or more frames have become ready for consumption. The + // |next_frame_buffer_size| argument is identical to the result of calling + // AdvanceToNextFrame(), and so the Consumer only needs to prepare a buffer + // and call ConsumeNextFrame(). It may then call AdvanceToNextFrame() to + // check whether there are any more frames ready, but this is not mandatory. + // See usage example in class-level comments. + virtual void OnFramesReady(int next_frame_buffer_size) = 0; + }; + + ReceiverBase(); + virtual ~ReceiverBase(); + + virtual const SessionConfig& config() const = 0; + virtual int rtp_timebase() const = 0; + virtual Ssrc ssrc() const = 0; + + // Set the Consumer receiving notifications when new frames are ready for + // consumption. Frames received before this method is called will remain in + // the queue indefinitely. + virtual void SetConsumer(Consumer* consumer) = 0; + + // Sets how much time the consumer will need to decode/buffer/render/etc., and + // otherwise fully process a frame for on-time playback. This information is + // used by the Receiver to decide whether to skip past frames that have + // arrived too late. This method can be called repeatedly to make adjustments + // based on changing environmental conditions. + // + // Default setting: kDefaultPlayerProcessingTime + virtual void SetPlayerProcessingTime(Clock::duration needed_time) = 0; + + // Propagates a "picture loss indicator" notification to the Sender, + // requesting a key frame so that decode/playout can recover. It is safe to + // call this redundantly. The Receiver will clear the picture loss condition + // automatically, once a key frame is received (i.e., before + // ConsumeNextFrame() is called to access it). + virtual void RequestKeyFrame() = 0; + + // Advances to the next frame ready for consumption. This may skip-over + // incomplete frames that will not play out on-time; but only if there are + // completed frames further down the queue that have no dependency + // relationship with them (e.g., key frames). + // + // This method returns kNoFramesReady if there is not currently a frame ready + // for consumption. The caller should wait for a Consumer::OnFramesReady() + // notification before trying again. Otherwise, the number of bytes of encoded + // data is returned, and the caller should use this to ensure the buffer it + // passes to ConsumeNextFrame() is large enough. + virtual int AdvanceToNextFrame() = 0; + + // Returns the next frame, both metadata and payload data. The Consumer calls + // this method after being notified via OnFramesReady(), and it can also call + // this whenever AdvanceToNextFrame() indicates another frame is ready. + // |buffer| must point to a sufficiently-sized buffer that will be populated + // with the frame's payload data. Upon return |frame->data| will be set to the + // portion of the buffer that was populated. + virtual EncodedFrame ConsumeNextFrame(absl::Span<uint8_t> buffer) = 0; + + // The default "player processing time" amount. See SetPlayerProcessingTime(). + static constexpr std::chrono::milliseconds kDefaultPlayerProcessingTime{5}; + + // Returned by AdvanceToNextFrame() when there are no frames currently ready + // for consumption. + static constexpr int kNoFramesReady = -1; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_RECEIVER_BASE_H_ diff --git a/cast/streaming/receiver_message.cc b/cast/streaming/receiver_message.cc index e2739220..7f0999f0 100644 --- a/cast/streaming/receiver_message.cc +++ b/cast/streaming/receiver_message.cc @@ -7,6 +7,7 @@ #include <utility> #include "absl/strings/ascii.h" +#include "absl/types/optional.h" #include "cast/streaming/message_fields.h" #include "json/reader.h" #include "json/writer.h" @@ -23,13 +24,24 @@ namespace { EnumNameTable<ReceiverMessage::Type, 5> kMessageTypeNames{ {{kMessageTypeAnswer, ReceiverMessage::Type::kAnswer}, - {"STATUS_RESPONSE", ReceiverMessage::Type::kStatusResponse}, {"CAPABILITIES_RESPONSE", ReceiverMessage::Type::kCapabilitiesResponse}, {"RPC", ReceiverMessage::Type::kRpc}}}; +EnumNameTable<MediaCapability, 10> kMediaCapabilityNames{ + {{"audio", MediaCapability::kAudio}, + {"aac", MediaCapability::kAac}, + {"opus", MediaCapability::kOpus}, + {"video", MediaCapability::kVideo}, + {"4k", MediaCapability::k4k}, + {"h264", MediaCapability::kH264}, + {"vp8", MediaCapability::kVp8}, + {"vp9", MediaCapability::kVp9}, + {"hevc", MediaCapability::kHevc}, + {"av1", MediaCapability::kAv1}}}; + ReceiverMessage::Type GetMessageType(const Json::Value& root) { std::string type; - if (!json::ParseAndValidateString(root[kMessageType], &type)) { + if (!json::TryParseString(root[kMessageType], &type)) { return ReceiverMessage::Type::kUnknown; } @@ -39,6 +51,21 @@ ReceiverMessage::Type GetMessageType(const Json::Value& root) { return parsed.value(ReceiverMessage::Type::kUnknown); } +bool TryParseCapability(const Json::Value& value, MediaCapability* out) { + std::string c; + if (!json::TryParseString(value, &c)) { + return false; + } + + const ErrorOr<MediaCapability> capability = GetEnum(kMediaCapabilityNames, c); + if (capability.is_error()) { + return false; + } + + *out = capability.value(); + return true; +} + } // namespace // static @@ -50,8 +77,8 @@ ErrorOr<ReceiverError> ReceiverError::Parse(const Json::Value& value) { int code; std::string description; - if (!json::ParseAndValidateInt(value[kErrorCode], &code) || - !json::ParseAndValidateString(value[kErrorDescription], &description)) { + if (!json::TryParseInt(value[kErrorCode], &code) || + !json::TryParseString(value[kErrorDescription], &description)) { return Error::Code::kJsonParseError; } return ReceiverError{code, description}; @@ -73,18 +100,18 @@ ErrorOr<ReceiverCapability> ReceiverCapability::Parse( } int remoting_version; - if (!json::ParseAndValidateInt(value["remoting"], &remoting_version)) { + if (!json::TryParseInt(value["remoting"], &remoting_version)) { remoting_version = ReceiverCapability::kRemotingVersionUnknown; } - std::vector<std::string> media_capabilities; - if (!json::ParseAndValidateStringArray(value["mediaCaps"], - &media_capabilities)) { + std::vector<MediaCapability> capabilities; + if (!json::TryParseArray<MediaCapability>( + value["mediaCaps"], TryParseCapability, &capabilities)) { return Error(Error::Code::kJsonParseError, "Failed to parse media capabilities"); } - return ReceiverCapability{remoting_version, std::move(media_capabilities)}; + return ReceiverCapability{remoting_version, std::move(capabilities)}; } Json::Value ReceiverCapability::ToJson() const { @@ -92,51 +119,21 @@ Json::Value ReceiverCapability::ToJson() const { root["remoting"] = remoting_version; Json::Value capabilities(Json::ValueType::arrayValue); for (const auto& capability : media_capabilities) { - capabilities.append(capability); + capabilities.append(GetEnumName(kMediaCapabilityNames, capability).value()); } root["mediaCaps"] = std::move(capabilities); return root; } // static -ErrorOr<ReceiverWifiStatus> ReceiverWifiStatus::Parse( - const Json::Value& value) { - if (!value) { - return Error(Error::Code::kParameterInvalid, - "Empty JSON in status parsing"); - } - - double wifi_snr; - std::vector<int32_t> wifi_speed; - if (!json::ParseAndValidateDouble(value["wifiSnr"], &wifi_snr, true) || - !json::ParseAndValidateIntArray(value["wifiSpeed"], &wifi_speed)) { - return Error::Code::kJsonParseError; - } - return ReceiverWifiStatus{wifi_snr, std::move(wifi_speed)}; -} - -Json::Value ReceiverWifiStatus::ToJson() const { - Json::Value root; - root["wifiSnr"] = wifi_snr; - Json::Value speeds(Json::ValueType::arrayValue); - for (const auto& speed : wifi_speed) { - speeds.append(speed); - } - root["wifiSpeed"] = std::move(speeds); - return root; -} - -// static ErrorOr<ReceiverMessage> ReceiverMessage::Parse(const Json::Value& value) { ReceiverMessage message; - if (!value || !json::ParseAndValidateInt(value[kSequenceNumber], - &(message.sequence_number))) { - return Error(Error::Code::kJsonParseError, - "Failed to parse sequence number"); + if (!value) { + return Error(Error::Code::kJsonParseError, "Invalid message body"); } std::string result; - if (!json::ParseAndValidateString(value[kResult], &result)) { + if (!json::TryParseString(value[kResult], &result)) { result = kResultError; } @@ -155,22 +152,13 @@ ErrorOr<ReceiverMessage> ReceiverMessage::Parse(const Json::Value& value) { switch (message.type) { case Type::kAnswer: { Answer answer; - if (openscreen::cast::Answer::ParseAndValidate(value[kAnswerMessageBody], - &answer)) { + if (openscreen::cast::Answer::TryParse(value[kAnswerMessageBody], + &answer)) { message.body = std::move(answer); message.valid = true; } } break; - case Type::kStatusResponse: { - ErrorOr<ReceiverWifiStatus> status = - ReceiverWifiStatus::Parse(value[kStatusMessageBody]); - if (status.is_value()) { - message.body = std::move(status.value()); - message.valid = true; - } - } break; - case Type::kCapabilitiesResponse: { ErrorOr<ReceiverCapability> capability = ReceiverCapability::Parse(value[kCapabilitiesMessageBody]); @@ -181,20 +169,25 @@ ErrorOr<ReceiverMessage> ReceiverMessage::Parse(const Json::Value& value) { } break; case Type::kRpc: { - std::string rpc; - if (json::ParseAndValidateString(value[kRpcMessageBody], &rpc) && - base64::Decode(rpc, &rpc)) { + std::string encoded_rpc; + std::vector<uint8_t> rpc; + if (json::TryParseString(value[kRpcMessageBody], &encoded_rpc) && + base64::Decode(encoded_rpc, &rpc)) { message.body = std::move(rpc); message.valid = true; } } break; - case Type::kUnknown: default: - message.valid = false; break; } + if (message.type != ReceiverMessage::Type::kRpc && + !json::TryParseInt(value[kSequenceNumber], &(message.sequence_number))) { + message.sequence_number = -1; + message.valid = false; + } + return message; } @@ -219,20 +212,21 @@ ErrorOr<Json::Value> ReceiverMessage::ToJson() const { } break; - case (ReceiverMessage::Type::kStatusResponse): - root[kResult] = kResultOk; - root[kStatusMessageBody] = absl::get<ReceiverWifiStatus>(body).ToJson(); - break; - case ReceiverMessage::Type::kCapabilitiesResponse: - root[kResult] = kResultOk; - root[kCapabilitiesMessageBody] = - absl::get<ReceiverCapability>(body).ToJson(); + if (valid) { + root[kResult] = kResultOk; + root[kCapabilitiesMessageBody] = + absl::get<ReceiverCapability>(body).ToJson(); + } else { + root[kResult] = kResultError; + root[kErrorMessageBody] = absl::get<ReceiverError>(body).ToJson(); + } break; // NOTE: RPC messages do NOT have a result field. case ReceiverMessage::Type::kRpc: - root[kRpcMessageBody] = base64::Encode(absl::get<std::string>(body)); + root[kRpcMessageBody] = + base64::Encode(absl::get<std::vector<uint8_t>>(body)); break; default: diff --git a/cast/streaming/receiver_message.h b/cast/streaming/receiver_message.h index 59aa9750..f4adbfb3 100644 --- a/cast/streaming/receiver_message.h +++ b/cast/streaming/receiver_message.h @@ -17,16 +17,17 @@ namespace openscreen { namespace cast { -struct ReceiverWifiStatus { - Json::Value ToJson() const; - static ErrorOr<ReceiverWifiStatus> Parse(const Json::Value& value); - - // Current WiFi signal to noise ratio in decibels. - double wifi_snr = 0.0; - - // Min, max, average, and current bandwidth in bps in order of the WiFi link. - // Example: [1200, 1300, 1250, 1230]. - std::vector<int32_t> wifi_speed; +enum class MediaCapability { + kAudio, + kAac, + kOpus, + kVideo, + k4k, + kH264, + kVp8, + kVp9, + kHevc, + kAv1 }; struct ReceiverCapability { @@ -39,7 +40,7 @@ struct ReceiverCapability { int remoting_version = kRemotingVersionUnknown; // Set of capabilities (e.g., ac3, 4k, hevc, vp9, dolby_vision, etc.). - std::vector<std::string> media_capabilities; + std::vector<MediaCapability> media_capabilities; }; struct ReceiverError { @@ -47,6 +48,8 @@ struct ReceiverError { static ErrorOr<ReceiverError> Parse(const Json::Value& value); // Error code. + // TODO(issuetracker.google.com/184766188): Error codes should be well + // defined. int32_t code = -1; // Error description. @@ -63,9 +66,6 @@ struct ReceiverMessage { // Response to OFFER message. kAnswer, - // Response to GET_STATUS message. - kStatusResponse, - // Response to GET_CAPABILITIES message. kCapabilitiesResponse, @@ -84,8 +84,7 @@ struct ReceiverMessage { absl::variant<absl::monostate, Answer, - std::string, - ReceiverWifiStatus, + std::vector<uint8_t>, // Binary-encoded RPC message. ReceiverCapability, ReceiverError> body; diff --git a/cast/streaming/receiver_packet_router.cc b/cast/streaming/receiver_packet_router.cc index 1ac4266a..23b99ce4 100644 --- a/cast/streaming/receiver_packet_router.cc +++ b/cast/streaming/receiver_packet_router.cc @@ -73,10 +73,11 @@ void ReceiverPacketRouter::OnReceivedPacket(const IPEndpoint& source, InspectPacketForRouting(packet); if (seems_like.first == ApparentPacketType::UNKNOWN) { constexpr int kMaxPartiaHexDumpSize = 96; + const std::size_t encode_size = + std::min(packet.size(), static_cast<size_t>(kMaxPartiaHexDumpSize)); OSP_LOG_WARN << "UNKNOWN packet of " << packet.size() << " bytes. Partial hex dump: " - << HexEncode(absl::Span<const uint8_t>(packet).subspan( - 0, kMaxPartiaHexDumpSize)); + << HexEncode(packet.data(), encode_size); return; } auto it = receivers_.find(seems_like.second); diff --git a/cast/streaming/receiver_session.cc b/cast/streaming/receiver_session.cc index 68c1ee4c..bda6d984 100644 --- a/cast/streaming/receiver_session.cc +++ b/cast/streaming/receiver_session.cc @@ -13,6 +13,7 @@ #include "absl/strings/numbers.h" #include "cast/common/channel/message_util.h" #include "cast/common/public/message_port.h" +#include "cast/streaming/answer_messages.h" #include "cast/streaming/environment.h" #include "cast/streaming/message_fields.h" #include "cast/streaming/offer_messages.h" @@ -23,22 +24,21 @@ namespace openscreen { namespace cast { - -// Using statements for constructor readability. -using Preferences = ReceiverSession::Preferences; -using ConfiguredReceivers = ReceiverSession::ConfiguredReceivers; - namespace { template <typename Stream, typename Codec> std::unique_ptr<Stream> SelectStream( const std::vector<Codec>& preferred_codecs, + ReceiverSession::Client* client, const std::vector<Stream>& offered_streams) { for (auto codec : preferred_codecs) { for (const Stream& offered_stream : offered_streams) { - if (offered_stream.codec == codec) { - OSP_DVLOG << "Selected " << CodecToString(codec) - << " as codec for streaming"; + if (offered_stream.codec == codec && + (offered_stream.stream.codec_parameter.empty() || + client->SupportsCodecParameter( + offered_stream.stream.codec_parameter))) { + OSP_VLOG << "Selected " << CodecToString(codec) + << " as codec for streaming"; return std::make_unique<Stream>(offered_stream); } } @@ -46,31 +46,174 @@ std::unique_ptr<Stream> SelectStream( return nullptr; } -DisplayResolution ToDisplayResolution(const Resolution& resolution) { - return DisplayResolution{resolution.width, resolution.height}; +MediaCapability ToCapability(AudioCodec codec) { + switch (codec) { + case AudioCodec::kAac: + return MediaCapability::kAac; + case AudioCodec::kOpus: + return MediaCapability::kOpus; + default: + OSP_DLOG_FATAL << "Invalid audio codec: " << static_cast<int>(codec); + OSP_NOTREACHED(); + } +} + +MediaCapability ToCapability(VideoCodec codec) { + switch (codec) { + case VideoCodec::kVp8: + return MediaCapability::kVp8; + case VideoCodec::kVp9: + return MediaCapability::kVp9; + case VideoCodec::kH264: + return MediaCapability::kH264; + case VideoCodec::kHevc: + return MediaCapability::kHevc; + case VideoCodec::kAv1: + return MediaCapability::kAv1; + default: + OSP_DLOG_FATAL << "Invalid video codec: " << static_cast<int>(codec); + OSP_NOTREACHED(); + } +} + +// Calculates whether any codecs present in |second| are not present in |first|. +template <typename T> +bool IsMissingCodecs(const std::vector<T>& first, + const std::vector<T>& second) { + if (second.size() > first.size()) { + return true; + } + + for (auto codec : second) { + if (std::find(first.begin(), first.end(), codec) == first.end()) { + return true; + } + } + + return false; +} + +// Calculates whether the limits defined by |first| are less restrictive than +// those defined by |second|. +// NOTE: These variables are intentionally passed by copy - the function will +// mutate them. +template <typename T> +bool HasLessRestrictiveLimits(std::vector<T> first, std::vector<T> second) { + // Sort both vectors to allow for element-by-element comparison between the + // two. All elements with |applies_to_all_codecs| set are sorted to the front. + std::function<bool(const T&, const T&)> sorter = [](const T& first, + const T& second) { + if (first.applies_to_all_codecs != second.applies_to_all_codecs) { + return first.applies_to_all_codecs; + } + return static_cast<int>(first.codec) < static_cast<int>(second.codec); + }; + std::sort(first.begin(), first.end(), sorter); + std::sort(second.begin(), second.end(), sorter); + auto first_it = first.begin(); + auto second_it = second.begin(); + + // |applies_to_all_codecs| is a special case, so handle that first. + T fake_applies_to_all_codecs_struct; + fake_applies_to_all_codecs_struct.applies_to_all_codecs = true; + T* first_applies_to_all_codecs_struct = + !first.empty() && first.front().applies_to_all_codecs + ? &(*first_it++) + : &fake_applies_to_all_codecs_struct; + T* second_applies_to_all_codecs_struct = + !second.empty() && second.front().applies_to_all_codecs + ? &(*second_it++) + : &fake_applies_to_all_codecs_struct; + if (!first_applies_to_all_codecs_struct->IsSupersetOf( + *second_applies_to_all_codecs_struct)) { + return false; + } + + // Now all elements of the vectors can be assumed to NOT have + // |applies_to_all_codecs| set. So iterate through all codecs set in either + // vector and check that the first has the less restrictive configuration set. + while (first_it != first.end() || second_it != second.end()) { + // Calculate the current codec to process, and whether each vector contains + // an instance of this codec. + decltype(T::codec) current_codec; + bool use_first_fake = false; + bool use_second_fake = false; + if (first_it == first.end()) { + current_codec = second_it->codec; + use_first_fake = true; + } else if (second_it == second.end()) { + current_codec = first_it->codec; + use_second_fake = true; + } else { + current_codec = std::min(first_it->codec, second_it->codec); + use_first_fake = first_it->codec != current_codec; + use_second_fake = second_it->codec != current_codec; + } + + // Compare each vector's limit associated with this codec, or compare + // against the default limits if no such codec limits are set. + T fake_codecs_struct; + fake_codecs_struct.codec = current_codec; + T* first_codec_struct = + use_first_fake ? &fake_codecs_struct : &(*first_it++); + T* second_codec_struct = + use_second_fake ? &fake_codecs_struct : &(*second_it++); + OSP_DCHECK(!first_codec_struct->applies_to_all_codecs); + OSP_DCHECK(!second_codec_struct->applies_to_all_codecs); + if (!first_codec_struct->IsSupersetOf(*second_codec_struct)) { + return false; + } + } + + return true; } } // namespace ReceiverSession::Client::~Client() = default; +using RemotingPreferences = ReceiverSession::RemotingPreferences; + +using Preferences = ReceiverSession::Preferences; + Preferences::Preferences() = default; Preferences::Preferences(std::vector<VideoCodec> video_codecs, std::vector<AudioCodec> audio_codecs) - : Preferences(video_codecs, audio_codecs, nullptr, nullptr) {} + : video_codecs(std::move(video_codecs)), + audio_codecs(std::move(audio_codecs)) {} Preferences::Preferences(std::vector<VideoCodec> video_codecs, std::vector<AudioCodec> audio_codecs, - std::unique_ptr<Constraints> constraints, - std::unique_ptr<DisplayDescription> description) + std::vector<AudioLimits> audio_limits, + std::vector<VideoLimits> video_limits, + std::unique_ptr<Display> description) : video_codecs(std::move(video_codecs)), audio_codecs(std::move(audio_codecs)), - constraints(std::move(constraints)), + audio_limits(std::move(audio_limits)), + video_limits(std::move(video_limits)), display_description(std::move(description)) {} Preferences::Preferences(Preferences&&) noexcept = default; Preferences& Preferences::operator=(Preferences&&) noexcept = default; +Preferences::Preferences(const Preferences& other) { + *this = other; +} + +Preferences& Preferences::operator=(const Preferences& other) { + video_codecs = other.video_codecs; + audio_codecs = other.audio_codecs; + audio_limits = other.audio_limits; + video_limits = other.video_limits; + if (other.display_description) { + display_description = std::make_unique<Display>(*other.display_description); + } + if (other.remoting) { + remoting = std::make_unique<RemotingPreferences>(*other.remoting); + } + return *this; +} + ReceiverSession::ReceiverSession(Client* const client, Environment* environment, MessagePort* message_port, @@ -79,19 +222,34 @@ ReceiverSession::ReceiverSession(Client* const client, environment_(environment), preferences_(std::move(preferences)), session_id_(MakeUniqueSessionId("streaming_receiver")), - messager_(message_port, - session_id_, - [this](Error error) { - OSP_DLOG_WARN << "Got a session messager error: " << error; - client_->OnError(this, error); - }), + messenger_(message_port, + session_id_, + [this](Error error) { + OSP_DLOG_WARN << "Got a session messenger error: " << error; + client_->OnError(this, error); + }), packet_router_(environment_) { OSP_DCHECK(client_); OSP_DCHECK(environment_); - messager_.SetHandler( + OSP_DCHECK(!std::any_of( + preferences_.video_codecs.begin(), preferences_.video_codecs.end(), + [](VideoCodec c) { return c == VideoCodec::kNotSpecified; })); + OSP_DCHECK(!std::any_of( + preferences_.audio_codecs.begin(), preferences_.audio_codecs.end(), + [](AudioCodec c) { return c == AudioCodec::kNotSpecified; })); + + messenger_.SetHandler( SenderMessage::Type::kOffer, [this](SenderMessage message) { OnOffer(std::move(message)); }); + messenger_.SetHandler(SenderMessage::Type::kGetCapabilities, + [this](SenderMessage message) { + OnCapabilitiesRequest(std::move(message)); + }); + messenger_.SetHandler(SenderMessage::Type::kRpc, + [this](SenderMessage message) { + this->OnRpcMessage(std::move(message)); + }); environment_->SetSocketSubscriber(this); } @@ -144,16 +302,16 @@ void ReceiverSession::OnOffer(SenderMessage message) { properties->sequence_number = message.sequence_number; const Offer& offer = absl::get<Offer>(message.body); - if (!offer.audio_streams.empty() && !preferences_.audio_codecs.empty()) { - properties->selected_audio = - SelectStream(preferences_.audio_codecs, offer.audio_streams); - } - - if (!offer.video_streams.empty() && !preferences_.video_codecs.empty()) { - properties->selected_video = - SelectStream(preferences_.video_codecs, offer.video_streams); + if (offer.cast_mode == CastMode::kRemoting) { + if (!preferences_.remoting) { + SendErrorAnswerReply(message.sequence_number, + "This receiver does not have remoting enabled."); + return; + } } + properties->mode = offer.cast_mode; + SelectStreams(offer, properties.get()); if (!properties->IsValid()) { SendErrorAnswerReply(message.sequence_number, "Failed to select any streams from OFFER"); @@ -180,6 +338,72 @@ void ReceiverSession::OnOffer(SenderMessage message) { } } +void ReceiverSession::OnCapabilitiesRequest(SenderMessage message) { + if (message.sequence_number < 0) { + OSP_DLOG_WARN + << "Dropping offer with missing sequence number, can't respond"; + return; + } + + ReceiverMessage response{ + ReceiverMessage::Type::kCapabilitiesResponse, message.sequence_number, + true /* valid */ + }; + if (preferences_.remoting) { + response.body = CreateRemotingCapabilityV2(); + } else { + response.valid = false; + response.body = + ReceiverError{static_cast<int>(Error::Code::kRemotingNotSupported), + "Remoting is not supported"}; + } + + const Error result = messenger_.SendMessage(std::move(response)); + if (!result.ok()) { + client_->OnError(this, std::move(result)); + } +} + +void ReceiverSession::OnRpcMessage(SenderMessage message) { + if (!message.valid) { + OSP_DLOG_WARN + << "Bad RPC message. This may or may not represent a serious problem."; + return; + } + + const auto& body = absl::get<std::vector<uint8_t>>(message.body); + if (!rpc_messenger_) { + OSP_DLOG_INFO << "Received an RPC message without having a messenger."; + return; + } + rpc_messenger_->ProcessMessageFromRemote(body.data(), body.size()); +} + +void ReceiverSession::SelectStreams(const Offer& offer, + SessionProperties* properties) { + if (offer.cast_mode == CastMode::kMirroring) { + if (!offer.audio_streams.empty() && !preferences_.audio_codecs.empty()) { + properties->selected_audio = + SelectStream(preferences_.audio_codecs, client_, offer.audio_streams); + } + if (!offer.video_streams.empty() && !preferences_.video_codecs.empty()) { + properties->selected_video = + SelectStream(preferences_.video_codecs, client_, offer.video_streams); + } + } else { + OSP_DCHECK(offer.cast_mode == CastMode::kRemoting); + + if (offer.audio_streams.size() == 1) { + properties->selected_audio = + std::make_unique<AudioStream>(offer.audio_streams[0]); + } + if (offer.video_streams.size() == 1) { + properties->selected_video = + std::make_unique<VideoStream>(offer.video_streams[0]); + } + } +} + void ReceiverSession::InitializeSession(const SessionProperties& properties) { Answer answer = ConstructAnswer(properties); if (!answer.IsValid()) { @@ -192,8 +416,23 @@ void ReceiverSession::InitializeSession(const SessionProperties& properties) { // Only spawn receivers if we know we have a valid answer message. ConfiguredReceivers receivers = SpawnReceivers(properties); - client_->OnMirroringNegotiated(this, std::move(receivers)); - const Error result = messager_.SendMessage(ReceiverMessage{ + if (properties.mode == CastMode::kMirroring) { + client_->OnNegotiated(this, std::move(receivers)); + } else { + // TODO(jophba): cleanup sequence number usage. + rpc_messenger_ = std::make_unique<RpcMessenger>([this](std::vector<uint8_t> message) { + Error error = this->messenger_.SendMessage( + ReceiverMessage{ReceiverMessage::Type::kRpc, -1, true /* valid */, + std::move(message)}); + + if (!error.ok()) { + OSP_LOG_WARN << "Failed to send RPC message: " << error; + } + }); + client_->OnRemotingNegotiated( + this, RemotingNegotiation{std::move(receivers), rpc_messenger_.get()}); + } + const Error result = messenger_.SendMessage(ReceiverMessage{ ReceiverMessage::Type::kAnswer, properties.sequence_number, true /* valid */, std::move(answer)}); if (!result.ok()) { @@ -212,7 +451,7 @@ std::unique_ptr<Receiver> ReceiverSession::ConstructReceiver( std::move(config)); } -ConfiguredReceivers ReceiverSession::SpawnReceivers( +ReceiverSession::ConfiguredReceivers ReceiverSession::SpawnReceivers( const SessionProperties& properties) { OSP_DCHECK(properties.IsValid()); ResetReceivers(Client::kRenegotiated); @@ -226,24 +465,21 @@ ConfiguredReceivers ReceiverSession::SpawnReceivers( properties.selected_audio->stream.channels, properties.selected_audio->bit_rate, properties.selected_audio->stream.rtp_timebase, - properties.selected_audio->stream.target_delay}; + properties.selected_audio->stream.target_delay, + properties.selected_audio->stream.codec_parameter}; } VideoCaptureConfig video_config; if (properties.selected_video) { current_video_receiver_ = ConstructReceiver(properties.selected_video->stream); - std::vector<DisplayResolution> display_resolutions; - std::transform(properties.selected_video->resolutions.begin(), - properties.selected_video->resolutions.end(), - std::back_inserter(display_resolutions), - ToDisplayResolution); - video_config = VideoCaptureConfig{ - properties.selected_video->codec, - FrameRate{properties.selected_video->max_frame_rate.numerator, - properties.selected_video->max_frame_rate.denominator}, - properties.selected_video->max_bit_rate, std::move(display_resolutions), - properties.selected_video->stream.target_delay}; + video_config = + VideoCaptureConfig{properties.selected_video->codec, + properties.selected_video->max_frame_rate, + properties.selected_video->max_bit_rate, + properties.selected_video->resolutions, + properties.selected_video->stream.target_delay, + properties.selected_video->stream.codec_parameter}; } return ConfiguredReceivers{ @@ -256,6 +492,7 @@ void ReceiverSession::ResetReceivers(Client::ReceiversDestroyingReason reason) { client_->OnReceiversDestroying(this, reason); current_audio_receiver_.reset(); current_video_receiver_.reset(); + rpc_messenger_.reset(); } } @@ -264,42 +501,88 @@ Answer ReceiverSession::ConstructAnswer(const SessionProperties& properties) { std::vector<int> stream_indexes; std::vector<Ssrc> stream_ssrcs; + Constraints constraints; if (properties.selected_audio) { stream_indexes.push_back(properties.selected_audio->stream.index); stream_ssrcs.push_back(properties.selected_audio->stream.ssrc + 1); + + for (const auto& limit : preferences_.audio_limits) { + if (limit.codec == properties.selected_audio->codec || + limit.applies_to_all_codecs) { + constraints.audio = AudioConstraints{ + limit.max_sample_rate, limit.max_channels, limit.min_bit_rate, + limit.max_bit_rate, limit.max_delay, + }; + break; + } + } } if (properties.selected_video) { stream_indexes.push_back(properties.selected_video->stream.index); stream_ssrcs.push_back(properties.selected_video->stream.ssrc + 1); - } - absl::optional<Constraints> constraints; - if (preferences_.constraints) { - constraints = absl::optional<Constraints>(*preferences_.constraints); + for (const auto& limit : preferences_.video_limits) { + if (limit.codec == properties.selected_video->codec || + limit.applies_to_all_codecs) { + constraints.video = VideoConstraints{ + limit.max_pixels_per_second, absl::nullopt, /* min dimensions */ + limit.max_dimensions, limit.min_bit_rate, + limit.max_bit_rate, limit.max_delay, + }; + break; + } + } } absl::optional<DisplayDescription> display; if (preferences_.display_description) { - display = - absl::optional<DisplayDescription>(*preferences_.display_description); + const auto* d = preferences_.display_description.get(); + display = DisplayDescription{d->dimensions, absl::nullopt, + d->can_scale_content + ? AspectRatioConstraint::kVariable + : AspectRatioConstraint::kFixed}; } + // Only set the constraints in the answer if they are valid (meaning we + // successfully found limits above). + absl::optional<Constraints> answer_constraints; + if (constraints.IsValid()) { + answer_constraints = std::move(constraints); + } return Answer{environment_->GetBoundLocalEndpoint().port, - std::move(stream_indexes), - std::move(stream_ssrcs), - std::move(constraints), - std::move(display), - std::vector<int>{}, // receiver_rtcp_event_log - std::vector<int>{}, // receiver_rtcp_dscp - supports_wifi_status_reporting_}; + std::move(stream_indexes), std::move(stream_ssrcs), + answer_constraints, std::move(display)}; +} + +ReceiverCapability ReceiverSession::CreateRemotingCapabilityV2() { + // If we don't support remoting, there is no reason to respond to + // capability requests--they are not used for mirroring. + OSP_DCHECK(preferences_.remoting); + ReceiverCapability capability; + capability.remoting_version = kSupportedRemotingVersion; + + for (const AudioCodec& codec : preferences_.audio_codecs) { + capability.media_capabilities.push_back(ToCapability(codec)); + } + for (const VideoCodec& codec : preferences_.video_codecs) { + capability.media_capabilities.push_back(ToCapability(codec)); + } + + if (preferences_.remoting->supports_chrome_audio_codecs) { + capability.media_capabilities.push_back(MediaCapability::kAudio); + } + if (preferences_.remoting->supports_4k) { + capability.media_capabilities.push_back(MediaCapability::k4k); + } + return capability; } void ReceiverSession::SendErrorAnswerReply(int sequence_number, const char* message) { const Error error(Error::Code::kParseError, message); OSP_DLOG_WARN << message; - const Error result = messager_.SendMessage(ReceiverMessage{ + const Error result = messenger_.SendMessage(ReceiverMessage{ ReceiverMessage::Type::kAnswer, sequence_number, false /* valid */, ReceiverError{static_cast<int>(Error::Code::kParseError), message}}); if (!result.ok()) { @@ -307,5 +590,64 @@ void ReceiverSession::SendErrorAnswerReply(int sequence_number, } } +bool ReceiverSession::VideoLimits::IsSupersetOf( + const ReceiverSession::VideoLimits& second) const { + return (applies_to_all_codecs == second.applies_to_all_codecs) && + (applies_to_all_codecs || codec == second.codec) && + (max_pixels_per_second >= second.max_pixels_per_second) && + (min_bit_rate <= second.min_bit_rate) && + (max_bit_rate >= second.max_bit_rate) && + (max_delay >= second.max_delay) && + (max_dimensions.IsSupersetOf(second.max_dimensions)); +} + +bool ReceiverSession::AudioLimits::IsSupersetOf( + const ReceiverSession::AudioLimits& second) const { + return (applies_to_all_codecs == second.applies_to_all_codecs) && + (applies_to_all_codecs || codec == second.codec) && + (max_sample_rate >= second.max_sample_rate) && + (max_channels >= second.max_channels) && + (min_bit_rate <= second.min_bit_rate) && + (max_bit_rate >= second.max_bit_rate) && + (max_delay >= second.max_delay); +} + +bool ReceiverSession::Display::IsSupersetOf( + const ReceiverSession::Display& other) const { + return dimensions.IsSupersetOf(other.dimensions) && + (can_scale_content || !other.can_scale_content); +} + +bool ReceiverSession::RemotingPreferences::IsSupersetOf( + const ReceiverSession::RemotingPreferences& other) const { + return (supports_chrome_audio_codecs || + !other.supports_chrome_audio_codecs) && + (supports_4k || !other.supports_4k); +} + +bool ReceiverSession::Preferences::IsSupersetOf( + const ReceiverSession::Preferences& other) const { + // Check simple cases first. + if ((!!display_description != !!other.display_description) || + (display_description && + !display_description->IsSupersetOf(*other.display_description))) { + return false; + } else if (other.remoting && + (!remoting || !remoting->IsSupersetOf(*other.remoting))) { + return false; + } + + // Then check set codecs. + if (IsMissingCodecs(video_codecs, other.video_codecs) || + IsMissingCodecs(audio_codecs, other.audio_codecs)) { + return false; + } + + // Then check limits. Do this last because it's the most resource intensive to + // check. + return HasLessRestrictiveLimits(video_limits, other.video_limits) && + HasLessRestrictiveLimits(audio_limits, other.audio_limits); +} + } // namespace cast } // namespace openscreen diff --git a/cast/streaming/receiver_session.h b/cast/streaming/receiver_session.h index b8365a36..caf271b6 100644 --- a/cast/streaming/receiver_session.h +++ b/cast/streaming/receiver_session.h @@ -11,14 +11,15 @@ #include <vector> #include "cast/common/public/message_port.h" -#include "cast/streaming/answer_messages.h" #include "cast/streaming/capture_configs.h" +#include "cast/streaming/constants.h" #include "cast/streaming/offer_messages.h" #include "cast/streaming/receiver_packet_router.h" +#include "cast/streaming/resolution.h" +#include "cast/streaming/rpc_messenger.h" #include "cast/streaming/sender_message.h" #include "cast/streaming/session_config.h" -#include "cast/streaming/session_messager.h" -#include "util/json/json_serialization.h" +#include "cast/streaming/session_messenger.h" namespace openscreen { namespace cast { @@ -26,6 +27,18 @@ namespace cast { class Environment; class Receiver; +// This class is responsible for listening for streaming requests from Cast +// Sender devices, then negotiating capture constraints and instantiating audio +// and video Receiver objects. +// The owner of this session is expected to provide a client for +// updates, an environment for getting UDP socket information (as well as +// other OS dependencies), and a set of preferences to be used for +// negotiation. +// +// NOTE: In some cases, the session initialization may be pending waiting for +// the UDP socket to be ready. In this case, the receivers and the answer +// message will not be configured and sent until the UDP socket has finished +// binding. class ReceiverSession final : public Environment::SocketSubscriber { public: // Upon successful negotiation, a set of configured receivers is constructed @@ -50,35 +63,184 @@ class ReceiverSession final : public Environment::SocketSubscriber { VideoCaptureConfig video_config; }; + // This struct contains all of the information necessary to begin remoting + // once we get a remoting request from a Sender. + struct RemotingNegotiation { + // The configured receivers set to be used for handling audio and + // video streams. Unlike in the general streaming case, when we are remoting + // we don't know the codec and other information about the stream until + // the sender provices that information through the + // DemuxerStreamInitializeCallback RPC method. + ConfiguredReceivers receivers; + + // The RPC messenger to be used for subscribing to remoting proto messages. + // Unlike the SenderSession API, the RPC messenger is negotiation specific. + // The messenger is torn down when |OnReceiversDestroying| is called, and + // is owned by the ReceiverSession. + RpcMessenger* messenger; + }; + // The embedder should provide a client for handling connections. - // When a connection is established, the OnMirroringNegotiated callback is - // called. + // When a connection is established, the OnNegotiated callback is called. class Client { public: + // Currently we only care about the session ending or being renegotiated, + // which means that we don't have to tear down as much state. enum ReceiversDestroyingReason { kEndOfSession, kRenegotiated }; - // Called when a new set of receivers has been negotiated. This may be - // called multiple times during a session, as renegotiations occur. - virtual void OnMirroringNegotiated(const ReceiverSession* session, - ConfiguredReceivers receivers) = 0; + // Called when a set of streaming receivers has been negotiated. Both this + // and |OnRemotingNegotiated| may be called repeatedly as negotiations occur + // through the life of a session. + virtual void OnNegotiated(const ReceiverSession* session, + ConfiguredReceivers receivers) = 0; + + // Called when a set of remoting receivers has been negotiated. This will + // only be called if |RemotingPreferences| are provided as part of + // constructing the ReceiverSession object. + virtual void OnRemotingNegotiated(const ReceiverSession* session, + RemotingNegotiation negotiation) {} // Called immediately preceding the destruction of this session's receivers. - // If |reason| is |kEndOfSession|, OnMirroringNegotiated() will never be - // called again; if it is |kRenegotiated|, OnMirroringNegotiated() will be - // called again soon with a new set of Receivers to use. + // If |reason| is |kEndOfSession|, OnNegotiated() will never be called + // again; if it is |kRenegotiated|, OnNegotiated() will be called again + // soon with a new set of Receivers to use. // // Before returning, the implementation must ensure that all references to - // the Receivers, from the last call to OnMirroringNegotiated(), have been - // cleared. + // the Receivers, from the last call to OnNegotiated(), have been cleared. virtual void OnReceiversDestroying(const ReceiverSession* session, ReceiversDestroyingReason reason) = 0; + // Called whenever an error that the client may care about occurs. + // Recoverable errors are usually logged by the receiver session instead + // of reported here. virtual void OnError(const ReceiverSession* session, Error error) = 0; + // Called to verify whether a given codec parameter is supported by + // this client. If not overriden, this always assumes true. + // This method is used only for secondary matching, e.g. + // if you don't add VideoCodec::kHevc to the VideoCaptureConfig, then + // supporting codec parameter "hev1.1.6.L153.B0" does not matter. + // + // The codec parameter support callback is optional, however if provided + // then any offered streams that have a non-empty codec parameter field must + // match. If a stream does not have a codec parameter, this callback will + // not be called. + virtual bool SupportsCodecParameter(const std::string& parameter) { + return true; + } + protected: virtual ~Client(); }; + // Information about the display the receiver is attached to. + struct Display { + // Returns true if all configurations supported by |other| are also + // supported by this instance. + bool IsSupersetOf(const Display& other) const; + + // The display limitations of the actual screen, used to provide upper + // bounds on streams. For example, we will never + // send 60FPS if it is going to be displayed on a 30FPS screen. + // Note that we may exceed the display width and height for standard + // content sizes like 720p or 1080p. + Dimensions dimensions; + + // Whether the embedder is capable of scaling content. If set to false, + // the sender will manage the aspect ratio scaling. + bool can_scale_content = false; + }; + + // Codec-specific audio limits for playback. + struct AudioLimits { + // Returns true if all configurations supported by |other| are also + // supported by this instance. + bool IsSupersetOf(const AudioLimits& other) const; + + // Whether or not these limits apply to all codecs. + bool applies_to_all_codecs = false; + + // Audio codec these limits apply to. Note that if |applies_to_all_codecs| + // is true this field is ignored. + AudioCodec codec; + + // Maximum audio sample rate. + int max_sample_rate = kDefaultAudioSampleRate; + + // Maximum audio channels, default is currently stereo. + int max_channels = kDefaultAudioChannels; + + // Minimum and maximum bitrates. Generally capture is done at the maximum + // bit rate, since audio bandwidth is much lower than video for most + // content. + int min_bit_rate = kDefaultAudioMinBitRate; + int max_bit_rate = kDefaultAudioMaxBitRate; + + // Max playout delay in milliseconds. + std::chrono::milliseconds max_delay = kDefaultMaxDelayMs; + }; + + // Codec-specific video limits for playback. + struct VideoLimits { + // Returns true if all configurations supported by |other| are also + // supported by this instance. + bool IsSupersetOf(const VideoLimits& other) const; + + // Whether or not these limits apply to all codecs. + bool applies_to_all_codecs = false; + + // Video codec these limits apply to. Note that if |applies_to_all_codecs| + // is true this field is ignored. + VideoCodec codec; + + // Maximum pixels per second. Value is the standard amount of pixels + // for 1080P at 30FPS. + int max_pixels_per_second = 1920 * 1080 * 30; + + // Maximum dimensions. Minimum dimensions try to use the same aspect + // ratio and are generated from the spec. + Dimensions max_dimensions = {1920, 1080, {kDefaultFrameRate, 1}}; + + // Minimum and maximum bitrates. Default values are based on default min and + // max dimensions, embedders that support different display dimensions + // should strongly consider setting these fields. + int min_bit_rate = kDefaultVideoMinBitRate; + int max_bit_rate = kDefaultVideoMaxBitRate; + + // Max playout delay in milliseconds. + std::chrono::milliseconds max_delay = kDefaultMaxDelayMs; + }; + + // This struct is used to provide preferences for setting up and running + // remoting streams. These properties are based on the current control + // protocol and allow remoting with current senders. + struct RemotingPreferences { + // Returns true if all configurations supported by |other| are also + // supported by this instance. + bool IsSupersetOf(const RemotingPreferences& other) const; + + // Current remoting senders take an "all or nothing" support for audio + // codec support. While Opus and AAC support is handled in our Preferences' + // |audio_codecs| property, support for the following codecs must be + // enabled or disabled all together: + // MP3 + // PCM, including Mu-Law, S16BE, S24BE, and ALAW variants + // Ogg Vorbis + // FLAC + // AMR, including narrow band (NB) and wide band (WB) variants + // GSM Mobile Station (MS) + // EAC3 (Dolby Digital Plus) + // ALAC (Apple Lossless) + // AC-3 (Dolby Digital) + // These properties are tied directly to what Chrome supports. See: + // https://source.chromium.org/chromium/chromium/src/+/master:media/base/audio_codecs.h + bool supports_chrome_audio_codecs = false; + + // Current remoting senders assume that the receiver supports 4K for all + // video codecs supplied in |video_codecs|, or none of them. + bool supports_4k = false; + }; + // Note: embedders are required to implement the following // codecs to be Cast V2 compliant: H264, VP8, AAC, Opus. struct Preferences { @@ -87,22 +249,39 @@ class ReceiverSession final : public Environment::SocketSubscriber { std::vector<AudioCodec> audio_codecs); Preferences(std::vector<VideoCodec> video_codecs, std::vector<AudioCodec> audio_codecs, - std::unique_ptr<Constraints> constraints, - std::unique_ptr<DisplayDescription> description); + std::vector<AudioLimits> audio_limits, + std::vector<VideoLimits> video_limits, + std::unique_ptr<Display> description); Preferences(Preferences&&) noexcept; - Preferences(const Preferences&) = delete; + Preferences(const Preferences&); Preferences& operator=(Preferences&&) noexcept; - Preferences& operator=(const Preferences&) = delete; + Preferences& operator=(const Preferences&); + // Returns true if all configurations supported by |other| are also + // supported by this instance. + bool IsSupersetOf(const Preferences& other) const; + + // Audio and video codec preferences. Should be supplied in order of + // preference, e.g. in this example if we get both VP8 and H264 we will + // generally select the VP8 offer. If a codec is omitted from these fields + // it will never be selected in the OFFER/ANSWER negotiation. std::vector<VideoCodec> video_codecs{VideoCodec::kVp8, VideoCodec::kH264}; std::vector<AudioCodec> audio_codecs{AudioCodec::kOpus, AudioCodec::kAac}; - // The embedder has the option of directly specifying the display - // information and video/audio constraints that will be passed along to - // senders during the offer/answer exchange. If nullptr, these are ignored. - std::unique_ptr<Constraints> constraints; - std::unique_ptr<DisplayDescription> display_description; + // Optional limitation fields that help the sender provide a delightful + // cast experience. Although optional, highly recommended. + // NOTE: embedders that wish to apply the same limits for all codecs can + // pass a vector of size 1 with the |applies_to_all_codecs| field set to + // true. + std::vector<AudioLimits> audio_limits; + std::vector<VideoLimits> video_limits; + std::unique_ptr<Display> display_description; + + // Libcast remoting support is opt-in: embedders wishing to field remoting + // offers may provide a set of remoting preferences, or leave nullptr for + // all remoting OFFERs to be rejected in favor of continuing streaming. + std::unique_ptr<RemotingPreferences> remoting; }; ReceiverSession(Client* const client, @@ -122,9 +301,18 @@ class ReceiverSession final : public Environment::SocketSubscriber { void OnSocketInvalid(Error error) override; private: + // In some cases, such as waiting for the UDP socket to be bound, we + // may have a pending session that cannot start yet. This class provides + // all necessary info to instantiate a session. struct SessionProperties { + // The cast mode the OFFER was sent for. + CastMode mode; + + // The selected audio and video streams from the original OFFER message. std::unique_ptr<AudioStream> selected_audio; std::unique_ptr<VideoStream> selected_video; + + // The sequence number of the OFFER that produced these properties. int sequence_number; // To be valid either the audio or video must be selected, and we must @@ -134,6 +322,12 @@ class ReceiverSession final : public Environment::SocketSubscriber { // Specific message type handler methods. void OnOffer(SenderMessage message); + void OnCapabilitiesRequest(SenderMessage message); + void OnRpcMessage(SenderMessage message); + + // Selects streams from an offer based on its configuration, and sets + // them in the session properties. + void SelectStreams(const Offer& offer, SessionProperties* properties); // Creates receivers and sends an appropriate Answer message using the // session properties. @@ -146,9 +340,13 @@ class ReceiverSession final : public Environment::SocketSubscriber { // video streams. NOTE: either audio or video may be null, but not both. ConfiguredReceivers SpawnReceivers(const SessionProperties& properties); - // Callers of this method should ensure at least one stream is non-null. + // Creates an ANSWER object. Assumes at least one stream is not nullptr. Answer ConstructAnswer(const SessionProperties& properties); + // Creates a ReceiverCapability version 2 object. This will be deprecated + // as part of https://issuetracker.google.com/184429130. + ReceiverCapability CreateRemotingCapabilityV2(); + // Handles resetting receivers and notifying the client. void ResetReceivers(Client::ReceiversDestroyingReason reason); @@ -158,21 +356,27 @@ class ReceiverSession final : public Environment::SocketSubscriber { Client* const client_; Environment* const environment_; const Preferences preferences_; + // The sender_id of this session. const std::string session_id_; - ReceiverSessionMessager messager_; - // In some cases, the session initialization may be pending waiting for the - // UDP socket to be ready. In this case, the receivers and the answer - // message will not be configured and sent until the UDP socket has finished - // binding. - std::unique_ptr<SessionProperties> pending_session_; + // The session messenger used for the lifetime of this session. + ReceiverSessionMessenger messenger_; - bool supports_wifi_status_reporting_ = false; + // The packet router to be used for all Receivers spawned by this session. ReceiverPacketRouter packet_router_; + // Any session pending while the UDP socket is being bound. + std::unique_ptr<SessionProperties> pending_session_; + + // The negotiated receivers we own, clients are notified of destruction + // through |Client::OnReceiversDestroying|. std::unique_ptr<Receiver> current_audio_receiver_; std::unique_ptr<Receiver> current_video_receiver_; + + // If remoting, we store the RpcMessenger used by the embedder to send RPC + // messages from the remoting protobuf specification. + std::unique_ptr<RpcMessenger> rpc_messenger_; }; } // namespace cast diff --git a/cast/streaming/receiver_session_unittest.cc b/cast/streaming/receiver_session_unittest.cc index 1914cbd5..098695a8 100644 --- a/cast/streaming/receiver_session_unittest.cc +++ b/cast/streaming/receiver_session_unittest.cc @@ -15,6 +15,7 @@ #include "platform/test/fake_clock.h" #include "platform/test/fake_task_runner.h" #include "util/chrono_helpers.h" +#include "util/json/json_serialization.h" using ::testing::_; using ::testing::InSequence; @@ -33,7 +34,6 @@ constexpr char kValidOfferMessage[] = R"({ "seqNum": 1337, "offer": { "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": [ { "index": 31337, @@ -78,6 +78,26 @@ constexpr char kValidOfferMessage[] = R"({ ] }, { + "index": 31339, + "type": "video_source", + "codecName": "hevc", + "codecParameter": "hev1.1.6.L150.B0", + "rtpProfile": "cast", + "rtpPayloadType": 127, + "ssrc": 19088746, + "maxFrameRate": "120", + "timeBase": "1/90000", + "maxBitRate": 5000000, + "aesKey": "040d756791711fd3adb939066e6d8690", + "aesIvMask": "9ff0f022a959150e70a2d05a6c184aed", + "resolutions": [ + { + "width": 1920, + "height": 1080 + } + ] + }, + { "index": 1337, "type": "audio_source", "codecName": "opus", @@ -94,12 +114,53 @@ constexpr char kValidOfferMessage[] = R"({ } })"; +constexpr char kValidRemotingOfferMessage[] = R"({ + "type": "OFFER", + "seqNum": 419, + "offer": { + "castMode": "remoting", + "supportedStreams": [ + { + "index": 31339, + "type": "video_source", + "codecName": "REMOTE_VIDEO", + "rtpProfile": "cast", + "rtpPayloadType": 127, + "ssrc": 19088745, + "maxFrameRate": "60000/1000", + "timeBase": "1/90000", + "maxBitRate": 5432101, + "aesKey": "040d756791711fd3adb939066e6d8690", + "aesIvMask": "9ff0f022a959150e70a2d05a6c184aed", + "resolutions": [ + { + "width": 1920, + "height":1080 + } + ] + }, + { + "index": 31340, + "type": "audio_source", + "codecName": "REMOTE_AUDIO", + "rtpProfile": "cast", + "rtpPayloadType": 97, + "ssrc": 19088747, + "bitRate": 125000, + "timeBase": "1/48000", + "channels": 2, + "aesKey": "51027e4e2347cbcb49d57ef10177aebc", + "aesIvMask": "7f12a19be62a36c04ae4116caaeff6d1" + } + ] + } +})"; + constexpr char kNoAudioOfferMessage[] = R"({ "type": "OFFER", "seqNum": 1337, "offer": { "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": [ { "index": 31338, @@ -131,7 +192,6 @@ constexpr char kInvalidCodecOfferMessage[] = R"({ "seqNum": 1337, "offer": { "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": [ { "index": 31338, @@ -163,7 +223,6 @@ constexpr char kNoVideoOfferMessage[] = R"({ "seqNum": 1337, "offer": { "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": [ { "index": 1337, @@ -187,7 +246,6 @@ constexpr char kNoAudioOrVideoOfferMessage[] = R"({ "seqNum": 1337, "offer": { "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": [] } })"; @@ -197,7 +255,6 @@ constexpr char kInvalidJsonOfferMessage[] = R"({ "seqNum": 1337, "offer": { "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": [ } })"; @@ -211,7 +268,6 @@ constexpr char kMissingSeqNumOfferMessage[] = R"({ "type": "OFFER", "offer": { "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": [] } })"; @@ -221,7 +277,6 @@ constexpr char kValidJsonInvalidFormatOfferMessage[] = R"({ "seqNum": 1337, "offer": { "castMode": "mirroring", - "receiverGetStatus": true, "supportedStreams": "anything" } })"; @@ -246,17 +301,36 @@ constexpr char kInvalidTypeMessage[] = R"({ "seqNum": 1337 })"; +constexpr char kGetCapabilitiesMessage[] = R"({ + "seqNum": 820263770, + "type": "GET_CAPABILITIES" +})"; + +constexpr char kRpcMessage[] = R"({ + "rpc" : "CGQQnBiCGQgSAggMGgIIBg==", + "seqNum" : 2, + "type" : "RPC" +})"; + class FakeClient : public ReceiverSession::Client { public: MOCK_METHOD(void, - OnMirroringNegotiated, + OnNegotiated, (const ReceiverSession*, ReceiverSession::ConfiguredReceivers), (override)); MOCK_METHOD(void, + OnRemotingNegotiated, + (const ReceiverSession*, ReceiverSession::RemotingNegotiation), + (override)); + MOCK_METHOD(void, OnReceiversDestroying, (const ReceiverSession*, ReceiversDestroyingReason), (override)); MOCK_METHOD(void, OnError, (const ReceiverSession*, Error error), (override)); + MOCK_METHOD(bool, + SupportsCodecParameter, + (const std::string& parameter), + (override)); }; void ExpectIsErrorAnswerMessage(const ErrorOr<Json::Value>& message_or_error) { @@ -288,12 +362,17 @@ class ReceiverSessionTest : public ::testing::Test { return environment_; } - void SetUp() { + void SetUp() { SetUpWithPreferences(ReceiverSession::Preferences{}); } + + // Since preferences are constant throughout the life of a session, + // changing them requires configuring a new session. + void SetUpWithPreferences(ReceiverSession::Preferences preferences) { + session_.reset(); message_port_ = std::make_unique<SimpleMessagePort>("sender-12345"); environment_ = MakeEnvironment(); - session_ = std::make_unique<ReceiverSession>( - &client_, environment_.get(), message_port_.get(), - ReceiverSession::Preferences{}); + session_ = std::make_unique<ReceiverSession>(&client_, environment_.get(), + message_port_.get(), + std::move(preferences)); } protected: @@ -315,7 +394,7 @@ class ReceiverSessionTest : public ::testing::Test { TEST_F(ReceiverSessionTest, CanNegotiateWithDefaultPreferences) { InSequence s; - EXPECT_CALL(client_, OnMirroringNegotiated(session_.get(), _)) + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)) .WillOnce([](const ReceiverSession* session_, ReceiverSession::ConfiguredReceivers cr) { EXPECT_TRUE(cr.audio_receiver); @@ -364,9 +443,6 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithDefaultPreferences) { EXPECT_LT(0, answer_body["udpPort"].asInt()); EXPECT_GT(65535, answer_body["udpPort"].asInt()); - // Get status should always be false, as we have no plans to implement it. - EXPECT_EQ(false, answer_body["receiverGetStatus"].asBool()); - // Constraints and display should not be present with no preferences. EXPECT_TRUE(answer_body["constraints"].isNull()); EXPECT_TRUE(answer_body["display"].isNull()); @@ -378,7 +454,7 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithCustomCodecPreferences) { ReceiverSession::Preferences{{VideoCodec::kVp9}, {AudioCodec::kOpus}}); InSequence s; - EXPECT_CALL(client_, OnMirroringNegotiated(&session, _)) + EXPECT_CALL(client_, OnNegotiated(&session, _)) .WillOnce([](const ReceiverSession* session_, ReceiverSession::ConfiguredReceivers cr) { EXPECT_TRUE(cr.audio_receiver); @@ -400,28 +476,88 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithCustomCodecPreferences) { message_port_->ReceiveMessage(kValidOfferMessage); } -TEST_F(ReceiverSessionTest, CanNegotiateWithCustomConstraints) { - auto constraints = std::make_unique<Constraints>(Constraints{ - AudioConstraints{48001, 2, 32001, 32002, milliseconds(3001)}, - VideoConstraints{3.14159, - absl::optional<Dimensions>( - Dimensions{320, 240, SimpleFraction{24, 1}}), - Dimensions{1920, 1080, SimpleFraction{144, 1}}, 300000, - 90000000, milliseconds(1000)}}); +TEST_F(ReceiverSessionTest, RejectsStreamWithUnsupportedCodecParameter) { + ReceiverSession::Preferences preferences({VideoCodec::kHevc}, + {AudioCodec::kOpus}); + EXPECT_CALL(client_, SupportsCodecParameter(_)).WillRepeatedly(Return(false)); + ReceiverSession session(&client_, environment_.get(), message_port_.get(), + preferences); + InSequence s; + EXPECT_CALL(client_, OnNegotiated(&session, _)) + .WillOnce([](const ReceiverSession* session_, + ReceiverSession::ConfiguredReceivers cr) { + EXPECT_FALSE(cr.video_receiver); + }); + EXPECT_CALL(client_, OnReceiversDestroying( + &session, ReceiverSession::Client::kEndOfSession)); + message_port_->ReceiveMessage(kValidOfferMessage); +} + +TEST_F(ReceiverSessionTest, AcceptsStreamWithNoCodecParameter) { + ReceiverSession::Preferences preferences( + {VideoCodec::kHevc, VideoCodec::kVp9}, {AudioCodec::kOpus}); + EXPECT_CALL(client_, SupportsCodecParameter(_)).WillRepeatedly(Return(false)); + + ReceiverSession session(&client_, environment_.get(), message_port_.get(), + std::move(preferences)); + InSequence s; + EXPECT_CALL(client_, OnNegotiated(&session, _)) + .WillOnce([](const ReceiverSession* session_, + ReceiverSession::ConfiguredReceivers cr) { + EXPECT_TRUE(cr.video_receiver); + EXPECT_EQ(cr.video_config.codec, VideoCodec::kVp9); + }); + EXPECT_CALL(client_, OnReceiversDestroying( + &session, ReceiverSession::Client::kEndOfSession)); + message_port_->ReceiveMessage(kValidOfferMessage); +} + +TEST_F(ReceiverSessionTest, AcceptsStreamWithMatchingParameter) { + ReceiverSession::Preferences preferences({VideoCodec::kHevc}, + {AudioCodec::kOpus}); + EXPECT_CALL(client_, SupportsCodecParameter(_)) + .WillRepeatedly( + [](const std::string& param) { return param == "hev1.1.6.L150.B0"; }); + + ReceiverSession session(&client_, environment_.get(), message_port_.get(), + std::move(preferences)); + InSequence s; + EXPECT_CALL(client_, OnNegotiated(&session, _)) + .WillOnce([](const ReceiverSession* session_, + ReceiverSession::ConfiguredReceivers cr) { + EXPECT_TRUE(cr.video_receiver); + EXPECT_EQ(cr.video_config.codec, VideoCodec::kHevc); + }); + EXPECT_CALL(client_, OnReceiversDestroying( + &session, ReceiverSession::Client::kEndOfSession)); + message_port_->ReceiveMessage(kValidOfferMessage); +} - auto display = std::make_unique<DisplayDescription>(DisplayDescription{ - absl::optional<Dimensions>(Dimensions{640, 480, SimpleFraction{60, 1}}), - absl::optional<AspectRatio>(AspectRatio{16, 9}), - absl::optional<AspectRatioConstraint>(AspectRatioConstraint::kFixed)}); +TEST_F(ReceiverSessionTest, CanNegotiateWithLimits) { + std::vector<ReceiverSession::AudioLimits> audio_limits = { + {false, AudioCodec::kOpus, 48001, 2, 32001, 32002, milliseconds(3001)}}; + std::vector<ReceiverSession::VideoLimits> video_limits = { + {true, + VideoCodec::kVp9, + 62208000, + {1920, 1080, {144, 1}}, + 300000, + 90000000, + milliseconds(1000)}}; + + auto display = + std::make_unique<ReceiverSession::Display>(ReceiverSession::Display{ + {640, 480, {60, 1}}, false /* can scale content */}); ReceiverSession session(&client_, environment_.get(), message_port_.get(), ReceiverSession::Preferences{{VideoCodec::kVp9}, {AudioCodec::kOpus}, - std::move(constraints), + std::move(audio_limits), + std::move(video_limits), std::move(display)}); InSequence s; - EXPECT_CALL(client_, OnMirroringNegotiated(&session, _)); + EXPECT_CALL(client_, OnNegotiated(&session, _)); EXPECT_CALL(client_, OnReceiversDestroying( &session, ReceiverSession::Client::kEndOfSession)); message_port_->ReceiveMessage(kValidOfferMessage); @@ -434,14 +570,13 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithCustomConstraints) { const Json::Value answer = std::move(message_body.value()); const Json::Value& answer_body = answer["answer"]; - ASSERT_TRUE(answer_body.isObject()); + ASSERT_TRUE(answer_body.isObject()) << messages[0]; // Constraints and display should be valid with valid preferences. ASSERT_FALSE(answer_body["constraints"].isNull()); ASSERT_FALSE(answer_body["display"].isNull()); const Json::Value& display_json = answer_body["display"]; - EXPECT_EQ("16:9", display_json["aspectRatio"].asString()); EXPECT_EQ("60", display_json["dimensions"]["frameRate"].asString()); EXPECT_EQ(640, display_json["dimensions"]["width"].asInt()); EXPECT_EQ(480, display_json["dimensions"]["height"].asInt()); @@ -465,16 +600,12 @@ TEST_F(ReceiverSessionTest, CanNegotiateWithCustomConstraints) { EXPECT_EQ("144", video["maxDimensions"]["frameRate"].asString()); EXPECT_EQ(1920, video["maxDimensions"]["width"].asInt()); EXPECT_EQ(1080, video["maxDimensions"]["height"].asInt()); - EXPECT_DOUBLE_EQ(3.14159, video["maxPixelsPerSecond"].asDouble()); EXPECT_EQ(300000, video["minBitRate"].asInt()); - EXPECT_EQ("24", video["minDimensions"]["frameRate"].asString()); - EXPECT_EQ(320, video["minDimensions"]["width"].asInt()); - EXPECT_EQ(240, video["minDimensions"]["height"].asInt()); } TEST_F(ReceiverSessionTest, HandlesNoValidAudioStream) { InSequence s; - EXPECT_CALL(client_, OnMirroringNegotiated(session_.get(), _)); + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)); EXPECT_CALL(client_, OnReceiversDestroying(session_.get(), ReceiverSession::Client::kEndOfSession)); @@ -511,7 +642,7 @@ TEST_F(ReceiverSessionTest, HandlesInvalidCodec) { TEST_F(ReceiverSessionTest, HandlesNoValidVideoStream) { InSequence s; - EXPECT_CALL(client_, OnMirroringNegotiated(session_.get(), _)); + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)); EXPECT_CALL(client_, OnReceiversDestroying(session_.get(), ReceiverSession::Client::kEndOfSession)); @@ -533,8 +664,7 @@ TEST_F(ReceiverSessionTest, HandlesNoValidVideoStream) { } TEST_F(ReceiverSessionTest, HandlesNoValidStreams) { - // We shouldn't call OnMirroringNegotiated if we failed to negotiate any - // streams. + // We shouldn't call OnNegotiated if we failed to negotiate any streams. message_port_->ReceiveMessage(kNoAudioOrVideoOfferMessage); AssertGotAnErrorAnswerResponse(); } @@ -596,11 +726,11 @@ TEST_F(ReceiverSessionTest, DoesNotCrashOnMessagePortError) { TEST_F(ReceiverSessionTest, NotifiesReceiverDestruction) { InSequence s; - EXPECT_CALL(client_, OnMirroringNegotiated(session_.get(), _)); + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)); EXPECT_CALL(client_, OnReceiversDestroying(session_.get(), ReceiverSession::Client::kRenegotiated)); - EXPECT_CALL(client_, OnMirroringNegotiated(session_.get(), _)); + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)); EXPECT_CALL(client_, OnReceiversDestroying(session_.get(), ReceiverSession::Client::kEndOfSession)); @@ -638,7 +768,7 @@ TEST_F(ReceiverSessionTest, DelaysAnswerUntilEnvironmentIsReady) { // state() will not be called again--we just need to get the bind event. EXPECT_CALL(*environment_, GetBoundLocalEndpoint()) .WillOnce(Return(IPEndpoint{{10, 0, 0, 2}, 4567})); - EXPECT_CALL(client_, OnMirroringNegotiated(session_.get(), _)); + EXPECT_CALL(client_, OnNegotiated(session_.get(), _)); EXPECT_CALL(client_, OnReceiversDestroying(session_.get(), ReceiverSession::Client::kEndOfSession)); @@ -691,5 +821,387 @@ TEST_F(ReceiverSessionTest, ReturnsErrorAnswerIfEnvironmentIsInvalidated) { EXPECT_EQ("error", message_body.value()["result"].asString()); } +TEST_F(ReceiverSessionTest, ReturnsErrorCapabilitiesIfRemotingDisabled) { + message_port_->ReceiveMessage(kGetCapabilitiesMessage); + const auto& messages = message_port_->posted_messages(); + ASSERT_EQ(1u, messages.size()); + + // We should have an error response. + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + EXPECT_EQ("CAPABILITIES_RESPONSE", message_body.value()["type"].asString()); + EXPECT_EQ("error", message_body.value()["result"].asString()); +} + +TEST_F(ReceiverSessionTest, ReturnsCapabilitiesWithRemotingDefaults) { + ReceiverSession::Preferences preferences; + preferences.remoting = + std::make_unique<ReceiverSession::RemotingPreferences>(); + + SetUpWithPreferences(std::move(preferences)); + message_port_->ReceiveMessage(kGetCapabilitiesMessage); + const auto& messages = message_port_->posted_messages(); + ASSERT_EQ(1u, messages.size()); + + // We should have an error response. + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + EXPECT_EQ("CAPABILITIES_RESPONSE", message_body.value()["type"].asString()); + EXPECT_EQ("ok", message_body.value()["result"].asString()); + const ReceiverCapability response = + ReceiverCapability::Parse(message_body.value()["capabilities"]).value(); + + EXPECT_THAT( + response.media_capabilities, + testing::ElementsAre(MediaCapability::kOpus, MediaCapability::kAac, + MediaCapability::kVp8, MediaCapability::kH264)); +} + +TEST_F(ReceiverSessionTest, ReturnsCapabilitiesWithRemotingPreferences) { + ReceiverSession::Preferences preferences; + preferences.video_codecs = {VideoCodec::kH264}; + preferences.remoting = + std::make_unique<ReceiverSession::RemotingPreferences>(); + preferences.remoting->supports_chrome_audio_codecs = true; + preferences.remoting->supports_4k = true; + + SetUpWithPreferences(std::move(preferences)); + message_port_->ReceiveMessage(kGetCapabilitiesMessage); + const auto& messages = message_port_->posted_messages(); + ASSERT_EQ(1u, messages.size()); + + // We should have an error response. + auto message_body = json::Parse(messages[0]); + EXPECT_TRUE(message_body.is_value()); + EXPECT_EQ("CAPABILITIES_RESPONSE", message_body.value()["type"].asString()); + EXPECT_EQ("ok", message_body.value()["result"].asString()); + const ReceiverCapability response = + ReceiverCapability::Parse(message_body.value()["capabilities"]).value(); + + EXPECT_THAT( + response.media_capabilities, + testing::ElementsAre(MediaCapability::kOpus, MediaCapability::kAac, + MediaCapability::kH264, MediaCapability::kAudio, + MediaCapability::k4k)); +} + +TEST_F(ReceiverSessionTest, CanNegotiateRemoting) { + ReceiverSession::Preferences preferences; + preferences.remoting = + std::make_unique<ReceiverSession::RemotingPreferences>(); + preferences.remoting->supports_chrome_audio_codecs = true; + preferences.remoting->supports_4k = true; + SetUpWithPreferences(std::move(preferences)); + + InSequence s; + EXPECT_CALL(client_, OnRemotingNegotiated(session_.get(), _)) + .WillOnce([](const ReceiverSession* session_, + ReceiverSession::RemotingNegotiation negotiation) { + const auto& cr = negotiation.receivers; + EXPECT_TRUE(cr.audio_receiver); + EXPECT_EQ(cr.audio_receiver->config().sender_ssrc, 19088747u); + EXPECT_EQ(cr.audio_receiver->config().receiver_ssrc, 19088748u); + EXPECT_EQ(cr.audio_receiver->config().channels, 2); + EXPECT_EQ(cr.audio_receiver->config().rtp_timebase, 48000); + EXPECT_EQ(cr.audio_config.codec, AudioCodec::kNotSpecified); + + EXPECT_TRUE(cr.video_receiver); + EXPECT_EQ(cr.video_receiver->config().sender_ssrc, 19088745u); + EXPECT_EQ(cr.video_receiver->config().receiver_ssrc, 19088746u); + EXPECT_EQ(cr.video_receiver->config().channels, 1); + EXPECT_EQ(cr.video_receiver->config().rtp_timebase, 90000); + EXPECT_EQ(cr.video_config.codec, VideoCodec::kNotSpecified); + }); + EXPECT_CALL(client_, + OnReceiversDestroying(session_.get(), + ReceiverSession::Client::kEndOfSession)); + + message_port_->ReceiveMessage(kValidRemotingOfferMessage); +} + +TEST_F(ReceiverSessionTest, HandlesRpcMessage) { + ReceiverSession::Preferences preferences; + preferences.remoting = + std::make_unique<ReceiverSession::RemotingPreferences>(); + preferences.remoting->supports_chrome_audio_codecs = true; + preferences.remoting->supports_4k = true; + SetUpWithPreferences(std::move(preferences)); + + message_port_->ReceiveMessage(kRpcMessage); + const auto& messages = message_port_->posted_messages(); + // Nothing should happen yet, the session doesn't have a messenger. + ASSERT_EQ(0u, messages.size()); + + // We don't need to fully test that the subscription model on the RpcMessenger + // works, but we do want to test that the ReceiverSession has properly wired + // the RpcMessenger up to the backing SessionMessenger and can properly + // handle received RPC messages. + InSequence s; + bool received_initialize_message = false; + EXPECT_CALL(client_, OnRemotingNegotiated(session_.get(), _)) + .WillOnce([this, &received_initialize_message]( + const ReceiverSession* session_, + ReceiverSession::RemotingNegotiation negotiation) mutable { + negotiation.messenger->RegisterMessageReceiverCallback( + 100, [&received_initialize_message]( + std::unique_ptr<RpcMessage> message) mutable { + ASSERT_EQ(100, message->handle()); + ASSERT_EQ(RpcMessage::RPC_DS_INITIALIZE_CALLBACK, + message->proc()); + ASSERT_EQ(0, message->integer_value()); + received_initialize_message = true; + }); + + message_port_->ReceiveMessage(kRpcMessage); + }); + EXPECT_CALL(client_, + OnReceiversDestroying(session_.get(), + ReceiverSession::Client::kEndOfSession)); + + message_port_->ReceiveMessage(kValidRemotingOfferMessage); + ASSERT_TRUE(received_initialize_message); +} + +TEST_F(ReceiverSessionTest, VideoLimitsIsSupersetOf) { + ReceiverSession::VideoLimits first{}; + ReceiverSession::VideoLimits second = first; + + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + + first.max_pixels_per_second += 1; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + first.max_pixels_per_second = second.max_pixels_per_second; + + first.max_dimensions = {1921, 1090, {kDefaultFrameRate, 1}}; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + + second.max_dimensions = {1921, 1090, {kDefaultFrameRate + 1, 1}}; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + + second.max_dimensions = {2000, 1000, {kDefaultFrameRate, 1}}; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second.max_dimensions = first.max_dimensions; + + first.min_bit_rate += 1; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.min_bit_rate = second.min_bit_rate; + + first.max_bit_rate += 1; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + first.max_bit_rate = second.max_bit_rate; + + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + + first.applies_to_all_codecs = true; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second.applies_to_all_codecs = true; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.codec = VideoCodec::kVp8; + second.codec = VideoCodec::kVp9; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.applies_to_all_codecs = false; + second.applies_to_all_codecs = false; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); +} + +TEST_F(ReceiverSessionTest, AudioLimitsIsSupersetOf) { + ReceiverSession::AudioLimits first{}; + ReceiverSession::AudioLimits second = first; + + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + + first.max_sample_rate += 1; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + first.max_sample_rate = second.max_sample_rate; + + first.max_channels += 1; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + first.max_channels = second.max_channels; + + first.min_bit_rate += 1; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.min_bit_rate = second.min_bit_rate; + + first.max_bit_rate += 1; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + first.max_bit_rate = second.max_bit_rate; + + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + + first.applies_to_all_codecs = true; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second.applies_to_all_codecs = true; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.codec = AudioCodec::kOpus; + second.codec = AudioCodec::kAac; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.applies_to_all_codecs = false; + second.applies_to_all_codecs = false; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); +} + +TEST_F(ReceiverSessionTest, DisplayIsSupersetOf) { + ReceiverSession::Display first; + ReceiverSession::Display second = first; + + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + + first.dimensions = {1921, 1090, {kDefaultFrameRate, 1}}; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + + second.dimensions = {1921, 1090, {kDefaultFrameRate + 1, 1}}; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + + second.dimensions = {2000, 1000, {kDefaultFrameRate, 1}}; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second.dimensions = first.dimensions; + + first.can_scale_content = true; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); +} + +TEST_F(ReceiverSessionTest, RemotingPreferencesIsSupersetOf) { + ReceiverSession::RemotingPreferences first; + ReceiverSession::RemotingPreferences second = first; + + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + + first.supports_chrome_audio_codecs = true; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + + second.supports_4k = true; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + + second.supports_chrome_audio_codecs = true; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); +} + +TEST_F(ReceiverSessionTest, PreferencesIsSupersetOf) { + ReceiverSession::Preferences first; + ReceiverSession::Preferences second(first); + + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + + // Modified |display_description|. + first.display_description = std::make_unique<ReceiverSession::Display>(); + first.display_description->dimensions = {1920, 1080, {kDefaultFrameRate, 1}}; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second = first; + + first.display_description->dimensions = {192, 1080, {kDefaultFrameRate, 1}}; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + second = first; + + // Modified |remoting|. + first.remoting = std::make_unique<ReceiverSession::RemotingPreferences>(); + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second = first; + + second.remoting->supports_4k = true; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + second = first; + + // Modified |video_codecs|. + first.video_codecs = {VideoCodec::kVp8, VideoCodec::kVp9}; + second.video_codecs = {}; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second.video_codecs = {VideoCodec::kHevc}; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + first.video_codecs.emplace_back(VideoCodec::kHevc); + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + first = second; + + // Modified |audio_codecs|. + first.audio_codecs = {AudioCodec::kOpus}; + second.audio_codecs = {}; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second.audio_codecs = {AudioCodec::kAac}; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + first.audio_codecs.emplace_back(AudioCodec::kAac); + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + first = second; + + // Modified |video_limits|. + first.video_limits.push_back({true, VideoCodec::kVp8}); + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.video_limits.front().min_bit_rate = -1; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second.video_limits.push_back({true, VideoCodec::kVp9}); + second.video_limits.front().min_bit_rate = -1; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.video_limits.front().applies_to_all_codecs = false; + first.video_limits.push_back({false, VideoCodec::kHevc, 123}); + second.video_limits.front().applies_to_all_codecs = false; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second.video_limits.front().min_bit_rate = kDefaultVideoMinBitRate; + first.video_limits.front().min_bit_rate = kDefaultVideoMinBitRate; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + second = first; + + // Modified |audio_limits|. + first.audio_limits.push_back({true, AudioCodec::kOpus}); + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.audio_limits.front().min_bit_rate = -1; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); + second.audio_limits.push_back({true, AudioCodec::kAac}); + second.audio_limits.front().min_bit_rate = -1; + EXPECT_TRUE(first.IsSupersetOf(second)); + EXPECT_TRUE(second.IsSupersetOf(first)); + first.audio_limits.front().applies_to_all_codecs = false; + first.audio_limits.push_back({false, AudioCodec::kOpus, -1}); + second.audio_limits.front().applies_to_all_codecs = false; + EXPECT_FALSE(first.IsSupersetOf(second)); + EXPECT_FALSE(second.IsSupersetOf(first)); +} + } // namespace cast } // namespace openscreen diff --git a/cast/streaming/remoting.proto b/cast/streaming/remoting.proto index 0ce73012..84729d60 100644 --- a/cast/streaming/remoting.proto +++ b/cast/streaming/remoting.proto @@ -76,6 +76,7 @@ message AudioDecoderConfig { kSampleFormatAc3 = 9; kSampleFormatEac3 = 10; kSampleFormatMpegHAudio = 11; + kSampleFormatPlanarU8 = 12; }; // Proto version of Chrome's media::ChannelLayout. @@ -238,10 +239,15 @@ message VideoDecoderConfig { optional bytes extra_data = 9; } -message PipelineDecoderInfo { +message AudioDecoderInfo { reserved 3; - reserved "has_decrypting_demuxer_stream"; - optional string decoder_name = 1; + optional int64 decoder_type = 1; + optional bool is_platform_decoder = 2; +}; + +message VideoDecoderInfo { + reserved 3; + optional int64 decoder_type = 1; optional bool is_platform_decoder = 2; }; @@ -253,8 +259,8 @@ message PipelineStatistics { optional int64 audio_memory_usage = 5; optional int64 video_memory_usage = 6; optional int64 video_frame_duration_average_usec = 7; - optional PipelineDecoderInfo audio_decoder_info = 8; - optional PipelineDecoderInfo video_decoder_info = 9; + optional AudioDecoderInfo audio_decoder_info = 8; + optional VideoDecoderInfo video_decoder_info = 9; }; message AcquireDemuxer { @@ -446,4 +452,4 @@ message RpcMessage { // RPC_DS_READUNTIL_CALLBACK DemuxerStreamReadUntilCallback demuxerstream_readuntilcb_rpc = 401; }; -}
\ No newline at end of file +} diff --git a/cast/streaming/remoting_capabilities.h b/cast/streaming/remoting_capabilities.h new file mode 100644 index 00000000..a2f7b173 --- /dev/null +++ b/cast/streaming/remoting_capabilities.h @@ -0,0 +1,61 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef CAST_STREAMING_REMOTING_CAPABILITIES_H_ +#define CAST_STREAMING_REMOTING_CAPABILITIES_H_ + +#include <string> +#include <vector> + +namespace openscreen { +namespace cast { + +// Audio capabilities are how receivers indicate support for remoting codecs-- +// as remoting does not include the actual codec in the OFFER message. +enum class AudioCapability { + // The "baseline set" is used in Chrome to check support for a wide + // variety of audio codecs in media/remoting/renderer_controller.cc, including + // but not limited to MP3, PCM, Ogg Vorbis, and FLAC. + kBaselineSet, + kAac, + kOpus, +}; + +// Similar to audio capabilities, video capabilities are how the receiver +// indicates support for certain video codecs, as well as support for streaming +// 4k content. It is assumed by the sender that the receiver can support 4k +// on all supported codecs. +enum class VideoCapability { + // |kSupports4k| indicates that the receiver wants and can support 4k remoting + // content--both decoding/rendering and either a native 4k display or + // downscaling to the display's native resolution. + // TODO(issuetracker.google.com/184429130): |kSupports4k| is not super helpful + // for enabling 4k support, as receivers may not support 4k for all types of + // content. + kSupports4k, + kH264, + kVp8, + kVp9, + kHevc, + kAv1 +}; + +// This class is similar to the RemotingSinkMetadata in Chrome, however +// it is focused around our needs and is not mojom-based. This contains +// a rough set of capabilities of the receiver to give the sender an idea of +// what features are suppported for remoting. +// TODO(issuetracker.google.com/184189100): this object should be expanded to +// allow more specific constraint tracking. +struct RemotingCapabilities { + // Receiver audio-specific capabilities. + std::vector<AudioCapability> audio; + + // Receiver video-specific capabilities. + std::vector<VideoCapability> video; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_REMOTING_CAPABILITIES_H_ diff --git a/cast/streaming/resolution.cc b/cast/streaming/resolution.cc new file mode 100644 index 00000000..9c763cfe --- /dev/null +++ b/cast/streaming/resolution.cc @@ -0,0 +1,122 @@ +// Copyright 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/resolution.h" + +#include <utility> + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "cast/streaming/message_fields.h" +#include "platform/base/error.h" +#include "util/json/json_helpers.h" +#include "util/osp_logging.h" + +namespace openscreen { +namespace cast { + +namespace { + +/// Dimension properties. +// Width in pixels. +static constexpr char kWidth[] = "width"; + +// Height in pixels. +static constexpr char kHeight[] = "height"; + +// Frame rate as a rational decimal number or fraction. +// E.g. 30 and "3000/1001" are both valid representations. +static constexpr char kFrameRate[] = "frameRate"; + +// Choice of epsilon for double comparison allows for proper comparison +// for both aspect ratios and frame rates. For frame rates, it is based on the +// broadcast rate of 29.97fps, which is actually 29.976. For aspect ratios, it +// allows for a one-pixel difference at a 4K resolution, we want it to be +// relatively high to avoid false negative comparison results. +bool FrameRateEquals(double a, double b) { + const double kEpsilonForFrameRateComparisons = .0001; + return std::abs(a - b) < kEpsilonForFrameRateComparisons; +} + +} // namespace + +bool Resolution::TryParse(const Json::Value& root, Resolution* out) { + if (!json::TryParseInt(root[kWidth], &(out->width)) || + !json::TryParseInt(root[kHeight], &(out->height))) { + return false; + } + return out->IsValid(); +} + +bool Resolution::IsValid() const { + return width > 0 && height > 0; +} + +Json::Value Resolution::ToJson() const { + OSP_DCHECK(IsValid()); + Json::Value root; + root[kWidth] = width; + root[kHeight] = height; + + return root; +} + +bool Resolution::operator==(const Resolution& other) const { + return std::tie(width, height) == std::tie(other.width, other.height); +} + +bool Resolution::operator!=(const Resolution& other) const { + return !(*this == other); +} + +bool Resolution::IsSupersetOf(const Resolution& other) const { + return width >= other.width && height >= other.height; +} + +bool Dimensions::TryParse(const Json::Value& root, Dimensions* out) { + if (!json::TryParseInt(root[kWidth], &(out->width)) || + !json::TryParseInt(root[kHeight], &(out->height)) || + !(root[kFrameRate].isNull() || + json::TryParseSimpleFraction(root[kFrameRate], &(out->frame_rate)))) { + return false; + } + return out->IsValid(); +} + +bool Dimensions::IsValid() const { + return width > 0 && height > 0 && frame_rate.is_positive(); +} + +Json::Value Dimensions::ToJson() const { + OSP_DCHECK(IsValid()); + Json::Value root; + root[kWidth] = width; + root[kHeight] = height; + root[kFrameRate] = frame_rate.ToString(); + + return root; +} + +bool Dimensions::operator==(const Dimensions& other) const { + return (std::tie(width, height) == std::tie(other.width, other.height) && + FrameRateEquals(static_cast<double>(frame_rate), + static_cast<double>(other.frame_rate))); +} + +bool Dimensions::operator!=(const Dimensions& other) const { + return !(*this == other); +} + +bool Dimensions::IsSupersetOf(const Dimensions& other) const { + if (static_cast<double>(frame_rate) != + static_cast<double>(other.frame_rate)) { + return static_cast<double>(frame_rate) >= + static_cast<double>(other.frame_rate); + } + + return ToResolution().IsSupersetOf(other.ToResolution()); +} + +} // namespace cast +} // namespace openscreen diff --git a/cast/streaming/resolution.h b/cast/streaming/resolution.h new file mode 100644 index 00000000..47815556 --- /dev/null +++ b/cast/streaming/resolution.h @@ -0,0 +1,70 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Resolutions and dimensions (resolutions with a frame rate) are used +// extensively throughout cast streaming. Since their serialization to and +// from JSON is stable and standard, we have a single place definition for +// these for use both in our public APIs and private messages. + +#ifndef CAST_STREAMING_RESOLUTION_H_ +#define CAST_STREAMING_RESOLUTION_H_ + +#include "absl/types/optional.h" +#include "json/value.h" +#include "util/simple_fraction.h" + +namespace openscreen { +namespace cast { + +// A resolution in pixels. +struct Resolution { + static bool TryParse(const Json::Value& value, Resolution* out); + bool IsValid() const; + Json::Value ToJson() const; + + // Returns true if both |width| and |height| of this instance are greater than + // or equal to that of |other|. + bool IsSupersetOf(const Resolution& other) const; + + bool operator==(const Resolution& other) const; + bool operator!=(const Resolution& other) const; + + // Width and height in pixels. + int width = 0; + int height = 0; +}; + +// A resolution in pixels and a frame rate. +struct Dimensions { + static bool TryParse(const Json::Value& value, Dimensions* out); + bool IsValid() const; + Json::Value ToJson() const; + + // Returns true if all properties of this instance are greater than or equal + // to those of |other|. + bool IsSupersetOf(const Dimensions& other) const; + + bool operator==(const Dimensions& other) const; + bool operator!=(const Dimensions& other) const; + + // Get just the width and height fields (for comparisons). + constexpr Resolution ToResolution() const { return {width, height}; } + + // The effective bit rate is the width * height * frame rate. + constexpr int effective_bit_rate() const { + return width * height * static_cast<double>(frame_rate); + } + + // Width and height in pixels. + int width = 0; + int height = 0; + + // |frame_rate| is the maximum maintainable frame rate. + SimpleFraction frame_rate{0, 1}; +}; + +} // namespace cast +} // namespace openscreen + +#endif // CAST_STREAMING_RESOLUTION_H_ diff --git a/cast/streaming/rpc_broker.cc b/cast/streaming/rpc_broker.cc deleted file mode 100644 index 6e79a4dd..00000000 --- a/cast/streaming/rpc_broker.cc +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "cast/streaming/rpc_broker.h" - -#include <utility> - -#include "util/osp_logging.h" - -namespace openscreen { -namespace cast { - -namespace { - -std::ostream& operator<<(std::ostream& out, const RpcMessage& message) { - out << "handle=" << message.handle() << ", proc=" << message.proc(); - switch (message.rpc_oneof_case()) { - case RpcMessage::kIntegerValue: - return out << ", integer_value=" << message.integer_value(); - case RpcMessage::kInteger64Value: - return out << ", integer64_value=" << message.integer64_value(); - case RpcMessage::kDoubleValue: - return out << ", double_value=" << message.double_value(); - case RpcMessage::kBooleanValue: - return out << ", boolean_value=" << message.boolean_value(); - case RpcMessage::kStringValue: - return out << ", string_value=" << message.string_value(); - default: - return out << ", rpc_oneof=" << message.rpc_oneof_case(); - } - - OSP_NOTREACHED(); -} - -} // namespace - -RpcBroker::RpcBroker(SendMessageCallback send_message_cb) - : next_handle_(kFirstHandle), - send_message_cb_(std::move(send_message_cb)) {} - -RpcBroker::~RpcBroker() { - receive_callbacks_.clear(); -} - -RpcBroker::Handle RpcBroker::GetUniqueHandle() { - return next_handle_++; -} - -void RpcBroker::RegisterMessageReceiverCallback( - RpcBroker::Handle handle, - ReceiveMessageCallback callback) { - OSP_DCHECK(receive_callbacks_.find(handle) == receive_callbacks_.end()) - << "must deregister before re-registering"; - OSP_DVLOG << "registering handle: " << handle; - receive_callbacks_.emplace_back(handle, std::move(callback)); -} - -void RpcBroker::UnregisterMessageReceiverCallback(RpcBroker::Handle handle) { - OSP_DVLOG << "unregistering handle: " << handle; - receive_callbacks_.erase_key(handle); -} - -void RpcBroker::ProcessMessageFromRemote(const RpcMessage& message) { - OSP_DVLOG << "received message: " << message; - const auto entry = receive_callbacks_.find(message.handle()); - if (entry == receive_callbacks_.end()) { - OSP_DVLOG << "unregistered handle: " << message.handle(); - return; - } - entry->second(message); -} - -void RpcBroker::SendMessageToRemote(const RpcMessage& message) { - OSP_DVLOG << "sending message message: " << message; - std::vector<uint8_t> serialized_message(message.ByteSizeLong()); - OSP_CHECK(message.SerializeToArray(serialized_message.data(), - serialized_message.size())); - send_message_cb_(std::move(serialized_message)); -} - -bool RpcBroker::IsRegisteredForTesting(RpcBroker::Handle handle) { - return receive_callbacks_.find(handle) != receive_callbacks_.end(); -} - -} // namespace cast -} // namespace openscreen diff --git a/cast/streaming/rpc_broker_unittest.cc b/cast/streaming/rpc_broker_unittest.cc deleted file mode 100644 index 5dacfbb2..00000000 --- a/cast/streaming/rpc_broker_unittest.cc +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2020 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "cast/streaming/rpc_broker.h" - -#include <memory> -#include <vector> - -#include "cast/streaming/remoting.pb.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using testing::_; -using testing::Invoke; -using testing::Return; - -namespace openscreen { -namespace cast { - -namespace { - -class FakeMessager { - public: - void OnReceivedRpc(const RpcMessage& message) { - received_rpc_ = message; - received_count_++; - } - - void OnSentRpc(const std::vector<uint8_t>& message) { - EXPECT_TRUE(sent_rpc_.ParseFromArray(message.data(), message.size())); - sent_count_++; - } - - int received_count() const { return received_count_; } - const RpcMessage& received_rpc() const { return received_rpc_; } - - int sent_count() const { return sent_count_; } - const RpcMessage& sent_rpc() const { return sent_rpc_; } - - void set_handle(RpcBroker::Handle handle) { handle_ = handle; } - RpcBroker::Handle handle() { return handle_; } - - private: - RpcMessage received_rpc_; - int received_count_ = 0; - - RpcMessage sent_rpc_; - int sent_count_ = 0; - - RpcBroker::Handle handle_ = -1; -}; - -} // namespace - -class RpcBrokerTest : public testing::Test { - protected: - void SetUp() override { - fake_messager_ = std::make_unique<FakeMessager>(); - ASSERT_FALSE(fake_messager_->received_count()); - - rpc_broker_ = std::make_unique<RpcBroker>( - [p = fake_messager_.get()](std::vector<uint8_t> message) { - p->OnSentRpc(message); - }); - - const auto handle = rpc_broker_->GetUniqueHandle(); - fake_messager_->set_handle(handle); - rpc_broker_->RegisterMessageReceiverCallback( - handle, [p = fake_messager_.get()](const RpcMessage& message) { - p->OnReceivedRpc(message); - }); - } - - std::unique_ptr<FakeMessager> fake_messager_; - std::unique_ptr<RpcBroker> rpc_broker_; -}; - -TEST_F(RpcBrokerTest, TestProcessMessageFromRemoteRegistered) { - RpcMessage rpc; - rpc.set_handle(fake_messager_->handle()); - rpc_broker_->ProcessMessageFromRemote(rpc); - ASSERT_EQ(1, fake_messager_->received_count()); -} - -TEST_F(RpcBrokerTest, TestProcessMessageFromRemoteUnregistered) { - RpcMessage rpc; - rpc_broker_->UnregisterMessageReceiverCallback(fake_messager_->handle()); - rpc_broker_->ProcessMessageFromRemote(rpc); - ASSERT_EQ(0, fake_messager_->received_count()); -} - -TEST_F(RpcBrokerTest, CanSendMultipleMessages) { - for (int i = 0; i < 10; ++i) { - rpc_broker_->SendMessageToRemote(RpcMessage{}); - } - EXPECT_EQ(10, fake_messager_->sent_count()); -} - -TEST_F(RpcBrokerTest, SendMessageCallback) { - // Send message for RPC broker to process. - RpcMessage sent_rpc; - sent_rpc.set_handle(fake_messager_->handle()); - sent_rpc.set_proc(RpcMessage::RPC_R_SETVOLUME); - sent_rpc.set_double_value(2.3); - rpc_broker_->SendMessageToRemote(sent_rpc); - - // Check if received message is identical to the one sent earlier. - ASSERT_EQ(1, fake_messager_->sent_count()); - const RpcMessage& message = fake_messager_->sent_rpc(); - ASSERT_EQ(fake_messager_->handle(), message.handle()); - ASSERT_EQ(RpcMessage::RPC_R_SETVOLUME, message.proc()); - ASSERT_EQ(2.3, message.double_value()); -} - -TEST_F(RpcBrokerTest, ProcessMessageWithRegisteredHandle) { - // Send message for RPC broker to process. - RpcMessage sent_rpc; - sent_rpc.set_handle(fake_messager_->handle()); - sent_rpc.set_proc(RpcMessage::RPC_R_SETVOLUME); - sent_rpc.set_double_value(3.4); - rpc_broker_->ProcessMessageFromRemote(sent_rpc); - - // Checks if received message is identical to the one sent earlier. - ASSERT_EQ(1, fake_messager_->received_count()); - const RpcMessage& received_rpc = fake_messager_->received_rpc(); - ASSERT_EQ(fake_messager_->handle(), received_rpc.handle()); - ASSERT_EQ(RpcMessage::RPC_R_SETVOLUME, received_rpc.proc()); - ASSERT_EQ(3.4, received_rpc.double_value()); -} - -TEST_F(RpcBrokerTest, ProcessMessageWithUnregisteredHandle) { - // Send message for RPC broker to process. - RpcMessage sent_rpc; - RpcBroker::Handle different_handle = fake_messager_->handle() + 1; - sent_rpc.set_handle(different_handle); - sent_rpc.set_proc(RpcMessage::RPC_R_SETVOLUME); - sent_rpc.set_double_value(4.5); - rpc_broker_->ProcessMessageFromRemote(sent_rpc); - - // We shouldn't have gotten the message since the handle is different. - ASSERT_EQ(0, fake_messager_->received_count()); -} - -TEST_F(RpcBrokerTest, Registration) { - const auto handle = fake_messager_->handle(); - ASSERT_TRUE(rpc_broker_->IsRegisteredForTesting(handle)); - - rpc_broker_->UnregisterMessageReceiverCallback(handle); - ASSERT_FALSE(rpc_broker_->IsRegisteredForTesting(handle)); -} - -} // namespace cast -} // namespace openscreen diff --git a/cast/streaming/rpc_messenger.cc b/cast/streaming/rpc_messenger.cc new file mode 100644 index 00000000..360b5572 --- /dev/null +++ b/cast/streaming/rpc_messenger.cc @@ -0,0 +1,106 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/rpc_messenger.h" + +#include <memory> +#include <string> +#include <utility> + +#include "util/osp_logging.h" + +namespace openscreen { +namespace cast { + +namespace { + +std::ostream& operator<<(std::ostream& out, const RpcMessage& message) { + out << "handle=" << message.handle() << ", proc=" << message.proc(); + switch (message.rpc_oneof_case()) { + case RpcMessage::kIntegerValue: + return out << ", integer_value=" << message.integer_value(); + case RpcMessage::kInteger64Value: + return out << ", integer64_value=" << message.integer64_value(); + case RpcMessage::kDoubleValue: + return out << ", double_value=" << message.double_value(); + case RpcMessage::kBooleanValue: + return out << ", boolean_value=" << message.boolean_value(); + case RpcMessage::kStringValue: + return out << ", string_value=" << message.string_value(); + default: + return out << ", rpc_oneof=" << message.rpc_oneof_case(); + } + + OSP_NOTREACHED(); +} + +} // namespace + +constexpr RpcMessenger::Handle RpcMessenger::kInvalidHandle; +constexpr RpcMessenger::Handle RpcMessenger::kAcquireRendererHandle; +constexpr RpcMessenger::Handle RpcMessenger::kAcquireDemuxerHandle; +constexpr RpcMessenger::Handle RpcMessenger::kFirstHandle; + +RpcMessenger::RpcMessenger(SendMessageCallback send_message_cb) + : next_handle_(kFirstHandle), + send_message_cb_(std::move(send_message_cb)) {} + +RpcMessenger::~RpcMessenger() { + receive_callbacks_.clear(); +} + +RpcMessenger::Handle RpcMessenger::GetUniqueHandle() { + return next_handle_++; +} + +void RpcMessenger::RegisterMessageReceiverCallback( + RpcMessenger::Handle handle, + ReceiveMessageCallback callback) { + OSP_DCHECK(receive_callbacks_.find(handle) == receive_callbacks_.end()) + << "must deregister before re-registering"; + receive_callbacks_.emplace_back(handle, std::move(callback)); +} + +void RpcMessenger::UnregisterMessageReceiverCallback(RpcMessenger::Handle handle) { + receive_callbacks_.erase_key(handle); +} + +void RpcMessenger::ProcessMessageFromRemote(const uint8_t* message, + std::size_t message_len) { + auto rpc = std::make_unique<RpcMessage>(); + if (!rpc->ParseFromArray(message, message_len)) { + OSP_DLOG_WARN << "Failed to parse RPC message from remote: \"" << message + << "\""; + return; + } + ProcessMessageFromRemote(std::move(rpc)); +} + +void RpcMessenger::ProcessMessageFromRemote(std::unique_ptr<RpcMessage> message) { + const auto entry = receive_callbacks_.find(message->handle()); + if (entry == receive_callbacks_.end()) { + OSP_VLOG << "Dropping message due to unregistered handle: " + << message->handle(); + return; + } + entry->second(std::move(message)); +} + +void RpcMessenger::SendMessageToRemote(const RpcMessage& rpc) { + OSP_VLOG << "Sending RPC message: " << rpc; + std::vector<uint8_t> message(rpc.ByteSizeLong()); + rpc.SerializeToArray(message.data(), message.size()); + send_message_cb_(std::move(message)); +} + +bool RpcMessenger::IsRegisteredForTesting(RpcMessenger::Handle handle) { + return receive_callbacks_.find(handle) != receive_callbacks_.end(); +} + +WeakPtr<RpcMessenger> RpcMessenger::GetWeakPtr() { + return weak_factory_.GetWeakPtr(); +} + +} // namespace cast +} // namespace openscreen diff --git a/cast/streaming/rpc_broker.h b/cast/streaming/rpc_messenger.h index 596aba13..dd1b1932 100644 --- a/cast/streaming/rpc_broker.h +++ b/cast/streaming/rpc_messenger.h @@ -2,41 +2,46 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef CAST_STREAMING_RPC_BROKER_H_ -#define CAST_STREAMING_RPC_BROKER_H_ +#ifndef CAST_STREAMING_RPC_MESSENGER_H_ +#define CAST_STREAMING_RPC_MESSENGER_H_ +#include <memory> +#include <string> +#include <utility> #include <vector> #include "cast/streaming/remoting.pb.h" #include "util/flat_map.h" +#include "util/weak_ptr.h" namespace openscreen { namespace cast { // Processes incoming and outgoing RPC messages and links them to desired -// components on both end points. For outgoing messages, the messager +// components on both end points. For outgoing messages, the messenger // must send an RPC message with associated handle value. On the messagee side, // the message is sent to a pre-registered component using that handle. // Before RPC communication starts, both sides need to negotiate the handle // value in the existing RPC communication channel using the special handles // |kAcquire*Handle|. // -// NOTE: RpcBroker doesn't actually send RPC messages to the remote. The session -// messager needs to set SendMessageCallback, and call ProcessMessageFromRemote -// as appropriate. The RpcBroker then distributes each RPC message to the +// NOTE: RpcMessenger doesn't actually send RPC messages to the remote. The session +// messenger needs to set SendMessageCallback, and call ProcessMessageFromRemote +// as appropriate. The RpcMessenger then distributes each RPC message to the // subscribed component. -class RpcBroker { +class RpcMessenger { public: using Handle = int; - using ReceiveMessageCallback = std::function<void(const RpcMessage&)>; + using ReceiveMessageCallback = + std::function<void(std::unique_ptr<RpcMessage>)>; using SendMessageCallback = std::function<void(std::vector<uint8_t>)>; - explicit RpcBroker(SendMessageCallback send_message_cb); - RpcBroker(const RpcBroker&) = delete; - RpcBroker(RpcBroker&&) noexcept; - RpcBroker& operator=(const RpcBroker&) = delete; - RpcBroker& operator=(RpcBroker&&); - ~RpcBroker(); + explicit RpcMessenger(SendMessageCallback send_message_cb); + RpcMessenger(const RpcMessenger&) = delete; + RpcMessenger(RpcMessenger&&) noexcept; + RpcMessenger& operator=(const RpcMessenger&) = delete; + RpcMessenger& operator=(RpcMessenger&&); + ~RpcMessenger(); // Get unique handle value for RPC message handles. Handle GetUniqueHandle(); @@ -53,14 +58,29 @@ class RpcBroker { void UnregisterMessageReceiverCallback(Handle handle); // Distributes an incoming RPC message to the registered (if any) component. - void ProcessMessageFromRemote(const RpcMessage& message); + // The |serialized_message| should be already base64-decoded and ready for + // deserialization by protobuf. + void ProcessMessageFromRemote(const uint8_t* message, + std::size_t message_len); + // This overload distributes an already-deserialized message to the + // registered component. + void ProcessMessageFromRemote(std::unique_ptr<RpcMessage> message); - // Executes the |send_message_cb_| using |message|. - void SendMessageToRemote(const RpcMessage& message); + // Executes the |send_message_cb_| using |rpc|. + void SendMessageToRemote(const RpcMessage& rpc); // Checks if the handle is registered for receiving messages. Test-only. bool IsRegisteredForTesting(Handle handle); + // Weak pointer creator. + WeakPtr<RpcMessenger> GetWeakPtr(); + + // Consumers of RPCMessenger may set the send message callback post-hoc + // in order to simulate different scenarios. + void set_send_message_cb_for_testing(SendMessageCallback cb) { + send_message_cb_ = std::move(cb); + } + // Predefined invalid handle value for RPC message. static constexpr Handle kInvalidHandle = -1; @@ -81,9 +101,11 @@ class RpcBroker { // Callback that is ran to send a serialized message. SendMessageCallback send_message_cb_; + + WeakPtrFactory<RpcMessenger> weak_factory_{this}; }; } // namespace cast } // namespace openscreen -#endif // CAST_STREAMING_RPC_BROKER_H_ +#endif // CAST_STREAMING_RPC_MESSENGER_H_ diff --git a/cast/streaming/rpc_messenger_unittest.cc b/cast/streaming/rpc_messenger_unittest.cc new file mode 100644 index 00000000..758a9036 --- /dev/null +++ b/cast/streaming/rpc_messenger_unittest.cc @@ -0,0 +1,162 @@ +// Copyright 2020 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "cast/streaming/rpc_messenger.h" + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "cast/streaming/remoting.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; +using testing::Invoke; +using testing::Return; + +namespace openscreen { +namespace cast { +namespace { + +class FakeMessenger { + public: + void OnReceivedRpc(std::unique_ptr<RpcMessage> message) { + received_rpc_ = std::move(message); + received_count_++; + } + + void OnSentRpc(const std::vector<uint8_t>& message) { + EXPECT_TRUE(sent_rpc_.ParseFromArray(message.data(), message.size())); + sent_count_++; + } + + int received_count() const { return received_count_; } + const RpcMessage& received_rpc() const { return *received_rpc_; } + + int sent_count() const { return sent_count_; } + const RpcMessage& sent_rpc() const { return sent_rpc_; } + + void set_handle(RpcMessenger::Handle handle) { handle_ = handle; } + RpcMessenger::Handle handle() { return handle_; } + + private: + std::unique_ptr<RpcMessage> received_rpc_; + int received_count_ = 0; + + RpcMessage sent_rpc_; + int sent_count_ = 0; + + RpcMessenger::Handle handle_ = -1; +}; + +} // namespace + +class RpcMessengerTest : public testing::Test { + protected: + void SetUp() override { + fake_messenger_ = std::make_unique<FakeMessenger>(); + ASSERT_FALSE(fake_messenger_->received_count()); + + rpc_messenger_ = std::make_unique<RpcMessenger>( + [p = fake_messenger_.get()](std::vector<uint8_t> message) { + p->OnSentRpc(message); + }); + + const auto handle = rpc_messenger_->GetUniqueHandle(); + fake_messenger_->set_handle(handle); + rpc_messenger_->RegisterMessageReceiverCallback( + handle, + [p = fake_messenger_.get()](std::unique_ptr<RpcMessage> message) { + p->OnReceivedRpc(std::move(message)); + }); + } + + void ProcessMessage(const RpcMessage& rpc) { + std::vector<uint8_t> message(rpc.ByteSizeLong()); + rpc.SerializeToArray(message.data(), message.size()); + rpc_messenger_->ProcessMessageFromRemote(message.data(), message.size()); + } + + std::unique_ptr<FakeMessenger> fake_messenger_; + std::unique_ptr<RpcMessenger> rpc_messenger_; +}; + +TEST_F(RpcMessengerTest, TestProcessMessageFromRemoteRegistered) { + RpcMessage rpc; + rpc.set_handle(fake_messenger_->handle()); + ProcessMessage(rpc); + ASSERT_EQ(1, fake_messenger_->received_count()); +} + +TEST_F(RpcMessengerTest, TestProcessMessageFromRemoteUnregistered) { + RpcMessage rpc; + rpc_messenger_->UnregisterMessageReceiverCallback(fake_messenger_->handle()); + ProcessMessage(rpc); + ASSERT_EQ(0, fake_messenger_->received_count()); +} + +TEST_F(RpcMessengerTest, CanSendMultipleMessages) { + for (int i = 0; i < 10; ++i) { + rpc_messenger_->SendMessageToRemote(RpcMessage{}); + } + EXPECT_EQ(10, fake_messenger_->sent_count()); +} + +TEST_F(RpcMessengerTest, SendMessageCallback) { + // Send message for RPC messenger to process. + RpcMessage sent_rpc; + sent_rpc.set_handle(fake_messenger_->handle()); + sent_rpc.set_proc(RpcMessage::RPC_R_SETVOLUME); + sent_rpc.set_double_value(2.3); + rpc_messenger_->SendMessageToRemote(sent_rpc); + + // Check if received message is identical to the one sent earlier. + ASSERT_EQ(1, fake_messenger_->sent_count()); + const RpcMessage& message = fake_messenger_->sent_rpc(); + ASSERT_EQ(fake_messenger_->handle(), message.handle()); + ASSERT_EQ(RpcMessage::RPC_R_SETVOLUME, message.proc()); + ASSERT_EQ(2.3, message.double_value()); +} + +TEST_F(RpcMessengerTest, ProcessMessageWithRegisteredHandle) { + // Send message for RPC messenger to process. + RpcMessage sent_rpc; + sent_rpc.set_handle(fake_messenger_->handle()); + sent_rpc.set_proc(RpcMessage::RPC_DS_INITIALIZE); + sent_rpc.set_integer_value(4004); + ProcessMessage(sent_rpc); + + // Checks if received message is identical to the one sent earlier. + ASSERT_EQ(1, fake_messenger_->received_count()); + const RpcMessage& received_rpc = fake_messenger_->received_rpc(); + ASSERT_EQ(fake_messenger_->handle(), received_rpc.handle()); + ASSERT_EQ(RpcMessage::RPC_DS_INITIALIZE, received_rpc.proc()); + ASSERT_EQ(4004, received_rpc.integer_value()); +} + +TEST_F(RpcMessengerTest, ProcessMessageWithUnregisteredHandle) { + // Send message for RPC messenger to process. + RpcMessage sent_rpc; + RpcMessenger::Handle different_handle = fake_messenger_->handle() + 1; + sent_rpc.set_handle(different_handle); + sent_rpc.set_proc(RpcMessage::RPC_R_SETVOLUME); + sent_rpc.set_double_value(4.5); + ProcessMessage(sent_rpc); + + // We shouldn't have gotten the message since the handle is different. + ASSERT_EQ(0, fake_messenger_->received_count()); +} + +TEST_F(RpcMessengerTest, Registration) { + const auto handle = fake_messenger_->handle(); + ASSERT_TRUE(rpc_messenger_->IsRegisteredForTesting(handle)); + + rpc_messenger_->UnregisterMessageReceiverCallback(handle); + ASSERT_FALSE(rpc_messenger_->IsRegisteredForTesting(handle)); +} + +} // namespace cast +} // namespace openscreen diff --git a/cast/streaming/rtcp_common.cc b/cast/streaming/rtcp_common.cc index ce1e42d9..03562c42 100644 --- a/cast/streaming/rtcp_common.cc +++ b/cast/streaming/rtcp_common.cc @@ -43,14 +43,12 @@ void RtcpCommonHeader::AppendFields(absl::Span<uint8_t>* buffer) const { break; default: OSP_NOTREACHED(); - break; } break; case RtcpPacketType::kExtendedReports: break; case RtcpPacketType::kNull: OSP_NOTREACHED(); - break; } AppendField<uint8_t>(byte0, buffer); diff --git a/cast/streaming/rtcp_common.h b/cast/streaming/rtcp_common.h index 25e2c2ed..5fbf5f72 100644 --- a/cast/streaming/rtcp_common.h +++ b/cast/streaming/rtcp_common.h @@ -161,10 +161,6 @@ struct PacketNack { FrameId frame_id; FramePacketId packet_id; - // Comparison operators. Define more when you need them! - // TODO(miu): In C++20, just - // replace all of this with one operator<=>() definition to get them all for - // free. constexpr bool operator==(const PacketNack& other) const { return frame_id == other.frame_id && packet_id == other.packet_id; } diff --git a/cast/streaming/rtp_defines.cc b/cast/streaming/rtp_defines.cc index d64773d5..6ab389eb 100644 --- a/cast/streaming/rtp_defines.cc +++ b/cast/streaming/rtp_defines.cc @@ -4,15 +4,56 @@ #include "cast/streaming/rtp_defines.h" +#include "util/osp_logging.h" + namespace openscreen { namespace cast { -RtpPayloadType GetPayloadType(AudioCodec codec) { - return RtpPayloadType::kAudioHackForAndroidTV; +RtpPayloadType GetPayloadType(AudioCodec codec, bool use_android_rtp_hack) { + if (use_android_rtp_hack) { + return codec == AudioCodec::kNotSpecified + ? RtpPayloadType::kAudioVarious + : RtpPayloadType::kAudioHackForAndroidTV; + } + + switch (codec) { + case AudioCodec::kAac: + return RtpPayloadType::kAudioAac; + case AudioCodec::kOpus: + return RtpPayloadType::kAudioOpus; + case AudioCodec::kNotSpecified: + return RtpPayloadType::kAudioVarious; + default: + OSP_NOTREACHED(); + } } -RtpPayloadType GetPayloadType(VideoCodec codec) { - return RtpPayloadType::kVideoHackForAndroidTV; +RtpPayloadType GetPayloadType(VideoCodec codec, bool use_android_rtp_hack) { + if (use_android_rtp_hack) { + return codec == VideoCodec::kNotSpecified + ? RtpPayloadType::kVideoVarious + : RtpPayloadType::kVideoHackForAndroidTV; + } + switch (codec) { + // VP8 and VP9 share the same payload type. + case VideoCodec::kVp9: + case VideoCodec::kVp8: + return RtpPayloadType::kVideoVp8; + + // H264 and HEVC/H265 share the same payload type. + case VideoCodec::kHevc: // fallthrough + case VideoCodec::kH264: + return RtpPayloadType::kVideoH264; + + case VideoCodec::kAv1: + return RtpPayloadType::kVideoAv1; + + case VideoCodec::kNotSpecified: + return RtpPayloadType::kVideoVarious; + + default: + OSP_NOTREACHED(); + } } bool IsRtpPayloadType(uint8_t raw_byte) { @@ -23,6 +64,8 @@ bool IsRtpPayloadType(uint8_t raw_byte) { case RtpPayloadType::kAudioVarious: case RtpPayloadType::kVideoVp8: case RtpPayloadType::kVideoH264: + case RtpPayloadType::kVideoVp9: + case RtpPayloadType::kVideoAv1: case RtpPayloadType::kVideoVarious: case RtpPayloadType::kAudioHackForAndroidTV: // Note: RtpPayloadType::kVideoHackForAndroidTV has the same value as diff --git a/cast/streaming/rtp_defines.h b/cast/streaming/rtp_defines.h index 82c91c31..43005714 100644 --- a/cast/streaming/rtp_defines.h +++ b/cast/streaming/rtp_defines.h @@ -92,26 +92,25 @@ enum class RtpPayloadType : uint8_t { kVideoVp8 = 100, kVideoH264 = 101, kVideoVarious = 102, // Codec being used is not fixed. - kVideoLast = 102, + kVideoVp9 = 103, + kVideoAv1 = 104, + kVideoLast = kVideoAv1, // Some AndroidTV receivers require the payload type for audio to be 127, and // video to be 96; regardless of the codecs actually being used. This is // definitely out-of-spec, and inconsistent with the audio versus video range // of values, but must be taken into account for backwards-compatibility. - // TODO(crbug.com/1127978): RTP payload types need to represent actual type, - // as well as have options for new codecs like VP9. kAudioHackForAndroidTV = 127, kVideoHackForAndroidTV = 96, }; -// NOTE: currently we match the legacy Chrome sender's behavior of always -// sending the audio and video hacks for AndroidTV, however we should migrate -// to using proper rtp payload types. New payload types for new codecs, such -// as VP9, should also be defined. -// TODO(crbug.com/1127978): RTP payload types need to represent actual type, -// as well as have options for new codecs like VP9. -RtpPayloadType GetPayloadType(AudioCodec codec); -RtpPayloadType GetPayloadType(VideoCodec codec); +// Setting |use_android_rtp_hack| to true means that we match the legacy Chrome +// sender's behavior of always sending the audio and video hacks for AndroidTV, +// as some legacy android receivers require these. +// TODO(issuetracker.google.com/184438154): we need to figure out what receivers +// need this still, if any. The hack should be removed when possible. +RtpPayloadType GetPayloadType(AudioCodec codec, bool use_android_rtp_hack); +RtpPayloadType GetPayloadType(VideoCodec codec, bool use_android_rtp_hack); // Returns true if the |raw_byte| can be type-casted to a RtpPayloadType, and is // also not RtpPayloadType::kNull. The caller should mask the byte, to select diff --git a/cast/streaming/sender.cc b/cast/streaming/sender.cc index ba42bcb4..fd3e3b75 100644 --- a/cast/streaming/sender.cc +++ b/cast/streaming/sender.cc @@ -12,6 +12,7 @@ #include "util/chrono_helpers.h" #include "util/osp_logging.h" #include "util/std_util.h" +#include "util/trace_logging.h" namespace openscreen { namespace cast { @@ -311,10 +312,11 @@ void Sender::OnReceiverReport(const RtcpReportBlock& receiver_report) { round_trip_time_ = (kInertia * round_trip_time_ + measurement) / (kInertia + 1); } - // TODO(miu): Add tracing event here to note the updated RTT. + TRACE_SCOPED(TraceCategory::kSender, "UpdatedRTT"); } void Sender::OnReceiverIndicatesPictureLoss() { + TRACE_DEFAULT_SCOPED(TraceCategory::kSender); // The Receiver will continue the PLI notifications until it has received a // key frame. Thus, if a key frame is already in-flight, don't make a state // change that would cause this Sender to force another expensive key frame. @@ -342,6 +344,7 @@ void Sender::OnReceiverIndicatesPictureLoss() { void Sender::OnReceiverCheckpoint(FrameId frame_id, milliseconds playout_delay) { + TRACE_DEFAULT_SCOPED(TraceCategory::kSender); if (frame_id > last_enqueued_frame_id_) { OSP_LOG_ERROR << "Ignoring checkpoint for " << latest_expected_frame_id_ @@ -415,14 +418,13 @@ void Sender::OnReceiverIsMissingPackets(std::vector<PacketNack> nacks) { // happen in rare cases where RTCP packets arrive out-of-order (i.e., the // network shuffled them). if (!slot) { - // TODO(miu): Add tracing event here to record this. + TRACE_SCOPED(TraceCategory::kSender, "MissingNackSlot"); for (++nack_it; nack_it != nacks.end() && nack_it->frame_id == frame_id; ++nack_it) { } continue; } - // NOLINTNEXTLINE latest_expected_frame_id_ = std::max(latest_expected_frame_id_, frame_id); const auto HandleIndividualNack = [&](FramePacketId packet_id) { diff --git a/cast/streaming/sender_message.cc b/cast/streaming/sender_message.cc index 9c2b5388..2526f96a 100644 --- a/cast/streaming/sender_message.cc +++ b/cast/streaming/sender_message.cc @@ -20,13 +20,12 @@ namespace { EnumNameTable<SenderMessage::Type, 4> kMessageTypeNames{ {{kMessageTypeOffer, SenderMessage::Type::kOffer}, - {"GET_STATUS", SenderMessage::Type::kGetStatus}, {"GET_CAPABILITIES", SenderMessage::Type::kGetCapabilities}, {"RPC", SenderMessage::Type::kRpc}}}; SenderMessage::Type GetMessageType(const Json::Value& root) { std::string type; - if (!json::ParseAndValidateString(root[kMessageType], &type)) { + if (!json::TryParseString(root[kMessageType], &type)) { return SenderMessage::Type::kUnknown; } @@ -45,29 +44,36 @@ ErrorOr<SenderMessage> SenderMessage::Parse(const Json::Value& value) { } SenderMessage message; - message.type = GetMessageType(value); - if (!json::ParseAndValidateInt(value[kSequenceNumber], - &(message.sequence_number))) { + if (!json::TryParseInt(value[kSequenceNumber], &(message.sequence_number))) { message.sequence_number = -1; } - if (message.type == SenderMessage::Type::kOffer) { - ErrorOr<Offer> offer = Offer::Parse(value[kOfferMessageBody]); - if (offer.is_value()) { - message.body = std::move(offer.value()); - message.valid = true; - } - } else if (message.type == SenderMessage::Type::kRpc) { - std::string rpc_body; - if (json::ParseAndValidateString(value[kRpcMessageBody], &rpc_body) && - base64::Decode(rpc_body, &rpc_body)) { - message.body = rpc_body; + message.type = GetMessageType(value); + switch (message.type) { + case Type::kOffer: { + Offer offer; + if (Offer::TryParse(value[kOfferMessageBody], &offer).ok()) { + message.body = std::move(offer); + message.valid = true; + } + } break; + + case Type::kRpc: { + std::string rpc_body; + std::vector<uint8_t> rpc; + if (json::TryParseString(value[kRpcMessageBody], &rpc_body) && + base64::Decode(rpc_body, &rpc)) { + message.body = rpc; + message.valid = true; + } + } break; + + case Type::kGetCapabilities: message.valid = true; - } - } else if (message.type == SenderMessage::Type::kGetStatus || - message.type == SenderMessage::Type::kGetCapabilities) { - // These types of messages just don't have a body. - message.valid = true; + break; + + default: + break; } return message; @@ -86,15 +92,15 @@ ErrorOr<Json::Value> SenderMessage::ToJson() const { switch (type) { case SenderMessage::Type::kOffer: - root[kOfferMessageBody] = absl::get<Offer>(body).ToJson().value(); + root[kOfferMessageBody] = absl::get<Offer>(body).ToJson(); break; case SenderMessage::Type::kRpc: - root[kRpcMessageBody] = base64::Encode(absl::get<std::string>(body)); + root[kRpcMessageBody] = + base64::Encode(absl::get<std::vector<uint8_t>>(body)); break; - case SenderMessage::Type::kGetCapabilities: // fallthrough - case SenderMessage::Type::kGetStatus: + case SenderMessage::Type::kGetCapabilities: break; default: diff --git a/cast/streaming/sender_message.h b/cast/streaming/sender_message.h index c016b31c..adb0b09e 100644 --- a/cast/streaming/sender_message.h +++ b/cast/streaming/sender_message.h @@ -28,9 +28,6 @@ struct SenderMessage { // OFFER request message. kOffer, - // GET_STATUS request message. - kGetStatus, - // GET_CAPABILITIES request message. kGetCapabilities, @@ -44,7 +41,11 @@ struct SenderMessage { Type type = Type::kUnknown; int32_t sequence_number = -1; bool valid = false; - absl::variant<absl::monostate, Offer, std::string> body; + absl::variant<absl::monostate, + std::vector<uint8_t>, // Binary-encoded RPC message. + Offer, + std::string> + body; }; } // namespace cast diff --git a/cast/streaming/sender_packet_router.cc b/cast/streaming/sender_packet_router.cc index c2b23dbf..684b1fb2 100644 --- a/cast/streaming/sender_packet_router.cc +++ b/cast/streaming/sender_packet_router.cc @@ -102,10 +102,11 @@ void SenderPacketRouter::OnReceivedPacket(const IPEndpoint& source, InspectPacketForRouting(packet); if (seems_like.first != ApparentPacketType::RTCP) { constexpr int kMaxPartiaHexDumpSize = 96; + const std::size_t encode_size = + std::min(packet.size(), static_cast<size_t>(kMaxPartiaHexDumpSize)); OSP_LOG_WARN << "UNKNOWN packet of " << packet.size() << " bytes. Partial hex dump: " - << HexEncode(absl::Span<const uint8_t>(packet).subspan( - 0, kMaxPartiaHexDumpSize)); + << HexEncode(packet.data(), encode_size); return; } const auto it = FindEntry(seems_like.second); diff --git a/cast/streaming/sender_session.cc b/cast/streaming/sender_session.cc index 91ed975f..c47e9667 100644 --- a/cast/streaming/sender_session.cc +++ b/cast/streaming/sender_session.cc @@ -24,64 +24,58 @@ #include "util/json/json_helpers.h" #include "util/json/json_serialization.h" #include "util/osp_logging.h" +#include "util/stringprintf.h" namespace openscreen { namespace cast { namespace { -AudioStream CreateStream(int index, const AudioCaptureConfig& config) { - return AudioStream{ - Stream{index, - Stream::Type::kAudioSource, - config.channels, - GetPayloadType(config.codec), - GenerateSsrc(true /*high_priority*/), - config.target_playout_delay, - GenerateRandomBytes16(), - GenerateRandomBytes16(), - false /* receiver_rtcp_event_log */, - {} /* receiver_rtcp_dscp */, - config.sample_rate}, - config.codec, - (config.bit_rate >= capture_recommendations::kDefaultAudioMinBitRate) - ? config.bit_rate - : capture_recommendations::kDefaultAudioMaxBitRate}; -} - -Resolution ToResolution(const DisplayResolution& display_resolution) { - return Resolution{display_resolution.width, display_resolution.height}; +AudioStream CreateStream(int index, + const AudioCaptureConfig& config, + bool use_android_rtp_hack) { + return AudioStream{Stream{index, + Stream::Type::kAudioSource, + config.channels, + GetPayloadType(config.codec, use_android_rtp_hack), + GenerateSsrc(true /*high_priority*/), + config.target_playout_delay, + GenerateRandomBytes16(), + GenerateRandomBytes16(), + false /* receiver_rtcp_event_log */, + {} /* receiver_rtcp_dscp */, + config.sample_rate, + config.codec_parameter}, + config.codec, + std::max(config.bit_rate, kDefaultAudioMinBitRate)}; } -VideoStream CreateStream(int index, const VideoCaptureConfig& config) { - std::vector<Resolution> resolutions; - std::transform(config.resolutions.begin(), config.resolutions.end(), - std::back_inserter(resolutions), ToResolution); - +VideoStream CreateStream(int index, + const VideoCaptureConfig& config, + bool use_android_rtp_hack) { constexpr int kVideoStreamChannelCount = 1; return VideoStream{ Stream{index, Stream::Type::kVideoSource, kVideoStreamChannelCount, - GetPayloadType(config.codec), + GetPayloadType(config.codec, use_android_rtp_hack), GenerateSsrc(false /*high_priority*/), config.target_playout_delay, GenerateRandomBytes16(), GenerateRandomBytes16(), false /* receiver_rtcp_event_log */, {} /* receiver_rtcp_dscp */, - kRtpVideoTimebase}, + kRtpVideoTimebase, + config.codec_parameter}, config.codec, - SimpleFraction{config.max_frame_rate.numerator, - config.max_frame_rate.denominator}, - (config.max_bit_rate > - capture_recommendations::kDefaultVideoBitRateLimits.minimum) + config.max_frame_rate, + (config.max_bit_rate >= kDefaultVideoMinBitRate) ? config.max_bit_rate - : capture_recommendations::kDefaultVideoBitRateLimits.maximum, + : kDefaultVideoMaxBitRate, {}, // protection {}, // profile {}, // protection - std::move(resolutions), + config.resolutions, {} /* error_recovery mode, always "castv2" */ }; } @@ -89,21 +83,50 @@ VideoStream CreateStream(int index, const VideoCaptureConfig& config) { template <typename S, typename C> void CreateStreamList(int offset_index, const std::vector<C>& configs, + bool use_android_rtp_hack, std::vector<S>* out) { out->reserve(configs.size()); for (size_t i = 0; i < configs.size(); ++i) { - out->emplace_back(CreateStream(i + offset_index, configs[i])); + out->emplace_back( + CreateStream(i + offset_index, configs[i], use_android_rtp_hack)); } } -Offer CreateOffer(const std::vector<AudioCaptureConfig>& audio_configs, - const std::vector<VideoCaptureConfig>& video_configs) { +Offer CreateMirroringOffer(const std::vector<AudioCaptureConfig>& audio_configs, + const std::vector<VideoCaptureConfig>& video_configs, + bool use_android_rtp_hack) { Offer offer; + offer.cast_mode = CastMode::kMirroring; // NOTE here: IDs will always follow the pattern: // [0.. audio streams... N - 1][N.. video streams.. K] - CreateStreamList(0, audio_configs, &offer.audio_streams); - CreateStreamList(audio_configs.size(), video_configs, &offer.video_streams); + CreateStreamList(0, audio_configs, use_android_rtp_hack, + &offer.audio_streams); + CreateStreamList(audio_configs.size(), video_configs, use_android_rtp_hack, + &offer.video_streams); + + return offer; +} + +Offer CreateRemotingOffer(const AudioCaptureConfig& audio_config, + const VideoCaptureConfig& video_config, + bool use_android_rtp_hack) { + Offer offer; + offer.cast_mode = CastMode::kRemoting; + + AudioStream audio_stream = + CreateStream(0, audio_config, use_android_rtp_hack); + audio_stream.codec = AudioCodec::kNotSpecified; + audio_stream.stream.rtp_payload_type = + GetPayloadType(AudioCodec::kNotSpecified, use_android_rtp_hack); + offer.audio_streams.push_back(std::move(audio_stream)); + + VideoStream video_stream = + CreateStream(1, video_config, use_android_rtp_hack); + video_stream.codec = VideoCodec::kNotSpecified; + video_stream.stream.rtp_payload_type = + GetPayloadType(VideoCodec::kNotSpecified, use_android_rtp_hack); + offer.video_streams.push_back(std::move(video_stream)); return offer; } @@ -112,20 +135,19 @@ bool IsValidAudioCaptureConfig(const AudioCaptureConfig& config) { return config.channels >= 1 && config.bit_rate >= 0; } -bool IsValidResolution(const DisplayResolution& resolution) { +// We don't support resolutions below our minimums. +bool IsSupportedResolution(const Resolution& resolution) { return resolution.width > kMinVideoWidth && resolution.height > kMinVideoHeight; } bool IsValidVideoCaptureConfig(const VideoCaptureConfig& config) { - return config.max_frame_rate.numerator > 0 && - config.max_frame_rate.denominator > 0 && + return config.max_frame_rate.is_positive() && ((config.max_bit_rate == 0) || - (config.max_bit_rate >= - capture_recommendations::kDefaultVideoBitRateLimits.minimum)) && + (config.max_bit_rate >= kDefaultVideoMinBitRate)) && !config.resolutions.empty() && std::all_of(config.resolutions.begin(), config.resolutions.end(), - IsValidResolution); + IsSupportedResolution); } bool AreAllValid(const std::vector<AudioCaptureConfig>& audio_configs, @@ -136,38 +158,83 @@ bool AreAllValid(const std::vector<AudioCaptureConfig>& audio_configs, IsValidVideoCaptureConfig); } +RemotingCapabilities ToCapabilities(const ReceiverCapability& capability) { + RemotingCapabilities out; + for (MediaCapability c : capability.media_capabilities) { + switch (c) { + case MediaCapability::kAudio: + out.audio.push_back(AudioCapability::kBaselineSet); + break; + case MediaCapability::kAac: + out.audio.push_back(AudioCapability::kAac); + break; + case MediaCapability::kOpus: + out.audio.push_back(AudioCapability::kOpus); + break; + case MediaCapability::k4k: + out.video.push_back(VideoCapability::kSupports4k); + break; + case MediaCapability::kH264: + out.video.push_back(VideoCapability::kH264); + break; + case MediaCapability::kVp8: + out.video.push_back(VideoCapability::kVp8); + break; + case MediaCapability::kVp9: + out.video.push_back(VideoCapability::kVp9); + break; + case MediaCapability::kHevc: + out.video.push_back(VideoCapability::kHevc); + break; + case MediaCapability::kAv1: + out.video.push_back(VideoCapability::kAv1); + break; + case MediaCapability::kVideo: + // noop, as "video" is ignored by Chrome remoting. + break; + + default: + OSP_NOTREACHED(); + } + } + return out; +} + } // namespace SenderSession::Client::~Client() = default; -SenderSession::SenderSession(IPAddress remote_address, - Client* const client, - Environment* environment, - MessagePort* message_port, - std::string message_source_id, - std::string message_destination_id) - : remote_address_(remote_address), - client_(client), - environment_(environment), - messager_( - message_port, - std::move(message_source_id), - std::move(message_destination_id), +SenderSession::SenderSession(Configuration config) + : config_(config), + messenger_( + config_.message_port, + config_.message_source_id, + config_.message_destination_id, [this](Error error) { OSP_DLOG_WARN << "SenderSession message port error: " << error; - client_->OnError(this, error); + config_.client->OnError(this, error); }, - environment->task_runner()), - packet_router_(environment_) { - OSP_DCHECK(client_); - OSP_DCHECK(environment_); + config_.environment->task_runner()), + rpc_messenger_([this](std::vector<uint8_t> message) { + SendRpcMessage(std::move(message)); + }), + packet_router_(config_.environment) { + OSP_DCHECK(config_.client); + OSP_DCHECK(config_.environment); + + // We may or may not do remoting this session, however our RPC handler + // is not negotiation-specific and registering on construction here allows us + // to record any unexpected RPC messages. + messenger_.SetHandler(ReceiverMessage::Type::kRpc, + [this](ReceiverMessage message) { + this->OnRpcMessage(std::move(message)); + }); } SenderSession::~SenderSession() = default; -Error SenderSession::NegotiateMirroring( - std::vector<AudioCaptureConfig> audio_configs, - std::vector<VideoCaptureConfig> video_configs) { +Error SenderSession::Negotiate(std::vector<AudioCaptureConfig> audio_configs, + std::vector<VideoCaptureConfig> video_configs) { // Negotiating with no streams doesn't make any sense. if (audio_configs.empty() && video_configs.empty()) { return Error(Error::Code::kParameterInvalid, @@ -177,45 +244,160 @@ Error SenderSession::NegotiateMirroring( return Error(Error::Code::kParameterInvalid, "Invalid configs provided."); } - Offer offer = CreateOffer(audio_configs, video_configs); - current_negotiation_ = std::unique_ptr<Negotiation>(new Negotiation{ - offer, std::move(audio_configs), std::move(video_configs)}); + Offer offer = CreateMirroringOffer(audio_configs, video_configs, + config_.use_android_rtp_hack); + return StartNegotiation(std::move(audio_configs), std::move(video_configs), + std::move(offer)); +} - return messager_.SendRequest( - SenderMessage{SenderMessage::Type::kOffer, ++current_sequence_number_, - true, std::move(offer)}, - ReceiverMessage::Type::kAnswer, - [this](ReceiverMessage message) { OnAnswer(message); }); +Error SenderSession::NegotiateRemoting(AudioCaptureConfig audio_config, + VideoCaptureConfig video_config) { + // Remoting requires both an audio and a video configuration. + if (!IsValidAudioCaptureConfig(audio_config) || + !IsValidVideoCaptureConfig(video_config)) { + return Error(Error::Code::kParameterInvalid, + "Passed invalid audio or video config."); + } + + Offer offer = CreateRemotingOffer(audio_config, video_config, + config_.use_android_rtp_hack); + return StartNegotiation({audio_config}, {video_config}, std::move(offer)); } int SenderSession::GetEstimatedNetworkBandwidth() const { return packet_router_.ComputeNetworkBandwidth(); } +void SenderSession::ResetState() { + state_ = State::kIdle; + current_negotiation_.reset(); + current_audio_sender_.reset(); + current_video_sender_.reset(); +} + +Error SenderSession::StartNegotiation( + std::vector<AudioCaptureConfig> audio_configs, + std::vector<VideoCaptureConfig> video_configs, + Offer offer) { + current_negotiation_ = + std::unique_ptr<InProcessNegotiation>(new InProcessNegotiation{ + offer, std::move(audio_configs), std::move(video_configs)}); + + return messenger_.SendRequest( + SenderMessage{SenderMessage::Type::kOffer, ++current_sequence_number_, + true, std::move(offer)}, + ReceiverMessage::Type::kAnswer, + [this](ReceiverMessage message) { OnAnswer(message); }); +} + void SenderSession::OnAnswer(ReceiverMessage message) { - OSP_LOG_WARN << "Message sn: " << message.sequence_number - << ", current: " << current_sequence_number_; if (!message.valid) { - if (absl::holds_alternative<ReceiverError>(message.body)) { - client_->OnError( - this, Error(Error::Code::kParameterInvalid, - absl::get<ReceiverError>(message.body).description)); - } else { - client_->OnError(this, Error(Error::Code::kJsonParseError, - "Received invalid answer message")); - } + HandleErrorMessage(message, "Invalid answer response message"); return; } + // There isn't an obvious way to tell from the Answer whether it is mirroring + // or remoting specific--the only clues are in the original offer message. const Answer& answer = absl::get<Answer>(message.body); - ConfiguredSenders senders = SpawnSenders(answer); + if (current_negotiation_->offer.cast_mode == CastMode::kMirroring) { + ConfiguredSenders senders = SpawnSenders(answer); + // If we didn't select any senders, the negotiation was unsuccessful. + if (senders.audio_sender == nullptr && senders.video_sender == nullptr) { + return; + } + + state_ = State::kStreaming; + config_.client->OnNegotiated( + this, std::move(senders), + capture_recommendations::GetRecommendations(answer)); + } else { + state_ = State::kRemoting; + + // We don't want to spawn senders yet, since we don't know what the + // receiver's capabilities are. So, we cache the Answer until the + // capabilites request is completed. + current_negotiation_->answer = answer; + const Error result = messenger_.SendRequest( + SenderMessage{SenderMessage::Type::kGetCapabilities, + ++current_sequence_number_, true}, + ReceiverMessage::Type::kCapabilitiesResponse, + [this](ReceiverMessage message) { OnCapabilitiesResponse(message); }); + if (!result.ok()) { + config_.client->OnError( + this, Error(Error::Code::kNegotiationFailure, + "Failed to set a GET_CAPABILITIES request")); + } + } +} + +void SenderSession::OnCapabilitiesResponse(ReceiverMessage message) { + if (!current_negotiation_ || !current_negotiation_->answer.IsValid()) { + OSP_LOG_INFO + << "Received a capabilities response, but not negotiating anything."; + return; + } + + if (!message.valid) { + HandleErrorMessage( + message, + "Bad CAPABILITIES_RESPONSE, assuming remoting is not supported"); + return; + } + + const ReceiverCapability& caps = absl::get<ReceiverCapability>(message.body); + int remoting_version = caps.remoting_version; + // If not set, we assume it is version 1. + if (remoting_version == ReceiverCapability::kRemotingVersionUnknown) { + remoting_version = 1; + } + + if (remoting_version > kSupportedRemotingVersion) { + std::string message = StringPrintf( + "Receiver is using too new of a version for remoting (%d > %d)", + remoting_version, kSupportedRemotingVersion); + config_.client->OnError( + this, Error(Error::Code::kRemotingNotSupported, std::move(message))); + return; + } + + ConfiguredSenders senders = SpawnSenders(current_negotiation_->answer); // If we didn't select any senders, the negotiation was unsuccessful. if (senders.audio_sender == nullptr && senders.video_sender == nullptr) { + config_.client->OnError(this, + Error(Error::Code::kNegotiationFailure, + "Failed to negotiate a remoting session.")); return; } - client_->OnMirroringNegotiated( - this, std::move(senders), - capture_recommendations::GetRecommendations(answer)); + + config_.client->OnRemotingNegotiated( + this, RemotingNegotiation{std::move(senders), ToCapabilities(caps)}); +} + +void SenderSession::OnRpcMessage(ReceiverMessage message) { + if (!message.valid) { + HandleErrorMessage( + message, + "Bad RPC message. This may or may not represent a serious problem"); + return; + } + + const auto& body = absl::get<std::vector<uint8_t>>(message.body); + rpc_messenger_.ProcessMessageFromRemote(body.data(), body.size()); +} + +void SenderSession::HandleErrorMessage(ReceiverMessage message, + const std::string& text) { + OSP_DCHECK(!message.valid); + if (absl::holds_alternative<ReceiverError>(message.body)) { + const ReceiverError& error = absl::get<ReceiverError>(message.body); + std::string error_text = + StringPrintf("%s. Error code: %d, description: %s", text.c_str(), + error.code, error.description.c_str()); + config_.client->OnError( + this, Error(Error::Code::kParameterInvalid, std::move(error_text))); + } else { + config_.client->OnError(this, Error(Error::Code::kJsonParseError, text)); + } } std::unique_ptr<Sender> SenderSession::CreateSender(Ssrc receiver_ssrc, @@ -231,7 +413,7 @@ std::unique_ptr<Sender> SenderSession::CreateSender(Ssrc receiver_ssrc, stream.aes_iv_mask, /* is_pli_enabled*/ true}; - return std::make_unique<Sender>(environment_, &packet_router_, + return std::make_unique<Sender>(config_.environment, &packet_router_, std::move(config), type); } @@ -241,7 +423,8 @@ void SenderSession::SpawnAudioSender(ConfiguredSenders* senders, int config_index) { const AudioCaptureConfig& config = current_negotiation_->audio_configs[config_index]; - const RtpPayloadType payload_type = GetPayloadType(config.codec); + const RtpPayloadType payload_type = + GetPayloadType(config.codec, config_.use_android_rtp_hack); for (const AudioStream& stream : current_negotiation_->offer.audio_streams) { if (stream.stream.index == send_index) { current_audio_sender_ = @@ -259,7 +442,8 @@ void SenderSession::SpawnVideoSender(ConfiguredSenders* senders, int config_index) { const VideoCaptureConfig& config = current_negotiation_->video_configs[config_index]; - const RtpPayloadType payload_type = GetPayloadType(config.codec); + const RtpPayloadType payload_type = + GetPayloadType(config.codec, config_.use_android_rtp_hack); for (const VideoStream& stream : current_negotiation_->offer.video_streams) { if (stream.stream.index == send_index) { current_video_sender_ = @@ -278,9 +462,10 @@ SenderSession::ConfiguredSenders SenderSession::SpawnSenders( // Although we already have a message port set up with the TLS // address of the receiver, we don't know where to send the separate UDP // stream until we get the ANSWER message here. - environment_->set_remote_endpoint( - IPEndpoint{remote_address_, static_cast<uint16_t>(answer.udp_port)}); - OSP_LOG_INFO << "Streaming to " << environment_->remote_endpoint() << "..."; + config_.environment->set_remote_endpoint(IPEndpoint{ + config_.remote_address, static_cast<uint16_t>(answer.udp_port)}); + OSP_LOG_INFO << "Streaming to " << config_.environment->remote_endpoint() + << "..."; ConfiguredSenders senders; for (size_t i = 0; i < answer.send_indexes.size(); ++i) { @@ -299,5 +484,15 @@ SenderSession::ConfiguredSenders SenderSession::SpawnSenders( return senders; } +void SenderSession::SendRpcMessage(std::vector<uint8_t> message_body) { + Error error = this->messenger_.SendOutboundMessage(SenderMessage{ + SenderMessage::Type::kRpc, ++(this->current_sequence_number_), true, + std::move(message_body)}); + + if (!error.ok()) { + OSP_LOG_WARN << "Failed to send RPC message: " << error; + } +} + } // namespace cast } // namespace openscreen diff --git a/cast/streaming/sender_session.h b/cast/streaming/sender_session.h index cba16209..ef68df73 100644 --- a/cast/streaming/sender_session.h +++ b/cast/streaming/sender_session.h @@ -13,21 +13,20 @@ #include "cast/common/public/message_port.h" #include "cast/streaming/answer_messages.h" #include "cast/streaming/capture_configs.h" +#include "cast/streaming/capture_recommendations.h" #include "cast/streaming/offer_messages.h" +#include "cast/streaming/remoting_capabilities.h" +#include "cast/streaming/rpc_messenger.h" #include "cast/streaming/sender.h" #include "cast/streaming/sender_packet_router.h" #include "cast/streaming/session_config.h" -#include "cast/streaming/session_messager.h" +#include "cast/streaming/session_messenger.h" #include "json/value.h" #include "util/json/json_serialization.h" namespace openscreen { namespace cast { -namespace capture_recommendations { -struct Recommendations; -} - class Environment; class Sender; @@ -42,37 +41,85 @@ class SenderSession final { // If the sender is audio- or video-only, either of the senders // may be nullptr. However, in the majority of cases they will be populated. - Sender* audio_sender; + Sender* audio_sender = nullptr; AudioCaptureConfig audio_config; - Sender* video_sender; + Sender* video_sender = nullptr; VideoCaptureConfig video_config; }; - // The embedder should provide a client for handling the negotiation. - // When the negotiation is complete, the OnMirroringNegotiated callback is - // called. + // This struct contains all of the information necessary to begin remoting + // after we receive the capabilities from the receiver. + struct RemotingNegotiation { + ConfiguredSenders senders; + + // The capabilities reported by the connected receiver. NOTE: SenderSession + // reports the capabilities as-is from the Receiver, so clients concerned + // about legacy devices, such as pre-1.27 Earth receivers should do + // a version check when using these capabilities to offer remoting. + RemotingCapabilities capabilities; + }; + + // The embedder should provide a client for handling negotiation events. + // The client is required to implement a mirorring handler, and may choose + // to provide a remoting negotiation if it supports remoting. + // When the negotiation is complete, the appropriate |On*Negotiated| handler + // is called. class Client { public: // Called when a new set of senders has been negotiated. This may be - // called multiple times during a session, once for every time - // NegotiateMirroring() is called on the SenderSession object. The - // negotiation call also includes capture recommendations that can be used - // by the sender to provide an optimal video stream for the receiver. - virtual void OnMirroringNegotiated( + // called multiple times during a session, once for every time Negotiate() + // is called on the SenderSession object. The negotiation call also includes + // capture recommendations that can be used by the sender to provide + // an optimal video stream for the receiver. + virtual void OnNegotiated( const SenderSession* session, ConfiguredSenders senders, capture_recommendations::Recommendations capture_recommendations) = 0; + // Called when a new set of remoting senders has been negotiated. Since + // remoting is an optional feature, the default behavior here is to leave + // this method unhandled. + virtual void OnRemotingNegotiated(const SenderSession* session, + RemotingNegotiation negotiation) {} + // Called whenever an error occurs. Ends the ongoing session, and the caller - // must call NegotiateMirroring() again if they wish to re-establish - // streaming. + // must call Negotiate() again if they wish to re-establish streaming. virtual void OnError(const SenderSession* session, Error error) = 0; protected: virtual ~Client(); }; + // The configuration information required to set up the session. + struct Configuration { + // The remote address of the receiver to connect to. NOTE: we do eventually + // set the remote endpoint on the |environment| object, but only after + // getting the port information from a successful ANSWER message. + IPAddress remote_address; + + // The client for notifying of successful negotiations and errors. Required. + Client* const client; + + // The cast environment used to access operating system resources, such + // as the UDP socket for RTP/RTCP messaging. Required. + Environment* environment; + + // The message port used to send streaming control protocol messages. + MessagePort* message_port; + + // The message source identifier (e.g. this sender). + std::string message_source_id; + + // The message destination identifier (e.g. the receiver we are connected + // to). + std::string message_destination_id; + + // Whether or not the android RTP value hack should be used (for legacy + // android devices). For more information, see https://crbug.com/631828. + bool use_android_rtp_hack = true; + }; + // The SenderSession assumes that the passed in client, environment, and // message port persist for at least the lifetime of the SenderSession. If // one of these classes needs to be reset, a new SenderSession should be @@ -82,42 +129,85 @@ class SenderSession final { // ID, respectively, to use when sending or receiving control messages (e.g., // OFFERs or ANSWERs) over the |message_port|. |message_port|'s SetClient() // method will be called. - SenderSession(IPAddress remote_address, - Client* const client, - Environment* environment, - MessagePort* message_port, - std::string message_source_id, - std::string message_destination_id); + explicit SenderSession(Configuration config); SenderSession(const SenderSession&) = delete; SenderSession(SenderSession&&) noexcept = delete; SenderSession& operator=(const SenderSession&) = delete; SenderSession& operator=(SenderSession&&) = delete; ~SenderSession(); - // Starts an OFFER/ANSWER exchange with the already configured receiver - // over the message port. The caller should assume any configured senders - // become invalid when calling this method. - Error NegotiateMirroring(std::vector<AudioCaptureConfig> audio_configs, - std::vector<VideoCaptureConfig> video_configs); + // Starts a mirroring OFFER/ANSWER exchange with the already configured + // receiver over the message port. The caller should assume any configured + // senders become invalid when calling this method. + Error Negotiate(std::vector<AudioCaptureConfig> audio_configs, + std::vector<VideoCaptureConfig> video_configs); + + // Remoting negotiation is actually very similar to mirroring negotiation-- + // an OFFER/ANSWER exchange still occurs, however only one audio and video + // codec should be presented based on the encoding of the media element that + // should be remoted. Note: the codec fields in |audio_config| and + // |video_config| are ignored in favor of |kRemote|. + Error NegotiateRemoting(AudioCaptureConfig audio_config, + VideoCaptureConfig video_config); // Get the current network usage (in bits per second). This includes all // senders managed by this session, and is a best guess based on receiver // feedback. Embedders may use this information to throttle capture devices. int GetEstimatedNetworkBandwidth() const; + // The RPC messenger for this session. NOTE: RPC messages may come at + // any time from the receiver, so subscriptions to RPC remoting messages + // should be done before calling |NegotiateRemoting|. + RpcMessenger* rpc_messenger() { return &rpc_messenger_; } + private: // We store the current negotiation, so that when we get an answer from the // receiver we can line up the selected streams with the original // configuration. - struct Negotiation { + struct InProcessNegotiation { + // The offer, which should always be valid if we have an in process + // negotiation. Offer offer; + // The configs used to derive the offer. std::vector<AudioCaptureConfig> audio_configs; std::vector<VideoCaptureConfig> video_configs; + + // The answer message for this negotiation, which may be invalid if we + // haven't received an answer yet. + Answer answer; + }; + + // The state of the session. + enum class State { + // Not sending content--may be in the middle of negotiation, or just + // waiting. + kIdle, + + // Currently mirroring content to a receiver. + kStreaming, + + // Currently remoting content to a receiver. + kRemoting }; + // Reset the state and tear down the current negotiation/negotiated mirroring + // or remoting session. After reset, the SenderSession is still connected to + // the same |remote_address_|, and the |packet_router_| and sequence number + // will be unchanged. + void ResetState(); + + // Uses the passed in configs and offer to send an OFFER/ANSWER negotiation + // and cache the new InProcessNavigation. + Error StartNegotiation(std::vector<AudioCaptureConfig> audio_configs, + std::vector<VideoCaptureConfig> video_configs, + Offer offer); + // Specific message type handler methods. void OnAnswer(ReceiverMessage message); + void OnCapabilitiesResponse(ReceiverMessage message); + void OnRpcMessage(ReceiverMessage message); + void HandleErrorMessage(ReceiverMessage message, const std::string& text); // Used by SpawnSenders to generate a sender for a specific stream. std::unique_ptr<Sender> CreateSender(Ssrc receiver_ssrc, @@ -137,18 +227,22 @@ class SenderSession final { // Spawn a set of configured senders from the currently stored negotiation. ConfiguredSenders SpawnSenders(const Answer& answer); - // The remote address of the receiver we are communicating with. Used - // for both TLS and UDP traffic. - const IPAddress remote_address_; + // Used by the RPC messenger to send outbound messages. + void SendRpcMessage(std::vector<uint8_t> message_body); - // The embedder is expected to provide us a client for notifications about - // negotiations and errors, a valid cast environment, and a messaging - // port for communicating to the Receiver over TLS. - Client* const client_; - Environment* const environment_; - SenderSessionMessager messager_; + // This session's configuration. + Configuration config_; - // The packet router used for messaging across all senders. + // The session messenger, which uses the message port for sending control + // messages. For message formats, see + // cast/protocol/castv2/streaming_schema.json. + SenderSessionMessenger messenger_; + + // The RPC messenger, which uses the session messager for sending RPC messages + // and handles subscriptions to RPC messages. + RpcMessenger rpc_messenger_; + + // The packet router used for RTP/RTCP messaging across all senders. SenderPacketRouter packet_router_; // Each negotiation has its own sequence number, and the receiver replies @@ -158,7 +252,12 @@ class SenderSession final { // The current negotiation. If present, we are expected an ANSWER from // the receiver. If not present, any provided ANSWERS are rejected. - std::unique_ptr<Negotiation> current_negotiation_; + std::unique_ptr<InProcessNegotiation> current_negotiation_; + + // The current state of the session. Note that the state is intentionally + // limited. |kStreaming| or |kRemoting| means that we are either starting + // a negotiation or actively sending to a receiver. + State state_ = State::kIdle; // If the negotiation has succeeded, we store the current audio and video // senders used for this session. Either or both may be nullptr. diff --git a/cast/streaming/sender_session_unittest.cc b/cast/streaming/sender_session_unittest.cc index 4aaf9e46..227e68e7 100644 --- a/cast/streaming/sender_session_unittest.cc +++ b/cast/streaming/sender_session_unittest.cc @@ -96,42 +96,68 @@ constexpr char kErrorAnswerMessage[] = R"({ } })"; +constexpr char kCapabilitiesResponse[] = R"({ + "seqNum": 2, + "result": "ok", + "type": "CAPABILITIES_RESPONSE", + "capabilities": { + "mediaCaps": ["video", "vp8", "audio", "aac"] + } +})"; + const AudioCaptureConfig kAudioCaptureConfigInvalidChannels{ AudioCodec::kAac, -1 /* channels */, 44000 /* bit_rate */, 96000 /* sample_rate */ }; const AudioCaptureConfig kAudioCaptureConfigValid{ - AudioCodec::kOpus, 5 /* channels */, 32000 /* bit_rate */, - 44000 /* sample_rate */ -}; + AudioCodec::kAac, + 5 /* channels */, + 32000 /* bit_rate */, + 44000 /* sample_rate */, + std::chrono::milliseconds(300), + "mp4a.40.5"}; const VideoCaptureConfig kVideoCaptureConfigMissingResolutions{ - VideoCodec::kHevc, FrameRate{60, 1}, 300000 /* max_bit_rate */, - std::vector<DisplayResolution>{}}; + VideoCodec::kHevc, + {60, 1}, + 300000 /* max_bit_rate */, + std::vector<Resolution>{}, + std::chrono::milliseconds(500), + "hev1.1.6.L150.B0"}; const VideoCaptureConfig kVideoCaptureConfigInvalid{ - VideoCodec::kHevc, FrameRate{60, 1}, -300000 /* max_bit_rate */, - std::vector<DisplayResolution>{DisplayResolution{1920, 1080}, - DisplayResolution{1280, 720}}}; + VideoCodec::kHevc, + {60, 1}, + -300000 /* max_bit_rate */, + std::vector<Resolution>{Resolution{1920, 1080}, Resolution{1280, 720}}}; const VideoCaptureConfig kVideoCaptureConfigValid{ - VideoCodec::kHevc, FrameRate{60, 1}, 300000 /* max_bit_rate */, - std::vector<DisplayResolution>{DisplayResolution{1280, 720}, - DisplayResolution{1920, 1080}}}; + VideoCodec::kHevc, + {60, 1}, + 300000 /* max_bit_rate */, + std::vector<Resolution>{Resolution{1280, 720}, Resolution{1920, 1080}}, + std::chrono::milliseconds(250), + "hev1.1.6.L150.B0"}; const VideoCaptureConfig kVideoCaptureConfigValidSimplest{ - VideoCodec::kHevc, FrameRate{60, 1}, 300000 /* max_bit_rate */, - std::vector<DisplayResolution>{DisplayResolution{1920, 1080}}}; + VideoCodec::kHevc, + {60, 1}, + 300000 /* max_bit_rate */, + std::vector<Resolution>{Resolution{1920, 1080}}}; class FakeClient : public SenderSession::Client { public: MOCK_METHOD(void, - OnMirroringNegotiated, + OnNegotiated, (const SenderSession*, SenderSession::ConfiguredSenders, capture_recommendations::Recommendations), (override)); + MOCK_METHOD(void, + OnRemotingNegotiated, + (const SenderSession*, SenderSession::RemotingNegotiation), + (override)); MOCK_METHOD(void, OnError, (const SenderSession*, Error error), (override)); }; @@ -153,19 +179,32 @@ class SenderSessionTest : public ::testing::Test { void SetUp() { message_port_ = std::make_unique<SimpleMessagePort>("receiver-12345"); environment_ = MakeEnvironment(); - session_ = std::make_unique<SenderSession>( - IPAddress::kV4LoopbackAddress(), &client_, environment_.get(), - message_port_.get(), "sender-12345", "receiver-12345"); + + SenderSession::Configuration config{IPAddress::kV4LoopbackAddress(), + &client_, + environment_.get(), + message_port_.get(), + "sender-12345", + "receiver-12345", + /* use_android_rtp_hack */ true}; + session_ = std::make_unique<SenderSession>(std::move(config)); } - std::string NegotiateOfferAndConstructAnswer() { - const Error error = session_->NegotiateMirroring( + void NegotiateMirroringWithValidConfigs() { + const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); - if (!error.ok()) { - return {}; - } + ASSERT_TRUE(error.ok()); + } + void NegotiateRemotingWithValidConfigs() { + const Error error = session_->NegotiateRemoting(kAudioCaptureConfigValid, + kVideoCaptureConfigValid); + ASSERT_TRUE(error.ok()); + } + + // Answers require specific fields from the original offer to be valid. + std::string ConstructAnswerFromOffer(CastMode mode) { const auto& messages = message_port_->posted_messages(); if (messages.size() != 1) { return {}; @@ -200,14 +239,16 @@ class SenderSessionTest : public ::testing::Test { "seqNum": %d, "result": "ok", "answer": { - "castMode": "mirroring", + "castMode": "%s", "udpPort": 1234, "sendIndexes": [%d, %d], "ssrcs": [%d, %d] } })"; - return StringPrintf(kAnswerTemplate, offer["seqNum"].asInt(), audio_index, - video_index, audio_ssrc + 1, video_ssrc + 1); + return StringPrintf(kAnswerTemplate, offer["seqNum"].asInt(), + mode == CastMode::kMirroring ? "mirroring" : "remoting", + audio_index, video_index, audio_ssrc + 1, + video_ssrc + 1); } protected: @@ -220,8 +261,8 @@ class SenderSessionTest : public ::testing::Test { }; TEST_F(SenderSessionTest, ComplainsIfNoConfigsToOffer) { - const Error error = session_->NegotiateMirroring( - std::vector<AudioCaptureConfig>{}, std::vector<VideoCaptureConfig>{}); + const Error error = session_->Negotiate(std::vector<AudioCaptureConfig>{}, + std::vector<VideoCaptureConfig>{}); EXPECT_EQ(error, Error(Error::Code::kParameterInvalid, @@ -229,7 +270,7 @@ TEST_F(SenderSessionTest, ComplainsIfNoConfigsToOffer) { } TEST_F(SenderSessionTest, ComplainsIfInvalidAudioCaptureConfig) { - const Error error = session_->NegotiateMirroring( + const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigInvalidChannels}, std::vector<VideoCaptureConfig>{}); @@ -238,7 +279,7 @@ TEST_F(SenderSessionTest, ComplainsIfInvalidAudioCaptureConfig) { } TEST_F(SenderSessionTest, ComplainsIfInvalidVideoCaptureConfig) { - const Error error = session_->NegotiateMirroring( + const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigInvalid}); EXPECT_EQ(error, @@ -246,7 +287,7 @@ TEST_F(SenderSessionTest, ComplainsIfInvalidVideoCaptureConfig) { } TEST_F(SenderSessionTest, ComplainsIfMissingResolutions) { - const Error error = session_->NegotiateMirroring( + const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigMissingResolutions}); EXPECT_EQ(error, @@ -259,9 +300,9 @@ TEST_F(SenderSessionTest, SendsOfferWithZeroBitrateOptions) { AudioCaptureConfig audio_config = kAudioCaptureConfigValid; audio_config.bit_rate = 0; - const Error error = session_->NegotiateMirroring( - std::vector<AudioCaptureConfig>{audio_config}, - std::vector<VideoCaptureConfig>{video_config}); + const Error error = + session_->Negotiate(std::vector<AudioCaptureConfig>{audio_config}, + std::vector<VideoCaptureConfig>{video_config}); EXPECT_TRUE(error.ok()); const auto& messages = message_port_->posted_messages(); @@ -273,7 +314,7 @@ TEST_F(SenderSessionTest, SendsOfferWithZeroBitrateOptions) { } TEST_F(SenderSessionTest, SendsOfferWithSimpleVideoOnly) { - const Error error = session_->NegotiateMirroring( + const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); EXPECT_TRUE(error.ok()); @@ -287,7 +328,7 @@ TEST_F(SenderSessionTest, SendsOfferWithSimpleVideoOnly) { } TEST_F(SenderSessionTest, SendsOfferAudioOnly) { - const Error error = session_->NegotiateMirroring( + const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{}); EXPECT_TRUE(error.ok()); @@ -301,7 +342,7 @@ TEST_F(SenderSessionTest, SendsOfferAudioOnly) { } TEST_F(SenderSessionTest, SendsOfferMessage) { - session_->NegotiateMirroring( + session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -318,20 +359,20 @@ TEST_F(SenderSessionTest, SendsOfferMessage) { ASSERT_FALSE(offer_body.isNull()); ASSERT_TRUE(offer_body.isObject()); EXPECT_EQ("mirroring", offer_body["castMode"].asString()); - EXPECT_EQ(false, offer_body["receiverGetStatus"].asBool()); const Json::Value& streams = offer_body["supportedStreams"]; EXPECT_TRUE(streams.isArray()); EXPECT_EQ(2u, streams.size()); const Json::Value& audio_stream = streams[0]; - EXPECT_EQ("opus", audio_stream["codecName"].asString()); + EXPECT_EQ("aac", audio_stream["codecName"].asString()); EXPECT_EQ(0, audio_stream["index"].asInt()); EXPECT_EQ(32u, audio_stream["aesKey"].asString().length()); EXPECT_EQ(32u, audio_stream["aesIvMask"].asString().length()); EXPECT_EQ(5, audio_stream["channels"].asInt()); EXPECT_LT(0u, audio_stream["ssrc"].asUInt()); EXPECT_EQ(127, audio_stream["rtpPayloadType"].asInt()); + EXPECT_EQ("mp4a.40.5", audio_stream["codecParameter"].asString()); const Json::Value& video_stream = streams[1]; EXPECT_EQ("hevc", video_stream["codecName"].asString()); @@ -341,22 +382,25 @@ TEST_F(SenderSessionTest, SendsOfferMessage) { EXPECT_EQ(1, video_stream["channels"].asInt()); EXPECT_LT(0u, video_stream["ssrc"].asUInt()); EXPECT_EQ(96, video_stream["rtpPayloadType"].asInt()); + EXPECT_EQ("hev1.1.6.L150.B0", video_stream["codecParameter"].asString()); } TEST_F(SenderSessionTest, HandlesValidAnswer) { - std::string answer = NegotiateOfferAndConstructAnswer(); + NegotiateMirroringWithValidConfigs(); + std::string answer = ConstructAnswerFromOffer(CastMode::kMirroring); - EXPECT_CALL(client_, OnMirroringNegotiated(session_.get(), _, _)); + EXPECT_CALL(client_, OnNegotiated(session_.get(), _, _)); message_port_->ReceiveMessage(answer); } TEST_F(SenderSessionTest, HandlesInvalidNamespace) { - std::string answer = NegotiateOfferAndConstructAnswer(); + NegotiateMirroringWithValidConfigs(); + std::string answer = ConstructAnswerFromOffer(CastMode::kMirroring); message_port_->ReceiveMessage("random-namespace", answer); } TEST_F(SenderSessionTest, HandlesMalformedAnswer) { - session_->NegotiateMirroring( + session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -368,7 +412,7 @@ TEST_F(SenderSessionTest, HandlesMalformedAnswer) { } TEST_F(SenderSessionTest, HandlesImproperlyFormattedAnswer) { - session_->NegotiateMirroring( + session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -377,7 +421,7 @@ TEST_F(SenderSessionTest, HandlesImproperlyFormattedAnswer) { } TEST_F(SenderSessionTest, HandlesInvalidAnswer) { - const Error error = session_->NegotiateMirroring( + const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -386,7 +430,7 @@ TEST_F(SenderSessionTest, HandlesInvalidAnswer) { } TEST_F(SenderSessionTest, HandlesNullAnswer) { - const Error error = session_->NegotiateMirroring( + const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -396,7 +440,7 @@ TEST_F(SenderSessionTest, HandlesNullAnswer) { } TEST_F(SenderSessionTest, HandlesInvalidSequenceNumber) { - const Error error = session_->NegotiateMirroring( + const Error error = session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -405,7 +449,7 @@ TEST_F(SenderSessionTest, HandlesInvalidSequenceNumber) { } TEST_F(SenderSessionTest, HandlesUnknownTypeMessageWithValidSeqNum) { - session_->NegotiateMirroring( + session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -416,7 +460,7 @@ TEST_F(SenderSessionTest, HandlesUnknownTypeMessageWithValidSeqNum) { } TEST_F(SenderSessionTest, HandlesInvalidTypeMessageWithValidSeqNum) { - session_->NegotiateMirroring( + session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -427,7 +471,7 @@ TEST_F(SenderSessionTest, HandlesInvalidTypeMessageWithValidSeqNum) { } TEST_F(SenderSessionTest, HandlesInvalidTypeMessage) { - session_->NegotiateMirroring( + session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -437,7 +481,7 @@ TEST_F(SenderSessionTest, HandlesInvalidTypeMessage) { } TEST_F(SenderSessionTest, HandlesErrorMessage) { - session_->NegotiateMirroring( + session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -447,7 +491,7 @@ TEST_F(SenderSessionTest, HandlesErrorMessage) { } TEST_F(SenderSessionTest, DoesNotCrashOnMessagePortError) { - session_->NegotiateMirroring( + session_->Negotiate( std::vector<AudioCaptureConfig>{kAudioCaptureConfigValid}, std::vector<VideoCaptureConfig>{kVideoCaptureConfigValid}); @@ -461,5 +505,60 @@ TEST_F(SenderSessionTest, ReportsZeroBandwidthWhenNoPacketsSent) { EXPECT_EQ(0, session_->GetEstimatedNetworkBandwidth()); } +TEST_F(SenderSessionTest, ComplainsIfInvalidAudioCaptureConfigRemoting) { + const Error error = session_->NegotiateRemoting( + kAudioCaptureConfigInvalidChannels, kVideoCaptureConfigValid); + + EXPECT_EQ(error.code(), Error::Code::kParameterInvalid); +} + +TEST_F(SenderSessionTest, ComplainsIfInvalidVideoCaptureConfigRemoting) { + const Error error = session_->NegotiateRemoting(kAudioCaptureConfigValid, + kVideoCaptureConfigInvalid); + EXPECT_EQ(error.code(), Error::Code::kParameterInvalid); +} + +TEST_F(SenderSessionTest, ComplainsIfMissingResolutionsRemoting) { + const Error error = session_->NegotiateRemoting( + kAudioCaptureConfigValid, kVideoCaptureConfigMissingResolutions); + EXPECT_EQ(error.code(), Error::Code::kParameterInvalid); +} + +TEST_F(SenderSessionTest, HandlesValidAnswerRemoting) { + NegotiateRemotingWithValidConfigs(); + std::string answer = ConstructAnswerFromOffer(CastMode::kRemoting); + + EXPECT_CALL(client_, OnRemotingNegotiated(session_.get(), _)); + message_port_->ReceiveMessage(answer); + message_port_->ReceiveMessage(kCapabilitiesResponse); +} + +TEST_F(SenderSessionTest, SuccessfulRemotingNegotiationYieldsValidObject) { + NegotiateRemotingWithValidConfigs(); + std::string answer = ConstructAnswerFromOffer(CastMode::kRemoting); + + SenderSession::RemotingNegotiation negotiation; + EXPECT_CALL(client_, OnRemotingNegotiated(session_.get(), _)) + .WillOnce(testing::SaveArg<1>(&negotiation)); + message_port_->ReceiveMessage(answer); + message_port_->ReceiveMessage(kCapabilitiesResponse); + + // The capabilities should match the values in |kCapabilitiesResponse|. + EXPECT_THAT(negotiation.capabilities.audio, + testing::ElementsAre(AudioCapability::kBaselineSet, + AudioCapability::kAac)); + + // The "video" capability is ignored since it means nothing. + EXPECT_THAT(negotiation.capabilities.video, + testing::ElementsAre(VideoCapability::kVp8)); + + // The messenger is tested elsewhere, but we can sanity check that we got a valid + // one here. + EXPECT_TRUE(session_->rpc_messenger()); + const RpcMessenger::Handle handle = + session_->rpc_messenger()->GetUniqueHandle(); + EXPECT_NE(RpcMessenger::kInvalidHandle, handle); +} + } // namespace cast } // namespace openscreen diff --git a/cast/streaming/session_messager.cc b/cast/streaming/session_messenger.cc index 31e634d9..a612b68d 100644 --- a/cast/streaming/session_messager.cc +++ b/cast/streaming/session_messenger.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "cast/streaming/session_messager.h" +#include "cast/streaming/session_messenger.h" #include "absl/strings/ascii.h" #include "cast/common/public/message_port.h" @@ -19,12 +19,11 @@ namespace { void ReplyIfTimedOut( int sequence_number, ReceiverMessage::Type reply_type, - std::vector<std::pair<int, SenderSessionMessager::ReplyCallback>>* + std::vector<std::pair<int, SenderSessionMessenger::ReplyCallback>>* replies) { - auto it = replies->begin(); - for (; it != replies->end(); ++it) { + for (auto it = replies->begin(); it != replies->end(); ++it) { if (it->first == sequence_number) { - OSP_DVLOG + OSP_VLOG << "Replying with empty message due to timeout for sequence number: " << sequence_number; it->second(ReceiverMessage{reply_type, sequence_number}); @@ -36,69 +35,69 @@ void ReplyIfTimedOut( } // namespace -SessionMessager::SessionMessager(MessagePort* message_port, - std::string source_id, - ErrorCallback cb) +SessionMessenger::SessionMessenger(MessagePort* message_port, + std::string source_id, + ErrorCallback cb) : message_port_(message_port), error_callback_(std::move(cb)) { OSP_DCHECK(message_port_); OSP_DCHECK(!source_id.empty()); message_port_->SetClient(this, source_id); } -SessionMessager::~SessionMessager() { +SessionMessenger::~SessionMessenger() { message_port_->ResetClient(); } -Error SessionMessager::SendMessage(const std::string& destination_id, - const std::string& namespace_, - const Json::Value& message_root) { +Error SessionMessenger::SendMessage(const std::string& destination_id, + const std::string& namespace_, + const Json::Value& message_root) { OSP_DCHECK(namespace_ == kCastRemotingNamespace || namespace_ == kCastWebrtcNamespace); auto body_or_error = json::Stringify(message_root); if (body_or_error.is_error()) { return std::move(body_or_error.error()); } - OSP_DVLOG << "Sending message: DESTINATION[" << destination_id - << "], NAMESPACE[" << namespace_ << "], BODY:\n" - << body_or_error.value(); + OSP_VLOG << "Sending message: DESTINATION[" << destination_id + << "], NAMESPACE[" << namespace_ << "], BODY:\n" + << body_or_error.value(); message_port_->PostMessage(destination_id, namespace_, body_or_error.value()); return Error::None(); } -void SessionMessager::ReportError(Error error) { +void SessionMessenger::ReportError(Error error) { error_callback_(std::move(error)); } -SenderSessionMessager::SenderSessionMessager(MessagePort* message_port, - std::string source_id, - std::string receiver_id, - ErrorCallback cb, - TaskRunner* task_runner) - : SessionMessager(message_port, std::move(source_id), std::move(cb)), +SenderSessionMessenger::SenderSessionMessenger(MessagePort* message_port, + std::string source_id, + std::string receiver_id, + ErrorCallback cb, + TaskRunner* task_runner) + : SessionMessenger(message_port, std::move(source_id), std::move(cb)), task_runner_(task_runner), receiver_id_(std::move(receiver_id)) {} -void SenderSessionMessager::SetHandler(ReceiverMessage::Type type, - ReplyCallback cb) { +void SenderSessionMessenger::SetHandler(ReceiverMessage::Type type, + ReplyCallback cb) { // Currently the only handler allowed is for RPC messages. OSP_DCHECK(type == ReceiverMessage::Type::kRpc); rpc_callback_ = std::move(cb); } -Error SenderSessionMessager::SendOutboundMessage(SenderMessage message) { +Error SenderSessionMessenger::SendOutboundMessage(SenderMessage message) { const auto namespace_ = (message.type == SenderMessage::Type::kRpc) ? kCastRemotingNamespace : kCastWebrtcNamespace; ErrorOr<Json::Value> jsonified = message.ToJson(); OSP_CHECK(jsonified.is_value()) << "Tried to send an invalid message"; - return SessionMessager::SendMessage(receiver_id_, namespace_, - jsonified.value()); + return SessionMessenger::SendMessage(receiver_id_, namespace_, + jsonified.value()); } -Error SenderSessionMessager::SendRequest(SenderMessage message, - ReceiverMessage::Type reply_type, - ReplyCallback cb) { +Error SenderSessionMessenger::SendRequest(SenderMessage message, + ReceiverMessage::Type reply_type, + ReplyCallback cb) { static constexpr std::chrono::milliseconds kReplyTimeout{4000}; // RPC messages are not meant to be request/reply. OSP_DCHECK(reply_type != ReceiverMessage::Type::kRpc); @@ -108,6 +107,8 @@ Error SenderSessionMessager::SendRequest(SenderMessage message, return error; } + OSP_DCHECK(awaiting_replies_.find(message.sequence_number) == + awaiting_replies_.end()); awaiting_replies_.emplace_back(message.sequence_number, std::move(cb)); task_runner_->PostTaskWithDelay( [self = weak_factory_.GetWeakPtr(), reply_type, @@ -121,9 +122,9 @@ Error SenderSessionMessager::SendRequest(SenderMessage message, return Error::None(); } -void SenderSessionMessager::OnMessage(const std::string& source_id, - const std::string& message_namespace, - const std::string& message) { +void SenderSessionMessenger::OnMessage(const std::string& source_id, + const std::string& message_namespace, + const std::string& message) { if (source_id != receiver_id_) { OSP_DLOG_WARN << "Received message from unknown/incorrect Cast Receiver, " "expected id \"" @@ -145,13 +146,6 @@ void SenderSessionMessager::OnMessage(const std::string& source_id, return; } - int sequence_number; - if (!json::ParseAndValidateInt(message_body.value()[kSequenceNumber], - &sequence_number)) { - OSP_DLOG_WARN << "Received a message without a sequence number"; - return; - } - // If the message is valid JSON and we don't understand it, there are two // options: (1) it's an unknown type, or (2) the receiver filled out the // message incorrectly. In the first case we can drop it, it's likely just @@ -172,6 +166,13 @@ void SenderSessionMessager::OnMessage(const std::string& source_id, OSP_DLOG_INFO << "Received RTP message but no callback, dropping"; } } else { + int sequence_number; + if (!json::TryParseInt(message_body.value()[kSequenceNumber], + &sequence_number)) { + OSP_DLOG_WARN << "Received a message without a sequence number"; + return; + } + auto it = awaiting_replies_.find(sequence_number); if (it == awaiting_replies_.end()) { OSP_DLOG_WARN << "Received a reply I wasn't waiting for: " @@ -179,30 +180,34 @@ void SenderSessionMessager::OnMessage(const std::string& source_id, return; } - it->second(receiver_message.value({})); - awaiting_replies_.erase(it); + it->second(std::move(receiver_message.value({}))); + + // Calling the function callback may result in the checksum of the pointed + // to object to change, so calling erase() on the iterator after executing + // second() may result in a segfault. + awaiting_replies_.erase_key(sequence_number); } } -void SenderSessionMessager::OnError(Error error) { - OSP_DLOG_WARN << "Received an error in the session messager: " << error; +void SenderSessionMessenger::OnError(Error error) { + OSP_DLOG_WARN << "Received an error in the session messenger: " << error; } -ReceiverSessionMessager::ReceiverSessionMessager(MessagePort* message_port, - std::string source_id, - ErrorCallback cb) - : SessionMessager(message_port, std::move(source_id), std::move(cb)) {} +ReceiverSessionMessenger::ReceiverSessionMessenger(MessagePort* message_port, + std::string source_id, + ErrorCallback cb) + : SessionMessenger(message_port, std::move(source_id), std::move(cb)) {} -void ReceiverSessionMessager::SetHandler(SenderMessage::Type type, - RequestCallback cb) { +void ReceiverSessionMessenger::SetHandler(SenderMessage::Type type, + RequestCallback cb) { OSP_DCHECK(callbacks_.find(type) == callbacks_.end()); callbacks_.emplace_back(type, std::move(cb)); } -Error ReceiverSessionMessager::SendMessage(ReceiverMessage message) { +Error ReceiverSessionMessenger::SendMessage(ReceiverMessage message) { if (sender_session_id_.empty()) { return Error(Error::Code::kInitializationFailure, - "Tried to send a message without receving one first"); + "Tried to send a message without receiving one first"); } const auto namespace_ = (message.type == ReceiverMessage::Type::kRpc) @@ -211,13 +216,13 @@ Error ReceiverSessionMessager::SendMessage(ReceiverMessage message) { ErrorOr<Json::Value> message_json = message.ToJson(); OSP_CHECK(message_json.is_value()) << "Tried to send an invalid message"; - return SessionMessager::SendMessage(sender_session_id_, namespace_, - message_json.value()); + return SessionMessenger::SendMessage(sender_session_id_, namespace_, + message_json.value()); } -void ReceiverSessionMessager::OnMessage(const std::string& source_id, - const std::string& message_namespace, - const std::string& message) { +void ReceiverSessionMessenger::OnMessage(const std::string& source_id, + const std::string& message_namespace, + const std::string& message) { // We assume we are connected to the first sender_id we receive. if (sender_session_id_.empty()) { sender_session_id_ = source_id; @@ -265,8 +270,8 @@ void ReceiverSessionMessager::OnMessage(const std::string& source_id, } } -void ReceiverSessionMessager::OnError(Error error) { - OSP_DLOG_WARN << "Received an error in the session messager: " << error; +void ReceiverSessionMessenger::OnError(Error error) { + OSP_DLOG_WARN << "Received an error in the session messenger: " << error; } } // namespace cast diff --git a/cast/streaming/session_messager.h b/cast/streaming/session_messenger.h index 99b458ff..97a2564a 100644 --- a/cast/streaming/session_messager.h +++ b/cast/streaming/session_messenger.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef CAST_STREAMING_SESSION_MESSAGER_H_ -#define CAST_STREAMING_SESSION_MESSAGER_H_ +#ifndef CAST_STREAMING_SESSION_MESSENGER_H_ +#define CAST_STREAMING_SESSION_MESSENGER_H_ #include <functional> #include <string> @@ -27,20 +27,20 @@ namespace cast { // A message port interface designed specifically for use by the Receiver // and Sender session classes. -class SessionMessager : public MessagePort::Client { +class SessionMessenger : public MessagePort::Client { public: using ErrorCallback = std::function<void(Error)>; - SessionMessager(MessagePort* message_port, - std::string source_id, - ErrorCallback cb); - ~SessionMessager() override; + SessionMessenger(MessagePort* message_port, + std::string source_id, + ErrorCallback cb); + ~SessionMessenger() override; protected: // Barebones message sending method shared by both children. - Error SendMessage(const std::string& destination_id, - const std::string& namespace_, - const Json::Value& message_root); + [[nodiscard]] Error SendMessage(const std::string& destination_id, + const std::string& namespace_, + const Json::Value& message_root); // Used to report errors in subclasses. void ReportError(Error error); @@ -50,15 +50,15 @@ class SessionMessager : public MessagePort::Client { ErrorCallback error_callback_; }; -class SenderSessionMessager final : public SessionMessager { +class SenderSessionMessenger final : public SessionMessenger { public: using ReplyCallback = std::function<void(ReceiverMessage)>; - SenderSessionMessager(MessagePort* message_port, - std::string source_id, - std::string receiver_id, - ErrorCallback cb, - TaskRunner* task_runner); + SenderSessionMessenger(MessagePort* message_port, + std::string source_id, + std::string receiver_id, + ErrorCallback cb, + TaskRunner* task_runner); // Set receiver message handler. Note that this should only be // applied for messages that don't have sequence numbers, like RPC @@ -80,7 +80,7 @@ class SenderSessionMessager final : public SessionMessager { private: TaskRunner* const task_runner_; - // This messager should only be connected to one receiver, so |receiver_id_| + // This messenger should only be connected to one receiver, so |receiver_id_| // should not change. const std::string receiver_id_; @@ -93,15 +93,15 @@ class SenderSessionMessager final : public SessionMessager { // a flatmap here. ReplyCallback rpc_callback_; - WeakPtrFactory<SenderSessionMessager> weak_factory_{this}; + WeakPtrFactory<SenderSessionMessenger> weak_factory_{this}; }; -class ReceiverSessionMessager final : public SessionMessager { +class ReceiverSessionMessenger final : public SessionMessenger { public: using RequestCallback = std::function<void(SenderMessage)>; - ReceiverSessionMessager(MessagePort* message_port, - std::string source_id, - ErrorCallback cb); + ReceiverSessionMessenger(MessagePort* message_port, + std::string source_id, + ErrorCallback cb); // Set sender message handler. void SetHandler(SenderMessage::Type type, RequestCallback cb); @@ -125,4 +125,4 @@ class ReceiverSessionMessager final : public SessionMessager { } // namespace cast } // namespace openscreen -#endif // CAST_STREAMING_SESSION_MESSAGER_H_ +#endif // CAST_STREAMING_SESSION_MESSENGER_H_ diff --git a/cast/streaming/session_messager_unittest.cc b/cast/streaming/session_messenger_unittest.cc index 6f5dd3ea..ce2f1655 100644 --- a/cast/streaming/session_messager_unittest.cc +++ b/cast/streaming/session_messenger_unittest.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "cast/streaming/session_messager.h" +#include "cast/streaming/session_messenger.h" #include "cast/streaming/testing/message_pipe.h" #include "cast/streaming/testing/simple_message_port.h" @@ -24,7 +24,6 @@ constexpr char kReceiverId[] = "receiver-12345"; // simply because it is massive. Offer kExampleOffer{ CastMode::kMirroring, - false, {AudioStream{Stream{0, Stream::Type::kAudioSource, 2, @@ -59,19 +58,19 @@ Offer kExampleOffer{ struct SessionMessageStore { public: - SenderSessionMessager::ReplyCallback GetReplyCallback() { + SenderSessionMessenger::ReplyCallback GetReplyCallback() { return [this](ReceiverMessage message) { receiver_messages.push_back(std::move(message)); }; } - ReceiverSessionMessager::RequestCallback GetRequestCallback() { + ReceiverSessionMessenger::RequestCallback GetRequestCallback() { return [this](SenderMessage message) { sender_messages.push_back(std::move(message)); }; } - SessionMessager::ErrorCallback GetErrorCallback() { + SessionMessenger::ErrorCallback GetErrorCallback() { return [this](Error error) { errors.push_back(std::move(error)); }; } @@ -81,35 +80,33 @@ struct SessionMessageStore { }; } // namespace -class SessionMessagerTest : public ::testing::Test { +class SessionMessengerTest : public ::testing::Test { public: - SessionMessagerTest() + SessionMessengerTest() : clock_{Clock::now()}, task_runner_(&clock_), message_store_(), pipe_(kSenderId, kReceiverId), - receiver_messager_(pipe_.right(), - kReceiverId, - message_store_.GetErrorCallback()), - sender_messager_(pipe_.left(), - kSenderId, - kReceiverId, - message_store_.GetErrorCallback(), - &task_runner_) + receiver_messenger_(pipe_.right(), + kReceiverId, + message_store_.GetErrorCallback()), + sender_messenger_(pipe_.left(), + kSenderId, + kReceiverId, + message_store_.GetErrorCallback(), + &task_runner_) {} void SetUp() override { - sender_messager_.SetHandler(ReceiverMessage::Type::kRpc, - message_store_.GetReplyCallback()); - receiver_messager_.SetHandler(SenderMessage::Type::kOffer, - message_store_.GetRequestCallback()); - receiver_messager_.SetHandler(SenderMessage::Type::kGetStatus, - message_store_.GetRequestCallback()); - receiver_messager_.SetHandler(SenderMessage::Type::kGetCapabilities, - message_store_.GetRequestCallback()); - receiver_messager_.SetHandler(SenderMessage::Type::kRpc, - message_store_.GetRequestCallback()); + sender_messenger_.SetHandler(ReceiverMessage::Type::kRpc, + message_store_.GetReplyCallback()); + receiver_messenger_.SetHandler(SenderMessage::Type::kOffer, + message_store_.GetRequestCallback()); + receiver_messenger_.SetHandler(SenderMessage::Type::kGetCapabilities, + message_store_.GetRequestCallback()); + receiver_messenger_.SetHandler(SenderMessage::Type::kRpc, + message_store_.GetRequestCallback()); } protected: @@ -117,32 +114,34 @@ class SessionMessagerTest : public ::testing::Test { FakeTaskRunner task_runner_; SessionMessageStore message_store_; MessagePipe pipe_; - ReceiverSessionMessager receiver_messager_; - SenderSessionMessager sender_messager_; + ReceiverSessionMessenger receiver_messenger_; + SenderSessionMessenger sender_messenger_; std::vector<Error> receiver_errors_; std::vector<Error> sender_errors_; }; -TEST_F(SessionMessagerTest, RpcMessaging) { - ASSERT_TRUE(sender_messager_ - .SendOutboundMessage(SenderMessage{ - SenderMessage::Type::kRpc, 123, true /* valid */, - std::string("all your base are belong to us")}) - .ok()); +TEST_F(SessionMessengerTest, RpcMessaging) { + static const std::vector<uint8_t> kSenderMessage{1, 2, 3, 4, 5}; + static const std::vector<uint8_t> kReceiverResponse{6, 7, 8, 9}; + ASSERT_TRUE( + sender_messenger_ + .SendOutboundMessage(SenderMessage{SenderMessage::Type::kRpc, 123, + true /* valid */, kSenderMessage}) + .ok()); ASSERT_EQ(1u, message_store_.sender_messages.size()); ASSERT_TRUE(message_store_.receiver_messages.empty()); EXPECT_EQ(SenderMessage::Type::kRpc, message_store_.sender_messages[0].type); ASSERT_TRUE(message_store_.sender_messages[0].valid); - EXPECT_EQ("all your base are belong to us", - absl::get<std::string>(message_store_.sender_messages[0].body)); + EXPECT_EQ(kSenderMessage, absl::get<std::vector<uint8_t>>( + message_store_.sender_messages[0].body)); message_store_.sender_messages.clear(); ASSERT_TRUE( - receiver_messager_ + receiver_messenger_ .SendMessage(ReceiverMessage{ReceiverMessage::Type::kRpc, 123, - true /* valid */, std::string("nuh uh")}) + true /* valid */, kReceiverResponse}) .ok()); ASSERT_TRUE(message_store_.sender_messages.empty()); @@ -150,47 +149,13 @@ TEST_F(SessionMessagerTest, RpcMessaging) { EXPECT_EQ(ReceiverMessage::Type::kRpc, message_store_.receiver_messages[0].type); EXPECT_TRUE(message_store_.receiver_messages[0].valid); - EXPECT_EQ("nuh uh", - absl::get<std::string>(message_store_.receiver_messages[0].body)); -} - -TEST_F(SessionMessagerTest, StatusMessaging) { - ASSERT_TRUE(sender_messager_ - .SendRequest(SenderMessage{SenderMessage::Type::kGetStatus, - 3123, true /* valid */}, - ReceiverMessage::Type::kStatusResponse, - message_store_.GetReplyCallback()) - .ok()); - - ASSERT_EQ(1u, message_store_.sender_messages.size()); - ASSERT_TRUE(message_store_.receiver_messages.empty()); - EXPECT_EQ(SenderMessage::Type::kGetStatus, - message_store_.sender_messages[0].type); - EXPECT_TRUE(message_store_.sender_messages[0].valid); - - message_store_.sender_messages.clear(); - ASSERT_TRUE( - receiver_messager_ - .SendMessage(ReceiverMessage{ - ReceiverMessage::Type::kStatusResponse, 3123, true /* valid */, - ReceiverWifiStatus{-5.7, std::vector<int32_t>{1200, 1300, 1250}}}) - .ok()); - - ASSERT_TRUE(message_store_.sender_messages.empty()); - ASSERT_EQ(1u, message_store_.receiver_messages.size()); - EXPECT_EQ(ReceiverMessage::Type::kStatusResponse, - message_store_.receiver_messages[0].type); - EXPECT_TRUE(message_store_.receiver_messages[0].valid); - - const auto& status = - absl::get<ReceiverWifiStatus>(message_store_.receiver_messages[0].body); - EXPECT_DOUBLE_EQ(-5.7, status.wifi_snr); - EXPECT_THAT(status.wifi_speed, ElementsAre(1200, 1300, 1250)); + EXPECT_EQ(kReceiverResponse, absl::get<std::vector<uint8_t>>( + message_store_.receiver_messages[0].body)); } -TEST_F(SessionMessagerTest, CapabilitiesMessaging) { +TEST_F(SessionMessengerTest, CapabilitiesMessaging) { ASSERT_TRUE( - sender_messager_ + sender_messenger_ .SendRequest(SenderMessage{SenderMessage::Type::kGetCapabilities, 1337, true /* valid */}, ReceiverMessage::Type::kCapabilitiesResponse, @@ -204,10 +169,12 @@ TEST_F(SessionMessagerTest, CapabilitiesMessaging) { EXPECT_TRUE(message_store_.sender_messages[0].valid); message_store_.sender_messages.clear(); - ASSERT_TRUE(receiver_messager_ + ASSERT_TRUE(receiver_messenger_ .SendMessage(ReceiverMessage{ ReceiverMessage::Type::kCapabilitiesResponse, 1337, - true /* valid */, ReceiverCapability{47, {"ac3", "4k"}}}) + true /* valid */, + ReceiverCapability{ + 47, {MediaCapability::kAac, MediaCapability::k4k}}}) .ok()); ASSERT_TRUE(message_store_.sender_messages.empty()); @@ -219,11 +186,12 @@ TEST_F(SessionMessagerTest, CapabilitiesMessaging) { const auto& capability = absl::get<ReceiverCapability>(message_store_.receiver_messages[0].body); EXPECT_EQ(47, capability.remoting_version); - EXPECT_THAT(capability.media_capabilities, ElementsAre("ac3", "4k")); + EXPECT_THAT(capability.media_capabilities, + ElementsAre(MediaCapability::kAac, MediaCapability::k4k)); } -TEST_F(SessionMessagerTest, OfferAnswerMessaging) { - ASSERT_TRUE(sender_messager_ +TEST_F(SessionMessengerTest, OfferAnswerMessaging) { + ASSERT_TRUE(sender_messenger_ .SendRequest(SenderMessage{SenderMessage::Type::kOffer, 42, true /* valid */, kExampleOffer}, ReceiverMessage::Type::kAnswer, @@ -237,7 +205,7 @@ TEST_F(SessionMessagerTest, OfferAnswerMessaging) { EXPECT_TRUE(message_store_.sender_messages[0].valid); message_store_.sender_messages.clear(); - EXPECT_TRUE(receiver_messager_ + EXPECT_TRUE(receiver_messenger_ .SendMessage(ReceiverMessage{ ReceiverMessage::Type::kAnswer, 41, true /* valid */, Answer{1234, {0, 1}, {12344443, 12344445}}}) @@ -246,7 +214,7 @@ TEST_F(SessionMessagerTest, OfferAnswerMessaging) { ASSERT_TRUE(message_store_.sender_messages.empty()); ASSERT_TRUE(message_store_.receiver_messages.empty()); - ASSERT_TRUE(receiver_messager_ + ASSERT_TRUE(receiver_messenger_ .SendMessage(ReceiverMessage{ ReceiverMessage::Type::kAnswer, 42, true /* valid */, Answer{1234, {0, 1}, {12344443, 12344445}}}) @@ -265,8 +233,8 @@ TEST_F(SessionMessagerTest, OfferAnswerMessaging) { EXPECT_THAT(answer.ssrcs, ElementsAre(12344443, 12344445)); } -TEST_F(SessionMessagerTest, OfferAndReceiverError) { - ASSERT_TRUE(sender_messager_ +TEST_F(SessionMessengerTest, OfferAndReceiverError) { + ASSERT_TRUE(sender_messenger_ .SendRequest(SenderMessage{SenderMessage::Type::kOffer, 42, true /* valid */, kExampleOffer}, ReceiverMessage::Type::kAnswer, @@ -280,7 +248,7 @@ TEST_F(SessionMessagerTest, OfferAndReceiverError) { EXPECT_TRUE(message_store_.sender_messages[0].valid); message_store_.sender_messages.clear(); - EXPECT_TRUE(receiver_messager_ + EXPECT_TRUE(receiver_messenger_ .SendMessage(ReceiverMessage{ ReceiverMessage::Type::kAnswer, 42, false /* valid */, ReceiverError{123, "Something real bad happened"}}) @@ -298,43 +266,43 @@ TEST_F(SessionMessagerTest, OfferAndReceiverError) { EXPECT_EQ("Something real bad happened", error.description); } -TEST_F(SessionMessagerTest, UnexpectedMessagesAreIgnored) { - EXPECT_FALSE( - receiver_messager_ - .SendMessage(ReceiverMessage{ - ReceiverMessage::Type::kStatusResponse, 3123, true /* valid */, - ReceiverWifiStatus{-5.7, std::vector<int32_t>{1200, 1300, 1250}}}) - .ok()); +TEST_F(SessionMessengerTest, UnexpectedMessagesAreIgnored) { + EXPECT_FALSE(receiver_messenger_ + .SendMessage(ReceiverMessage{ + ReceiverMessage::Type::kCapabilitiesResponse, 3123, + true /* valid */, + ReceiverCapability{2, {MediaCapability::kH264}}}) + .ok()); // The message gets dropped and thus won't be in the store. EXPECT_TRUE(message_store_.sender_messages.empty()); EXPECT_TRUE(message_store_.receiver_messages.empty()); } -TEST_F(SessionMessagerTest, UnknownSenderMessageTypesDontGetSent) { - EXPECT_DEATH(sender_messager_ +TEST_F(SessionMessengerTest, UnknownSenderMessageTypesDontGetSent) { + EXPECT_DEATH(sender_messenger_ .SendOutboundMessage(SenderMessage{ SenderMessage::Type::kUnknown, 123, true /* valid */}) .ok(), ".*Trying to send an unknown message is a developer error.*"); } -TEST_F(SessionMessagerTest, UnknownReceiverMessageTypesDontGetSent) { - ASSERT_TRUE(sender_messager_ +TEST_F(SessionMessengerTest, UnknownReceiverMessageTypesDontGetSent) { + ASSERT_TRUE(sender_messenger_ .SendRequest(SenderMessage{SenderMessage::Type::kOffer, 42, true /* valid */, kExampleOffer}, ReceiverMessage::Type::kAnswer, message_store_.GetReplyCallback()) .ok()); - EXPECT_DEATH(receiver_messager_ + EXPECT_DEATH(receiver_messenger_ .SendMessage(ReceiverMessage{ReceiverMessage::Type::kUnknown, 3123, true /* valid */}) .ok(), ".*Trying to send an unknown message is a developer error.*"); } -TEST_F(SessionMessagerTest, ReceiverHandlesUnknownMessageType) { +TEST_F(SessionMessengerTest, ReceiverHandlesUnknownMessageType) { pipe_.right()->ReceiveMessage(kCastWebrtcNamespace, R"({ "type": "GET_VIRTUAL_REALITY", "seqNum": 31337 @@ -342,12 +310,12 @@ TEST_F(SessionMessagerTest, ReceiverHandlesUnknownMessageType) { ASSERT_TRUE(message_store_.errors.empty()); } -TEST_F(SessionMessagerTest, SenderHandlesUnknownMessageType) { +TEST_F(SessionMessengerTest, SenderHandlesUnknownMessageType) { // The behavior on the sender side is a little more interesting: we // test elsewhere that messages with the wrong sequence number are ignored, // here if the type is unknown but the message contains a valid sequence // number we just treat it as a bad response/same as a timeout. - ASSERT_TRUE(sender_messager_ + ASSERT_TRUE(sender_messenger_ .SendRequest(SenderMessage{SenderMessage::Type::kOffer, 42, true /* valid */, kExampleOffer}, ReceiverMessage::Type::kAnswer, @@ -365,9 +333,9 @@ TEST_F(SessionMessagerTest, SenderHandlesUnknownMessageType) { ASSERT_EQ(false, message_store_.receiver_messages[0].valid); } -TEST_F(SessionMessagerTest, SenderHandlesMessageMissingSequenceNumber) { +TEST_F(SessionMessengerTest, SenderHandlesMessageMissingSequenceNumber) { ASSERT_TRUE( - sender_messager_ + sender_messenger_ .SendRequest(SenderMessage{SenderMessage::Type::kGetCapabilities, 42, true /* valid */}, ReceiverMessage::Type::kCapabilitiesResponse, @@ -386,21 +354,22 @@ TEST_F(SessionMessagerTest, SenderHandlesMessageMissingSequenceNumber) { ASSERT_TRUE(message_store_.receiver_messages.empty()); } -TEST_F(SessionMessagerTest, ReceiverCannotSendFirst) { - const Error error = receiver_messager_.SendMessage(ReceiverMessage{ - ReceiverMessage::Type::kStatusResponse, 3123, true /* valid */, - ReceiverWifiStatus{-5.7, std::vector<int32_t>{1200, 1300, 1250}}}); +TEST_F(SessionMessengerTest, ReceiverCannotSendFirst) { + const Error error = receiver_messenger_.SendMessage(ReceiverMessage{ + ReceiverMessage::Type::kCapabilitiesResponse, 3123, true /* valid */, + ReceiverCapability{2, {MediaCapability::kAudio}}}); EXPECT_EQ(Error::Code::kInitializationFailure, error.code()); } -TEST_F(SessionMessagerTest, ErrorMessageLoggedIfTimeout) { - ASSERT_TRUE(sender_messager_ - .SendRequest(SenderMessage{SenderMessage::Type::kGetStatus, - 3123, true /* valid */}, - ReceiverMessage::Type::kStatusResponse, - message_store_.GetReplyCallback()) - .ok()); +TEST_F(SessionMessengerTest, ErrorMessageLoggedIfTimeout) { + ASSERT_TRUE( + sender_messenger_ + .SendRequest(SenderMessage{SenderMessage::Type::kGetCapabilities, + 3123, true /* valid */}, + ReceiverMessage::Type::kCapabilitiesResponse, + message_store_.GetReplyCallback()) + .ok()); ASSERT_EQ(1u, message_store_.sender_messages.size()); ASSERT_TRUE(message_store_.receiver_messages.empty()); @@ -409,24 +378,23 @@ TEST_F(SessionMessagerTest, ErrorMessageLoggedIfTimeout) { ASSERT_EQ(1u, message_store_.sender_messages.size()); ASSERT_EQ(1u, message_store_.receiver_messages.size()); EXPECT_EQ(3123, message_store_.receiver_messages[0].sequence_number); - EXPECT_EQ(ReceiverMessage::Type::kStatusResponse, + EXPECT_EQ(ReceiverMessage::Type::kCapabilitiesResponse, message_store_.receiver_messages[0].type); EXPECT_FALSE(message_store_.receiver_messages[0].valid); } -TEST_F(SessionMessagerTest, ReceiverRejectsMessageFromWrongSender) { +TEST_F(SessionMessengerTest, ReceiverRejectsMessageFromWrongSender) { SimpleMessagePort port(kReceiverId); - ReceiverSessionMessager messager(&port, kReceiverId, - message_store_.GetErrorCallback()); - messager.SetHandler(SenderMessage::Type::kGetStatus, - message_store_.GetRequestCallback()); + ReceiverSessionMessenger messenger(&port, kReceiverId, + message_store_.GetErrorCallback()); + messenger.SetHandler(SenderMessage::Type::kGetCapabilities, + message_store_.GetRequestCallback()); // The first message should be accepted since we don't have a set sender_id // yet. port.ReceiveMessage("sender-31337", kCastWebrtcNamespace, R"({ - "get_status": ["wifiSnr", "wifiSpeed"], "seqNum": 820263769, - "type": "GET_STATUS" + "type": "GET_CAPABILITIES" })"); ASSERT_TRUE(message_store_.errors.empty()); ASSERT_EQ(1u, message_store_.sender_messages.size()); @@ -434,9 +402,8 @@ TEST_F(SessionMessagerTest, ReceiverRejectsMessageFromWrongSender) { // The second message should just be ignored. port.ReceiveMessage("sender-42", kCastWebrtcNamespace, R"({ - "get_status": ["wifiSnr"], "seqNum": 1234, - "type": "GET_STATUS" + "type": "GET_CAPABILITIES" })"); ASSERT_TRUE(message_store_.errors.empty()); ASSERT_TRUE(message_store_.sender_messages.empty()); @@ -444,19 +411,18 @@ TEST_F(SessionMessagerTest, ReceiverRejectsMessageFromWrongSender) { // But the third message should be accepted again since it's from the // first sender. port.ReceiveMessage("sender-31337", kCastWebrtcNamespace, R"({ - "get_status": ["wifiSnr", "wifiSpeed"], "seqNum": 820263769, - "type": "GET_STATUS" + "type": "GET_CAPABILITIES" })"); ASSERT_TRUE(message_store_.errors.empty()); ASSERT_EQ(1u, message_store_.sender_messages.size()); } -TEST_F(SessionMessagerTest, SenderRejectsMessageFromWrongSender) { +TEST_F(SessionMessengerTest, SenderRejectsMessageFromWrongSender) { SimpleMessagePort port(kReceiverId); - SenderSessionMessager messager(&port, kSenderId, kReceiverId, - message_store_.GetErrorCallback(), - &task_runner_); + SenderSessionMessenger messenger(&port, kSenderId, kReceiverId, + message_store_.GetErrorCallback(), + &task_runner_); port.ReceiveMessage("receiver-31337", kCastWebrtcNamespace, R"({ "seqNum": 12345, @@ -472,19 +438,18 @@ TEST_F(SessionMessagerTest, SenderRejectsMessageFromWrongSender) { ASSERT_TRUE(message_store_.receiver_messages.empty()); } -TEST_F(SessionMessagerTest, ReceiverRejectsMessagesWithoutHandler) { +TEST_F(SessionMessengerTest, ReceiverRejectsMessagesWithoutHandler) { SimpleMessagePort port(kReceiverId); - ReceiverSessionMessager messager(&port, kReceiverId, - message_store_.GetErrorCallback()); - messager.SetHandler(SenderMessage::Type::kGetStatus, - message_store_.GetRequestCallback()); + ReceiverSessionMessenger messenger(&port, kReceiverId, + message_store_.GetErrorCallback()); + messenger.SetHandler(SenderMessage::Type::kGetCapabilities, + message_store_.GetRequestCallback()); // The first message should be accepted since we don't have a set sender_id // yet. port.ReceiveMessage("sender-31337", kCastWebrtcNamespace, R"({ - "get_status": ["wifiSnr", "wifiSpeed"], "seqNum": 820263769, - "type": "GET_STATUS" + "type": "GET_CAPABILITIES" })"); ASSERT_TRUE(message_store_.errors.empty()); ASSERT_EQ(1u, message_store_.sender_messages.size()); @@ -493,17 +458,17 @@ TEST_F(SessionMessagerTest, ReceiverRejectsMessagesWithoutHandler) { // The second message should be rejected since it doesn't have a handler. port.ReceiveMessage("sender-31337", kCastWebrtcNamespace, R"({ "seqNum": 820263770, - "type": "GET_CAPABILITIES" + "type": "RPC" })"); ASSERT_TRUE(message_store_.errors.empty()); ASSERT_TRUE(message_store_.sender_messages.empty()); } -TEST_F(SessionMessagerTest, SenderRejectsMessagesWithoutHandler) { +TEST_F(SessionMessengerTest, SenderRejectsMessagesWithoutHandler) { SimpleMessagePort port(kReceiverId); - SenderSessionMessager messager(&port, kSenderId, kReceiverId, - message_store_.GetErrorCallback(), - &task_runner_); + SenderSessionMessenger messenger(&port, kSenderId, kReceiverId, + message_store_.GetErrorCallback(), + &task_runner_); port.ReceiveMessage(kReceiverId, kCastWebrtcNamespace, R"({ "seqNum": 12345, @@ -519,7 +484,7 @@ TEST_F(SessionMessagerTest, SenderRejectsMessagesWithoutHandler) { ASSERT_TRUE(message_store_.receiver_messages.empty()); } -TEST_F(SessionMessagerTest, UnknownNamespaceMessagesGetDropped) { +TEST_F(SessionMessengerTest, UnknownNamespaceMessagesGetDropped) { pipe_.right()->ReceiveMessage("urn:x-cast:com.google.cast.virtualreality", R"({ "seqNum": 12345, diff --git a/cast/test/BUILD.gn b/cast/test/BUILD.gn index d37e6af2..c7569236 100644 --- a/cast/test/BUILD.gn +++ b/cast/test/BUILD.gn @@ -33,6 +33,7 @@ if (is_posix && !build_with_chromium) { deps = [ "../../platform", + "../../platform:standalone_impl", "../../testing/util", "../../third_party/abseil", "../../third_party/boringssl", diff --git a/discovery/BUILD.gn b/discovery/BUILD.gn index 11598c8f..dbf0ace8 100644 --- a/discovery/BUILD.gn +++ b/discovery/BUILD.gn @@ -5,20 +5,32 @@ import("//build_overrides/build.gni") import("../testing/libfuzzer/fuzzer_test.gni") -source_set("common") { +source_set("public") { sources = [ "common/config.h", "common/reporting_client.h", + "dnssd/public/dns_sd_instance.cc", + "dnssd/public/dns_sd_instance.h", + "dnssd/public/dns_sd_instance_endpoint.cc", + "dnssd/public/dns_sd_instance_endpoint.h", + "dnssd/public/dns_sd_publisher.h", + "dnssd/public/dns_sd_querier.h", + "dnssd/public/dns_sd_service.h", + "dnssd/public/dns_sd_txt_record.cc", + "dnssd/public/dns_sd_txt_record.h", + "mdns/public/mdns_constants.h", + "mdns/public/mdns_service.cc", + "mdns/public/mdns_service.h", + "public/dns_sd_service_factory.h", + "public/dns_sd_service_publisher.h", + "public/dns_sd_service_watcher.h", ] - + public_deps = [ "../platform" ] deps = [ "../util" ] - - public_deps = [ - "../platform", - "../third_party/abseil", - ] } +# TODO(https://issuetracker.google.com/issues/194234872): +# Move implementation files to impl/ source_set("mdns") { sources = [ "mdns/mdns_domain_confirmed_provider.h", @@ -47,21 +59,16 @@ source_set("mdns") { "mdns/mdns_trackers.h", "mdns/mdns_writer.cc", "mdns/mdns_writer.h", - "mdns/public/mdns_constants.h", - "mdns/public/mdns_service.cc", - "mdns/public/mdns_service.h", ] - deps = [ "../util" ] - - public_deps = [ - ":common", + public_deps = [ "../third_party/abseil" ] + deps = [ + ":public", "../platform", - "../third_party/abseil", + "../util", ] } -# TODO(issuetracker.google.com/179705382): Separate out a public target. source_set("dnssd") { sources = [ "dnssd/impl/conversion_layer.cc", @@ -82,34 +89,12 @@ source_set("dnssd") { "dnssd/impl/service_instance.h", "dnssd/impl/service_key.cc", "dnssd/impl/service_key.h", - "dnssd/public/dns_sd_instance.cc", - "dnssd/public/dns_sd_instance.h", - "dnssd/public/dns_sd_instance_endpoint.cc", - "dnssd/public/dns_sd_instance_endpoint.h", - "dnssd/public/dns_sd_publisher.h", - "dnssd/public/dns_sd_querier.h", - "dnssd/public/dns_sd_service.h", - "dnssd/public/dns_sd_txt_record.cc", - "dnssd/public/dns_sd_txt_record.h", ] - public_deps = [ - ":common", + deps = [ ":mdns", - "../util", - ] -} - -source_set("public") { - sources = [ - "public/dns_sd_service_factory.h", - "public/dns_sd_service_publisher.h", - "public/dns_sd_service_watcher.h", - ] - - public_deps = [ - ":common", - ":dnssd", + ":public", + "../third_party/abseil", "../util", ] } @@ -133,8 +118,9 @@ source_set("testing") { sources += [ "mdns/testing/hash_test_util_abseil.h" ] } - public_deps = [ + deps = [ ":mdns", + ":public", "../third_party/abseil", "../third_party/googletest:gmock", "../third_party/googletest:gtest", @@ -185,7 +171,10 @@ source_set("unittests") { openscreen_fuzzer_test("mdns_fuzzer") { sources = [ "mdns/mdns_reader_fuzztest.cc" ] - deps = [ ":mdns" ] + deps = [ + ":mdns", + ":public", + ] seed_corpus = "mdns/fuzzer_seeds" diff --git a/discovery/DEPS b/discovery/DEPS index de7afcec..4d758dcd 100644 --- a/discovery/DEPS +++ b/discovery/DEPS @@ -4,6 +4,6 @@ include_rules = [ # Intra-discovery dependencies must be explicit. '-discovery', - # All discovery code can use discovery/common + # All discovery code can use discovery/common. '+discovery/common', ] diff --git a/discovery/common/config.h b/discovery/common/config.h index b1ef731a..940001ce 100644 --- a/discovery/common/config.h +++ b/discovery/common/config.h @@ -14,28 +14,13 @@ namespace discovery { // This struct provides parameters needed to initialize the discovery pipeline. struct Config { - struct NetworkInfo { - enum AddressFamilies : uint8_t { - kNoAddressFamily = 0, - kUseIpV4 = 0x01 << 0, - kUseIpV6 = 0x01 << 1 - }; - - // Network Interface on which discovery should be run. - InterfaceInfo interface; - - // IP Address Families supported by this network interface and on which the - // mDNS Service should listen for and/or publish records. - AddressFamilies supported_address_families; - }; - /***************************************** * Common Settings *****************************************/ // Interfaces on which services should be published, and on which discovery // should listen for announced service instances. - std::vector<NetworkInfo> network_info; + std::vector<InterfaceInfo> network_info; // Maximum allowed size in bytes for the rdata in an incoming record. All // received records with rdata size exceeding this size will be dropped. @@ -98,32 +83,6 @@ struct Config { bool ignore_nsec_responses = false; }; -inline Config::NetworkInfo::AddressFamilies operator&( - Config::NetworkInfo::AddressFamilies lhs, - Config::NetworkInfo::AddressFamilies rhs) { - return static_cast<Config::NetworkInfo::AddressFamilies>( - static_cast<uint8_t>(lhs) & static_cast<uint8_t>(rhs)); -} - -inline Config::NetworkInfo::AddressFamilies operator|( - Config::NetworkInfo::AddressFamilies lhs, - Config::NetworkInfo::AddressFamilies rhs) { - return static_cast<Config::NetworkInfo::AddressFamilies>( - static_cast<uint8_t>(lhs) | static_cast<uint8_t>(rhs)); -} - -inline Config::NetworkInfo::AddressFamilies operator|=( - Config::NetworkInfo::AddressFamilies& lhs, - Config::NetworkInfo::AddressFamilies rhs) { - return lhs = lhs | rhs; -} - -inline Config::NetworkInfo::AddressFamilies operator&=( - Config::NetworkInfo::AddressFamilies& lhs, - Config::NetworkInfo::AddressFamilies rhs) { - return lhs = lhs & rhs; -} - } // namespace discovery } // namespace openscreen diff --git a/discovery/dnssd/impl/DEPS b/discovery/dnssd/impl/DEPS index 243d363d..57d73c16 100644 --- a/discovery/dnssd/impl/DEPS +++ b/discovery/dnssd/impl/DEPS @@ -2,5 +2,13 @@ include_rules = [ '+discovery/dnssd/public', - '+discovery/mdns', + '+discovery/mdns/public', + + # TODO(https://issuetracker.google.com/issues/194234872): + # Move these to discovery/mdns/public + '+discovery/mdns/mdns_domain_confirmed_provider.h', + '+discovery/mdns/mdns_record_changed_callback.h', + '+discovery/mdns/mdns_records.h', + + '+discovery/mdns/testing/mdns_test_util.h', ] diff --git a/discovery/dnssd/impl/service_instance.cc b/discovery/dnssd/impl/service_instance.cc index 7d4b014f..d923eef9 100644 --- a/discovery/dnssd/impl/service_instance.cc +++ b/discovery/dnssd/impl/service_instance.cc @@ -16,29 +16,15 @@ namespace discovery { ServiceInstance::ServiceInstance(TaskRunner* task_runner, ReportingClient* reporting_client, const Config& config, - const Config::NetworkInfo& network_info) + const InterfaceInfo& network_info) : task_runner_(task_runner), mdns_service_(MdnsService::Create(task_runner, reporting_client, config, network_info)), - network_config_(network_info.interface.index, - (network_info.supported_address_families & - Config::NetworkInfo::kUseIpV4) - ? network_info.interface.GetIpAddressV4() - : IPAddress{}, - (network_info.supported_address_families & - Config::NetworkInfo::kUseIpV6) - ? network_info.interface.GetIpAddressV6() - : IPAddress{}) { - const Config::NetworkInfo::AddressFamilies supported_address_families = - network_info.supported_address_families; - - OSP_DCHECK(!(supported_address_families & Config::NetworkInfo::kUseIpV4) || - network_config_.HasAddressV4()); - OSP_DCHECK(!(supported_address_families & Config::NetworkInfo::kUseIpV6) || - network_config_.HasAddressV6()); - + network_config_(network_info.index, + network_info.GetIpAddressV4(), + network_info.GetIpAddressV6()) { if (config.enable_querying) { querier_ = std::make_unique<QuerierImpl>( mdns_service_.get(), task_runner_, reporting_client, &network_config_); diff --git a/discovery/dnssd/impl/service_instance.h b/discovery/dnssd/impl/service_instance.h index e06ca564..798a17b0 100644 --- a/discovery/dnssd/impl/service_instance.h +++ b/discovery/dnssd/impl/service_instance.h @@ -26,9 +26,9 @@ class ServiceInstance final : public DnsSdService { ServiceInstance(TaskRunner* task_runner, ReportingClient* reporting_client, const Config& config, - const Config::NetworkInfo& network_info); + const InterfaceInfo& network_info); ServiceInstance(const ServiceInstance& other) = delete; - ServiceInstance(ServiceInstance&& other) = delete; + ServiceInstance(ServiceInstance&& other) noexcept = delete; ~ServiceInstance() override; ServiceInstance& operator=(const ServiceInstance& other) = delete; diff --git a/discovery/dnssd/public/DEPS b/discovery/dnssd/public/DEPS new file mode 100644 index 00000000..e8ae0cbe --- /dev/null +++ b/discovery/dnssd/public/DEPS @@ -0,0 +1,6 @@ +# -*- Mode: Python; -*- + +include_rules = [ + # Layering rule. + '-discovery/dnssd/impl', +] diff --git a/discovery/dnssd/public/dns_sd_publisher.h b/discovery/dnssd/public/dns_sd_publisher.h index 3c139b4e..10eb03ad 100644 --- a/discovery/dnssd/public/dns_sd_publisher.h +++ b/discovery/dnssd/public/dns_sd_publisher.h @@ -19,7 +19,6 @@ class DnsSdPublisher { public: class Client { public: - virtual ~Client() = default; // Callback called when an endpoint is successfully claimed and published // via the Register() method. These values are expected to only differ in @@ -29,6 +28,9 @@ class DnsSdPublisher { virtual void OnEndpointClaimed( const DnsSdInstance& requested_instance, const DnsSdInstanceEndpoint& claimed_endpoint) = 0; + + protected: + virtual ~Client() = default; }; virtual ~DnsSdPublisher() = default; diff --git a/discovery/mdns/DEPS b/discovery/mdns/DEPS index 309d03f4..c0348a87 100644 --- a/discovery/mdns/DEPS +++ b/discovery/mdns/DEPS @@ -2,4 +2,6 @@ include_rules = [ '+discovery/mdns/public', + # DNS-SD is layered on top of mDNS. + '-discovery/dnssd', ] diff --git a/discovery/mdns/mdns_service_impl.cc b/discovery/mdns/mdns_service_impl.cc index 6d94c3c7..bbbf0815 100644 --- a/discovery/mdns/mdns_service_impl.cc +++ b/discovery/mdns/mdns_service_impl.cc @@ -20,7 +20,7 @@ std::unique_ptr<MdnsService> MdnsService::Create( TaskRunner* task_runner, ReportingClient* reporting_client, const Config& config, - const Config::NetworkInfo& network_info) { + const InterfaceInfo& network_info) { return std::make_unique<MdnsServiceImpl>( task_runner, Clock::now, reporting_client, config, network_info); } @@ -29,22 +29,21 @@ MdnsServiceImpl::MdnsServiceImpl(TaskRunner* task_runner, ClockNowFunctionPtr now_function, ReportingClient* reporting_client, const Config& config, - const Config::NetworkInfo& network_info) + const InterfaceInfo& network_info) : task_runner_(task_runner), now_function_(now_function), reporting_client_(reporting_client), receiver_(config), - interface_(network_info.interface.index) { + interface_(network_info.index) { OSP_DCHECK(task_runner_); OSP_DCHECK(reporting_client_); - OSP_DCHECK(network_info.supported_address_families); // Create all UDP sockets needed for this object. They should not yet be bound // so that they do not send or receive data until the objects on which their // callback depends is initialized. // NOTE: we bind to the Any addresses here because traffic is filtered by // the multicast join calls. - if (network_info.supported_address_families & Config::NetworkInfo::kUseIpV4) { + if (network_info.GetIpAddressV4()) { ErrorOr<std::unique_ptr<UdpSocket>> socket = UdpSocket::Create( task_runner, this, IPEndpoint{IPAddress::kAnyV4(), kDefaultMulticastPort}); @@ -55,7 +54,7 @@ MdnsServiceImpl::MdnsServiceImpl(TaskRunner* task_runner, socket_v4_ = std::move(socket.value()); } - if (network_info.supported_address_families & Config::NetworkInfo::kUseIpV6) { + if (network_info.GetIpAddressV6()) { ErrorOr<std::unique_ptr<UdpSocket>> socket = UdpSocket::Create( task_runner, this, IPEndpoint{IPAddress::kAnyV6(), kDefaultMulticastPort}); diff --git a/discovery/mdns/mdns_service_impl.h b/discovery/mdns/mdns_service_impl.h index e1c15226..523f078f 100644 --- a/discovery/mdns/mdns_service_impl.h +++ b/discovery/mdns/mdns_service_impl.h @@ -40,7 +40,7 @@ class MdnsServiceImpl : public MdnsService, public UdpSocket::Client { ClockNowFunctionPtr now_function, ReportingClient* reporting_client, const Config& config, - const Config::NetworkInfo& network_info); + const InterfaceInfo& network_info); ~MdnsServiceImpl() override; // MdnsService Overrides. diff --git a/discovery/mdns/public/DEPS b/discovery/mdns/public/DEPS new file mode 100644 index 00000000..5b65c0e4 --- /dev/null +++ b/discovery/mdns/public/DEPS @@ -0,0 +1,8 @@ +# -*- Mode: Python; -*- +include_rules = [ + # Layering rule. + '-discovery/mdns', + # Except ourselves. + '+discovery/mdns/public', +] + diff --git a/discovery/mdns/public/mdns_service.h b/discovery/mdns/public/mdns_service.h index 03e58008..76a8f05d 100644 --- a/discovery/mdns/public/mdns_service.h +++ b/discovery/mdns/public/mdns_service.h @@ -34,11 +34,10 @@ class MdnsService { // Creates a new MdnsService instance, to be owned by the caller. On failure, // returns nullptr. |task_runner|, |reporting_client|, and |config| must exist // for the duration of the resulting instance's life. - static std::unique_ptr<MdnsService> Create( - TaskRunner* task_runner, - ReportingClient* reporting_client, - const Config& config, - const Config::NetworkInfo& network_info); + static std::unique_ptr<MdnsService> Create(TaskRunner* task_runner, + ReportingClient* reporting_client, + const Config& config, + const InterfaceInfo& network_info); // Starts an mDNS query with the given properties. Updated records are passed // to |callback|. The caller must ensure |callback| remains alive while it is diff --git a/docs/advanced_gerrit.md b/docs/advanced_gerrit.md index 790ee359..7b14d922 100644 --- a/docs/advanced_gerrit.md +++ b/docs/advanced_gerrit.md @@ -26,58 +26,18 @@ following command: chmod a+x .git/hooks/commit-msg ``` -### Uploading a new patch for review +### Uploading a new patch for review -You should run `PRESUBMIT.sh` in the root of the repository before pushing for +You should run `git cl presubmit --upload` in the root of the repository before pushing for review (which primarily checks formatting). -There is official [Gerrit -documentation](https://gerrit-documentation.storage.googleapis.com/Documentation/2.14.7/user-upload.html#push_create) -for this which essentially amounts to: +After verifying that presubmission works correctly, you can then execute: +`git cl upload`, which will prompt you to verify the commit message and check +for owners. -``` bash - git push origin HEAD:refs/for/master -``` - -Gerrit keeps track of changes using a [Change-Id -line](https://gerrit-documentation.storage.googleapis.com/Documentation/2.14.7/user-changeid.html) -in each commit. - -When there is no `Change-Id` line, Gerrit creates a new `Change-Id` for the -commit, and therefore a new change. Gerrit's documentation for -[replacing a change](https://gerrit-documentation.storage.googleapis.com/Documentation/2.14.7/user-upload.html#push_replace) -describes this. So if you want to upload a new patchset to an existing review, -it should contain the matching `Change-Id` line in the commit message. - -### Adding a new patchset to an existing change - -By default, each commit to your local branch will get its own Gerrit change when -pushed, unless it has a `Change-Id` corresponding to an existing review. - -If you need to modify commits on your local branch to ensure they have the -correct `Change-Id`, you can do one of two things: - -After committing to the local branch, run: - -```bash - git commit --amend - git show -``` - -to attach the current `Change-Id` to the most recent commit. Check that the -correct one was inserted by comparing it with the one shown on -`chromium-review.googlesource.com` for the existing review. - -If you have made multiple local commits, you can squash them all into a single -commit with the correct Change-Id: - -```bash - git rebase -i HEAD~4 - git show -``` - -where '4' means that you want to squash three additional commits onto an -existing commit that has been uploaded for review. +The first time you upload an issue, the issue number is associated with the +current branch. If you upload again, it uploads on the same issue (which is tied +to the branch, not the commit). See the [git-cl](https://chromium.googlesource.com/chromium/tools/depot_tools.git/+/HEAD/README.git-cl.md) documentation for more information. ## Uploading a new dependent change diff --git a/docs/code_coverage.md b/docs/code_coverage.md new file mode 100644 index 00000000..1c181ee8 --- /dev/null +++ b/docs/code_coverage.md @@ -0,0 +1,41 @@ +# Code Coverage + +Code coverage can be checked using clang's source-based coverage tools. You +must use the GN argument `use_coverage=true`. It's recommended to do this in a +separate output directory since the added instrumentation will affect +performance and generate an output file every time a binary is run. You can +read more about this in [clang's documentation]( +http://clang.llvm.org/docs/SourceBasedCodeCoverage.html) but the +bare minimum steps are also outlined below. You will also need to download the +pre-built clang coverage tools, which are not downloaded by default. The +easiest way to do this is to set a custom variable in your `.gclient` file. +Under the "openscreen" solution, add: +```python + "custom_vars": { + "checkout_clang_coverage_tools": True, + }, +``` +then run `gclient runhooks`. You can also run the python command from the +`clang_coverage_tools` hook in `//DEPS` yourself or even download the tools +manually +([link](https://storage.googleapis.com/chromium-browser-clang-staging/)). + +Once you have your GN directory (we'll call it `out/coverage`) and have +downloaded the tools, do the following to generate an HTML coverage report: +```bash +out/coverage/openscreen_unittests +third_party/llvm-build/Release+Asserts/bin/llvm-profdata merge -sparse default.profraw -o foo.profdata +third_party/llvm-build/Release+Asserts/bin/llvm-cov show out/coverage/openscreen_unittests -instr-profile=foo.profdata -format=html -output-dir=<out dir> [filter paths] +``` +There are a few things to note here: + - `default.profraw` is generated by running the instrumented code, but + `foo.profdata` can be any path you want. + - `<out dir>` should be an empty directory for placing the generated HTML + files. You can view the report at `<out dir>/index.html`. + - `[filter paths]` is a list of paths to which you want to limit the coverage + report. For example, you may want to limit it to cast/ or even + cast/streaming/. If this list is empty, all data will be in the report. + +The same process can be used to check the coverage of a fuzzer's corpus. Just +add `-runs=0` to the fuzzer arguments to make sure it only runs the existing +corpus then exits. diff --git a/docs/continuous_build.md b/docs/continuous_build.md new file mode 100644 index 00000000..53c68e21 --- /dev/null +++ b/docs/continuous_build.md @@ -0,0 +1,28 @@ +# Continuous build and try jobs + +Open Screen uses [LUCI builders](https://ci.chromium.org/p/openscreen/builders) +to monitor the build and test health of the library. + +Current builders include: + +| Name | Arch | OS | Toolchain | Build | Notes | +|------------------------|--------|------------------------|-----------|---------|------------------------| +| linux64_debug | x86-64 | Ubuntu Linux 18.04 | clang | debug | ASAN enabled | +| linux_arm64_debug | arm64 | Ubuntu Linux 20.04 [*] | clang | debug | | +| linux64_gcc_debug | x86-64 | Ubuntu Linux 18.04 | gcc-7 | debug | | +| linux64_tsan | x86-64 | Ubuntu Linux 18.04 | clang | release | TSAN enabled | +| linux64_coverage_debug | x86-64 | Ubuntu Linux 18.04 | clang | debug | used for code coverage | +| linux64_cast_e2e | x86-64 | Ubuntu Linux 18.04 | clang | debug | Builds cast standalone | +| mac_debug | x86-64 | Mac OS X/Xcode | clang | debug | | +| chromium_linux64_debug | x86-64 | Ubuntu Linux 18.04 | clang | debug | built with chromium | +| chromium_mac_debug | x86-64 | Mac OS X 10.15 | clang | debug | built with chromium | +<br /> + +[*] Tests run on Ubuntu 20.04, but are cross-compiled to arm64 with a debian stretch sysroot. + +The chromium_ builders compile against Chromium top-of-tree to ensure that +changes can be autorolled into Chromium. + +You can run a patch through all builders using `git cl try` or the Gerrit Web +interface. All builders are run as part of the commit queue and are also run +continuously in our CI. diff --git a/docs/fuzzing.md b/docs/fuzzing.md new file mode 100644 index 00000000..427165b7 --- /dev/null +++ b/docs/fuzzing.md @@ -0,0 +1,19 @@ +# Building and running fuzzers + +In order to build fuzzers, you need the GN arg `use_libfuzzer=true`. It's also +recommended to build with `is_asan=true` to catch additional problems. Building +and running then might look like: +```bash + gn gen out/libfuzzer --args="use_libfuzzer=true is_asan=true is_debug=false" + ninja -C out/libfuzzer some_fuzz_target + out/libfuzzer/some_fuzz_target <args> <corpus_dir> [additional corpus dirs] +``` + +The arguments to the fuzzer binary should be whatever is listed in the GN target +description (e.g. `-max_len=1500`). These arguments may be automatically +scraped by Chromium's ClusterFuzz tool when it runs fuzzers, but they are not +built into the target. You can also look at the file +`out/libfuzzer/some_fuzz_target.options` for what arguments should be used. The +`corpus_dir` is listed as `seed_corpus` in the GN definition of the fuzzer +target. + diff --git a/docs/raspberry_pi.md b/docs/raspberry_pi.md new file mode 100644 index 00000000..b61315cd --- /dev/null +++ b/docs/raspberry_pi.md @@ -0,0 +1,25 @@ +# Working with ARM/ARM64/the Raspberry PI + +Open Screen Library supports cross compilation for both arm32 and arm64 +platforms, by using the `gn args` parameter `target_cpu="arm"` or +`target_cpu="arm64"` respectively. Note that quotes are required around the +target arch value. + +Setting an arm(64) target_cpu causes GN to pull down a sysroot from openscreen's +public cloud storage bucket. Google employees may update the sysroots stored +by requesting access to the Open Screen pantheon project and uploading a new +tar.xz to the openscreen-sysroots bucket. + +NOTE: The "arm" image is taken from Chromium's debian arm image, however it has +been manually patched to include support for libavcodec and libsdl2. To update +this image, the new image must be manually patched to include the necessary +header and library dependencies. Note that if the versions of libavcodec and +libsdl2 are too out of sync from the copies in the sysroot, compilation will +succeed, but you may experience issues decoding content. + +To install the last known good version of the libavcodec and libsdl packages +on a Raspberry Pi, you can run the following command: + +```bash +sudo ./cast/standalone_receiver/install_demo_deps_raspian.sh +``` diff --git a/docs/style_guide.md b/docs/style_guide.md index dfc6ab1a..2174411e 100644 --- a/docs/style_guide.md +++ b/docs/style_guide.md @@ -1,6 +1,6 @@ # Open Screen Library Style Guide -The Open Screen Library follows the [Chromium C++ coding style](https://chromium.googlesource.com/chromium/src/+/master/styleguide/c++/c++.md) +The Open Screen Library follows the [Chromium C++ coding style](https://chromium.googlesource.com/chromium/src/+/main/styleguide/c++/c++.md) which, in turn, defers to the [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). We also follow the [Chromium C++ Do's and Don'ts](https://sites.google.com/a/chromium.org/dev/developers/coding-style/cpp-dos-and-donts). @@ -188,4 +188,4 @@ not explicitly sprinkle "#if OSP_DCHECK_IS_ON()" guards all around any functions, variables, etc. that will be unused in "DCHECK off" builds. Use OSP_DCHECK and OSP_CHECK in accordance with the -[Chromium guidance for DCHECK/CHECK](https://chromium.googlesource.com/chromium/src/+/master/styleguide/c++/c++.md#check_dcheck_and-notreached). +[Chromium guidance for DCHECK/CHECK](https://chromium.googlesource.com/chromium/src/+/main/styleguide/c++/c++.md#check_dcheck_and-notreached). diff --git a/infra/config/global/commit-queue.cfg b/infra/config/global/commit-queue.cfg index 22a5b113..4f19b682 100644 --- a/infra/config/global/commit-queue.cfg +++ b/infra/config/global/commit-queue.cfg @@ -53,6 +53,10 @@ config_groups { name: "openscreen/try/linux64_coverage_debug" experiment_percentage: 100 } + builders { + name: "openscreen/try/linux64_cast_e2e" + experiment_percentage: 100 + } retry_config { single_quota: 1 global_quota: 2 diff --git a/infra/config/global/cr-buildbucket.cfg b/infra/config/global/cr-buildbucket.cfg index 96063359..5ff8b9ad 100644 --- a/infra/config/global/cr-buildbucket.cfg +++ b/infra/config/global/cr-buildbucket.cfg @@ -77,12 +77,18 @@ builder_mixins { } builder_mixins { - name: "linux" - dimensions: "os:Ubuntu-16.04" + name: "cast_standalone" + recipe { + properties_j: "have_ffmpeg:true" + properties_j: "have_libsdl2:true" + properties_j: "have_libopus:true" + properties_j: "have_libvpx:true" + properties_j: "cast_allow_developer_certificate:true" + } } builder_mixins { - name: "linux1804" + name: "linux" dimensions: "os:Ubuntu-18.04" } @@ -193,16 +199,6 @@ buckets { } EOF properties_j: <<EOF - $recipe_engine/isolated: { - "server": "https://isolateserver.appspot.com" - } - EOF - properties_j: <<EOF - $recipe_engine/cas: { - "instance": "chromium-swarm" - } - EOF - properties_j: <<EOF $recipe_engine/swarming: { "server": "https://chromium-swarm.appspot.com" } @@ -223,7 +219,7 @@ buckets { builders { name: "linux64_gcc_debug" - mixins: "linux1804" + mixins: "linux" mixins: "debug" mixins: "x64" mixins: "gcc" @@ -290,6 +286,15 @@ buckets { mixins: "ci" mixins: "goma_rbe_ats" } + + builders { + name: "linux64_cast_e2e" + mixins: "linux" + mixins: "debug" + mixins: "x64" + mixins: "cast_standalone" + mixins: "goma_rbe_ats" + } } } @@ -311,16 +316,6 @@ buckets: { } EOF properties_j: <<EOF - $recipe_engine/isolated: { - "server": "https://isolateserver.appspot.com" - } - EOF - properties_j: <<EOF - $recipe_engine/cas: { - "instance": "chromium-swarm" - } - EOF - properties_j: <<EOF $recipe_engine/swarming: { "server": "https://chromium-swarm.appspot.com" } @@ -340,7 +335,7 @@ buckets: { builders { name: "linux64_gcc_debug" - mixins: "linux1804" + mixins: "linux" mixins: "debug" mixins: "x64" mixins: "gcc" @@ -411,6 +406,14 @@ buckets: { mixins: "code_coverage" mixins: "goma_rbe_ats" } + + builders { + name: "linux64_cast_e2e" + mixins: "linux" + mixins: "debug" + mixins: "x64" + mixins: "cast_standalone" + mixins: "goma_rbe_ats" + } } } - diff --git a/infra/config/global/luci-milo.cfg b/infra/config/global/luci-milo.cfg index 70df9ce0..c0445bb6 100644 --- a/infra/config/global/luci-milo.cfg +++ b/infra/config/global/luci-milo.cfg @@ -54,6 +54,12 @@ consoles { category: "linux|x64" short_name: "coverage" } + + builders { + name: "buildbucket/luci.openscreen.ci/linux64_cast_e2e" + category: "linux|x64" + short_name: "cast" + } } consoles { @@ -110,4 +116,10 @@ consoles { category: "linux|x64" short_name: "coverage" } + + builders { + name: "buildbucket/luci.openscreen.ci/linux64_cast_e2e" + category: "linux|x64" + short_name: "cast" + } } diff --git a/infra/config/global/luci-scheduler.cfg b/infra/config/global/luci-scheduler.cfg index 13612af8..0ec867b9 100644 --- a/infra/config/global/luci-scheduler.cfg +++ b/infra/config/global/luci-scheduler.cfg @@ -28,6 +28,7 @@ trigger { triggers: "linux_arm64_debug" triggers: "mac_debug" triggers: "linux64_coverage_debug" + triggers: "linux64_cast_e2e" } trigger { @@ -121,3 +122,13 @@ job { builder: "linux64_coverage_debug" } } + +job { + id: "linux64_cast_e2e" + acl_sets: "default" + buildbucket: { + server: "cr-buildbucket.appspot.com" + bucket: "luci.openscreen.ci" + builder: "linux64_cast_e2e" + } +} diff --git a/osp/BUILD.gn b/osp/BUILD.gn index 14bb1ae1..5a2feb44 100644 --- a/osp/BUILD.gn +++ b/osp/BUILD.gn @@ -48,20 +48,13 @@ source_set("unittests") { "public", "public:test_support", ] - - if (use_mdns_responder) { - sources += [ "impl/mdns_responder_service_unittest.cc" ] - - deps += [ "impl/testing" ] - } } -if (use_chromium_quic && use_mdns_responder) { +if (use_chromium_quic) { executable("osp_demo") { sources = [ "demo/osp_demo.cc" ] deps = [ ":osp_with_chromium_quic", - "//osp/impl/discovery/mdns", "//platform", "//util", ] diff --git a/osp/build/config/services.gni b/osp/build/config/services.gni index 1d3d3466..808c123b 100644 --- a/osp/build/config/services.gni +++ b/osp/build/config/services.gni @@ -5,9 +5,7 @@ import("//build_overrides/build.gni") use_chromium_quic = true -use_mdns_responder = true if (build_with_chromium) { use_chromium_quic = false - use_mdns_responder = false } diff --git a/osp/demo/osp_demo.cc b/osp/demo/osp_demo.cc index 952a925f..93f593fa 100644 --- a/osp/demo/osp_demo.cc +++ b/osp/demo/osp_demo.cc @@ -16,7 +16,6 @@ #include "absl/strings/string_view.h" #include "osp/msgs/osp_messages.h" #include "osp/public/mdns_service_listener_factory.h" -#include "osp/public/mdns_service_publisher_factory.h" #include "osp/public/message_demuxer.h" #include "osp/public/network_service_manager.h" #include "osp/public/presentation/presentation_controller.h" @@ -27,6 +26,7 @@ #include "osp/public/protocol_connection_server_factory.h" #include "osp/public/service_listener.h" #include "osp/public/service_publisher.h" +#include "osp/public/service_publisher_factory.h" #include "platform/api/network_interface.h" #include "platform/api/time.h" #include "platform/impl/logging.h" @@ -152,7 +152,9 @@ class DemoPublisherObserver final : public ServicePublisher::Observer { void OnStopped() override { OSP_LOG_INFO << "publisher stopped!"; } void OnSuspended() override { OSP_LOG_INFO << "publisher suspended!"; } - void OnError(ServicePublisherError) override {} + void OnError(Error error) override { + OSP_LOG_ERROR << "publisher error: " << error; + } void OnMetrics(ServicePublisher::Metrics) override {} }; @@ -457,7 +459,10 @@ void HandleReceiverCommand(absl::string_view command, DemoReceiverDelegate& delegate, NetworkServiceManager* manager) { if (command == "avail") { - ServicePublisher* publisher = manager->GetMdnsServicePublisher(); + ServicePublisher* publisher = manager->GetServicePublisher(); + + OSP_LOG_INFO << "publisher->state() == " + << static_cast<int>(publisher->state()); if (publisher->state() == ServicePublisher::State::kSuspended) { publisher->Resume(); @@ -497,7 +502,7 @@ void RunReceiverPollLoop(pollfd& file_descriptor, void CleanupPublisherDemo(NetworkServiceManager* manager) { Receiver::Get()->SetReceiverDelegate(nullptr); Receiver::Get()->Deinit(); - manager->GetMdnsServicePublisher()->Stop(); + manager->GetServicePublisher()->Stop(); manager->GetProtocolConnectionServer()->Stop(); NetworkServiceManager::Dispose(); @@ -508,7 +513,6 @@ void PublisherDemo(absl::string_view friendly_name) { constexpr uint16_t server_port = 6667; - DemoPublisherObserver publisher_observer; // TODO(btolsch): aggregate initialization probably better? ServicePublisher::Config publisher_config; publisher_config.friendly_name = std::string(friendly_name); @@ -516,21 +520,23 @@ void PublisherDemo(absl::string_view friendly_name) { publisher_config.service_instance_name = "deadbeef"; publisher_config.connection_server_port = server_port; - auto mdns_publisher = MdnsServicePublisherFactory::Create( - publisher_config, &publisher_observer, - PlatformClientPosix::GetInstance()->GetTaskRunner()); - ServerConfig server_config; for (const InterfaceInfo& interface : GetNetworkInterfaces()) { OSP_VLOG << "Found interface: " << interface; if (!interface.addresses.empty()) { server_config.connection_endpoints.push_back( IPEndpoint{interface.addresses[0].address, server_port}); + publisher_config.network_interfaces.push_back(interface); } } OSP_LOG_IF(WARN, server_config.connection_endpoints.empty()) << "No network interfaces had usable addresses for mDNS publishing."; + DemoPublisherObserver publisher_observer; + auto service_publisher = ServicePublisherFactory::Create( + publisher_config, &publisher_observer, + PlatformClientPosix::GetInstance()->GetTaskRunner()); + MessageDemuxer demuxer(Clock::now, MessageDemuxer::kDefaultBufferLimit); DemoConnectionServerObserver server_observer; auto connection_server = ProtocolConnectionServerFactory::Create( @@ -538,13 +544,13 @@ void PublisherDemo(absl::string_view friendly_name) { PlatformClientPosix::GetInstance()->GetTaskRunner()); auto* network_service = - NetworkServiceManager::Create(nullptr, std::move(mdns_publisher), nullptr, - std::move(connection_server)); + NetworkServiceManager::Create(nullptr, std::move(service_publisher), + nullptr, std::move(connection_server)); DemoReceiverDelegate receiver_delegate; Receiver::Get()->Init(); Receiver::Get()->SetReceiverDelegate(&receiver_delegate); - network_service->GetMdnsServicePublisher()->Start(); + network_service->GetServicePublisher()->Start(); network_service->GetProtocolConnectionServer()->Start(); pollfd stdin_pollfd{STDIN_FILENO, POLLIN}; diff --git a/osp/go/README b/osp/go/README.md index f94ba080..96864ef8 100644 --- a/osp/go/README +++ b/osp/go/README.md @@ -1,7 +1,10 @@ -Run command line app: +To run the command line app: + +```bash $ go run cmd/osp.go server TV $ go run cmd/osp.go browse $ go run cmd/osp.go fling TV http://youtube.com -(may require apt-get install libwebkit2gtk-4.0 on linux) +``` +(may require `apt-get install libwebkit2gtk-4.0` on linux) diff --git a/osp/go/client.go b/osp/go/client.go index 61b52148..ead684b0 100644 --- a/osp/go/client.go +++ b/osp/go/client.go @@ -4,7 +4,7 @@ package osp -// TODO(pthatcher): +// TODO(jophba): // - Read messages as well, and more than one import ( diff --git a/osp/go/cmd/osp.go b/osp/go/cmd/osp.go index ad8d7f0d..f9c27bb9 100644 --- a/osp/go/cmd/osp.go +++ b/osp/go/cmd/osp.go @@ -4,7 +4,8 @@ package main -// TODO(pthatcher): Add response messages from receiver +// TODO(jophba): +// Add response messages from receiver // Inject JS into viewURL to using .Eval and .Bind to send and receiver presentation connection messages @@ -15,8 +16,9 @@ import ( "fmt" "log" - "osp" + "osp" + mdns "github.com/grandcat/zeroconf" "github.com/zserge/webview" ) @@ -30,7 +32,7 @@ func runServer(ctx context.Context, mdnsInstanceName string, port int) { func browseMdns(ctx context.Context) { entries, err := osp.BrowseMdns(ctx) - if (err != nil) { + if err != nil { log.Fatalf("Failed to browse mDNS: %v\n", err) } for entry := range entries { @@ -38,21 +40,35 @@ func browseMdns(ctx context.Context) { } } +func getMdnsHost(entry *mdns.ServiceEntry) string { + for _, ipv6 := range entry.AddrIPv6 { + log.Printf("Choosing IPv6 address [%s]\n", ipv6) + return fmt.Sprintf("[%s]", ipv6) + } + for _, ipv4 := range entry.AddrIPv4 { + log.Printf("Choosing IPv4 address %s\n", ipv4) + return fmt.Sprintf("%s", ipv4) + } + + // This shouldn't happen + log.Printf("No IP address found. Falling back to hostname %s\n", entry.HostName) + return entry.HostName +} + func flingUrl(ctx context.Context, target string, url string) { log.Printf("Search for %s\n", target) - entries, err := osp.BrowseMdns(ctx) - if (err != nil) { + entries, err := osp.LookupMdns(ctx, target) + if err != nil { log.Fatalf("Failed to browse mDNS: %v\n", err) } for entry := range entries { - if entry.Instance == target { - log.Printf("Fling %s to %s:%d\n", url, entry.HostName, entry.Port) - err := osp.StartPresentation(ctx, entry.HostName, entry.Port, url); - if err != nil { - log.Fatalln("Failed to start presentation."); - } - break + log.Printf("Fling %s to %s:%d\n", url, entry.HostName, entry.Port) + host := getMdnsHost(entry) + err := osp.StartPresentation(ctx, host, entry.Port, url) + if err != nil { + log.Fatalln("Failed to start presentation.") } + break } } @@ -91,10 +107,10 @@ func main() { log.Fatalln("Usage: osp server name") } mdnsInstanceName := args[1] - runServer(ctx, mdnsInstanceName, *port) + runServer(ctx, mdnsInstanceName, *port) case "browse": - browseMdns(ctx) + browseMdns(ctx) case "fling": if len(args) < 3 { @@ -103,7 +119,7 @@ func main() { target := args[1] url := args[2] - flingUrl(ctx, target, url) + flingUrl(ctx, target, url) case "view": if len(args) < 2 { diff --git a/osp/go/cmd/test.go b/osp/go/cmd/test.go index ad6142b8..ac068f42 100644 --- a/osp/go/cmd/test.go +++ b/osp/go/cmd/test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -// TODO(pthatcher): Use proper testing framework +// TODO(jophba): Use proper testing framework package main @@ -17,7 +17,7 @@ import ( ) func testMdns() { - // TODO(pthatcher): log error if it fails + // TODO(jophba): log error if it fails ctx := context.Background() instance := "TV" port := 10000 diff --git a/osp/go/controller.go b/osp/go/controller.go index 005ff52d..6fa25358 100644 --- a/osp/go/controller.go +++ b/osp/go/controller.go @@ -4,7 +4,7 @@ package osp -// TODO(pthatcher): +// TODO(jophba): // - Read and check the response message // - Make a nice object API with methods that can do more than one thing per connection // - Make it possible to have a presentation controller that is a server diff --git a/osp/go/go.mod b/osp/go/go.mod index 4928182e..a4c8fbfe 100644 --- a/osp/go/go.mod +++ b/osp/go/go.mod @@ -11,7 +11,7 @@ require ( github.com/lucas-clemente/quic-go-certificates v0.0.0-20160823095156-d2f86524cced // indirect github.com/miekg/dns v1.1.2 // indirect github.com/ugorji/go/codec v0.0.0-20181209151446-772ced7fd4c2 - github.com/zserge/webview v0.0.0-20181018084947-f390a2df9ec5 + github.com/zserge/webview v0.0.0-20200121135717-9c1b0a888aa4 golang.org/x/crypto v0.0.0-20190103213133-ff983b9c42bc // indirect golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e // indirect golang.org/x/sys v0.0.0-20190109145017-48ac38b7c8cb // indirect diff --git a/osp/go/go.sum b/osp/go/go.sum index 1a164b0a..4d12cc27 100644 --- a/osp/go/go.sum +++ b/osp/go/go.sum @@ -20,6 +20,8 @@ github.com/ugorji/go/codec v0.0.0-20181209151446-772ced7fd4c2 h1:EICbibRW4JNKMcY github.com/ugorji/go/codec v0.0.0-20181209151446-772ced7fd4c2/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/zserge/webview v0.0.0-20181018084947-f390a2df9ec5 h1:1zYVGLwZR4gPRQdEiOBf9s63ZHGfCkQ/p99d1zHuZBQ= github.com/zserge/webview v0.0.0-20181018084947-f390a2df9ec5/go.mod h1:a1CV8KR4Dd1eP2g+mEijGOp+HKczwdKHWyx0aPHKvo4= +github.com/zserge/webview v0.0.0-20200121135717-9c1b0a888aa4 h1:UjGpx0KjJegeVC/TZEL/dSCTUXajewpIA1NTF8snadg= +github.com/zserge/webview v0.0.0-20200121135717-9c1b0a888aa4/go.mod h1:a1CV8KR4Dd1eP2g+mEijGOp+HKczwdKHWyx0aPHKvo4= golang.org/x/crypto v0.0.0-20190103213133-ff983b9c42bc h1:F5tKCVGp+MUAHhKp5MZtGqAlGX3+oCsiL1Q629FL90M= golang.org/x/crypto v0.0.0-20190103213133-ff983b9c42bc/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= diff --git a/osp/go/mdns.go b/osp/go/mdns.go index 591ef6e7..7bda7ec6 100644 --- a/osp/go/mdns.go +++ b/osp/go/mdns.go @@ -4,9 +4,9 @@ package osp -// TODO(pthatcher): +// TODO(jophba): // - Make our own abstraction that has -// .InstanceName, .HostName, .MetadataVersion, .FingerPrint +// .InstanceName, .HostName, .MetadataVersion, .FingerPrint // rather than using mdns.ServiceEntry // - Advertise TXT (text below) with "fp" and "mv" @@ -21,8 +21,9 @@ const ( MdnsDomain = "local" ) -// Returns a channel of mDNS entries -// The critical parts are entry.Target (name) entry.HostName (address) +// Returns a channel of mDNS entries. The critical parts are +// entry.Target (service name) entry.HostName, entry.AddrIPv4, and +// entry.AddrIPv6. func BrowseMdns(ctx context.Context) (<-chan *mdns.ServiceEntry, error) { entries := make(chan *mdns.ServiceEntry) @@ -35,6 +36,20 @@ func BrowseMdns(ctx context.Context) (<-chan *mdns.ServiceEntry, error) { return entries, err } +// Returns a channel of mDNS entries. The critical parts are, +// entry.HostName, entry.AddrIPv4, and entry.AddrIPv6. +func LookupMdns(ctx context.Context, target string) (<-chan *mdns.ServiceEntry, error) { + entries := make(chan *mdns.ServiceEntry) + + resolver, err := mdns.NewResolver(nil) + if err != nil { + return entries, err + } + + err = resolver.Lookup(ctx, target, MdnsServiceType, MdnsDomain, entries) + return entries, err +} + func RunMdnsServer(ctx context.Context, instance string, port int) error { var text []string server, err := mdns.Register(instance, MdnsServiceType, MdnsDomain, port, text, nil /* ifaces */) diff --git a/osp/go/messages.go b/osp/go/messages.go index ffc85633..174ef300 100644 --- a/osp/go/messages.go +++ b/osp/go/messages.go @@ -4,7 +4,7 @@ package osp -// TODO(pthatcher): +// TODO(jophba): // - Read and write size prefixes import ( diff --git a/osp/go/quic.go b/osp/go/quic.go index 8fdff9aa..1ec2224c 100644 --- a/osp/go/quic.go +++ b/osp/go/quic.go @@ -4,7 +4,7 @@ package osp -// TODO(pthatcher): +// TODO(jophba): // - avoid NetworkIdleTimeout // - make a client object that can send and receive more than one stream // - make a server object that can send and receive more than one stream @@ -56,7 +56,7 @@ func readAllStreams(ctx context.Context, session quic.Session, streams chan<- io // Returns a quic.Session object with a .OpenStreamSync method to send streams func DialAsQuicClient(ctx context.Context, hostname string, port int) (quic.Session, error) { - // TODO(pthatcher): Change InsecureSkipVerify + // TODO(jophba): Change InsecureSkipVerify tlsConfig := &tls.Config{InsecureSkipVerify: true} addr := fmt.Sprintf("%s:%d", hostname, port) session, err := quic.DialAddrContext(ctx, addr, tlsConfig, nil) diff --git a/osp/go/receiver.go b/osp/go/receiver.go index de1f3727..ace8b606 100644 --- a/osp/go/receiver.go +++ b/osp/go/receiver.go @@ -4,7 +4,7 @@ package osp -// TODO(pthatcher): +// TODO(jophba): // - Send a response message // - Make a nice object API with methods // - Make it possible to have a presentation receiver that is a client diff --git a/osp/go/server.go b/osp/go/server.go index e7770576..8f325fd4 100644 --- a/osp/go/server.go +++ b/osp/go/server.go @@ -4,7 +4,7 @@ package osp -// TODO(pthatcher): +// TODO(jophba): // - Write messages as well import ( @@ -14,7 +14,7 @@ import ( ) func ReadMessagesAsServer(ctx context.Context, instanceName string, port int, cert tls.Certificate, messages chan<- interface{}) error { - // TODO(pthatcher): log error if it fails + // TODO(jophba): log error if it fails go RunMdnsServer(ctx, instanceName, port) streams := make(chan io.ReadWriteCloser) go RunQuicServer(ctx, port, cert, streams) diff --git a/osp/impl/BUILD.gn b/osp/impl/BUILD.gn index 83326f57..54404a61 100644 --- a/osp/impl/BUILD.gn +++ b/osp/impl/BUILD.gn @@ -6,8 +6,9 @@ import("../../osp/build/config/services.gni") source_set("impl") { sources = [ - "mdns_platform_service.cc", - "mdns_platform_service.h", + "dns_sd_publisher_client.cc", + "dns_sd_publisher_client.h", + "dns_sd_service_publisher_factory.cc", "message_demuxer.cc", "network_service_manager.cc", "presentation/presentation_common.cc", @@ -26,23 +27,13 @@ source_set("impl") { "with_destruction_callback.cc", "with_destruction_callback.h", ] - - if (use_mdns_responder) { - sources += [ - "internal_services.cc", - "internal_services.h", - "mdns_responder_service.cc", - "mdns_responder_service.h", - "mdns_service_listener_factory.cc", - "mdns_service_publisher_factory.cc", - ] - } - public_deps = [ "../msgs", "../public", ] deps = [ + "../../discovery:dnssd", + "../../discovery:public", "../../platform", "../../third_party/abseil", "../../util", diff --git a/osp/impl/DEPS b/osp/impl/DEPS index 63752366..cd004f79 100644 --- a/osp/impl/DEPS +++ b/osp/impl/DEPS @@ -1,7 +1,9 @@ -# Copyright (c) 2019 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. +# -*- Mode: Python; -*- include_rules = [ - '+osp/impl/discovery/mdns', + # Allowed to use discovery module. + '+discovery/public', + '+discovery/dnssd/public', + # Also necessary to implement discovery APIs. + '+discovery/common', ] diff --git a/osp/impl/discovery/mdns/BUILD.gn b/osp/impl/discovery/mdns/BUILD.gn deleted file mode 100644 index cd051d17..00000000 --- a/osp/impl/discovery/mdns/BUILD.gn +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2018 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -import("../../../build/config/services.gni") -assert(use_mdns_responder) - -source_set("mdns_interface") { - sources = [ - "domain_name.cc", - "domain_name.h", - "mdns_responder_adapter.cc", - "mdns_responder_adapter.h", - ] - - public_deps = [ - "../../../../platform", - "../../../../third_party/abseil", - "../../../../util", - ] -} - -source_set("unittests") { - testonly = true - - sources = [ - "domain_name_unittest.cc", - ] - - deps = [ - ":mdns_interface", - "../../../../third_party/googletest:gmock", - "../../../../third_party/googletest:gtest", - ] - - sources += [ "mdns_responder_adapter_impl_unittest.cc" ] - deps += [ ":mdns" ] -} - -executable("mdns_demo") { - sources = [ - "mdns_demo.cc", - ] - - deps = [ - ":mdns", - ] -} - -source_set("mdns") { - sources = [ - "mdns_responder_adapter_impl.cc", - "mdns_responder_adapter_impl.h", - "mdns_responder_platform.cc", - "mdns_responder_platform.h", - ] - - public_deps = [ - ":mdns_interface", - "../../../../platform", - "../../../../util", - ] - - deps = [ - "../../../../third_party/mDNSResponder:core", - ] -} diff --git a/osp/impl/discovery/mdns/DEPS b/osp/impl/discovery/mdns/DEPS deleted file mode 100644 index 96a7209c..00000000 --- a/osp/impl/discovery/mdns/DEPS +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright 2019 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -include_rules = [ - '+platform/impl', # Needed by embedder_demo.cc -] diff --git a/osp/impl/discovery/mdns/domain_name.cc b/osp/impl/discovery/mdns/domain_name.cc deleted file mode 100644 index c574793a..00000000 --- a/osp/impl/discovery/mdns/domain_name.cc +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/discovery/mdns/domain_name.h" - -#include <algorithm> -#include <iterator> - -#include "util/stringprintf.h" - -namespace openscreen { -namespace osp { - -// static -DomainName DomainName::GetLocalDomain() { - return DomainName{{5, 'l', 'o', 'c', 'a', 'l', 0}}; -} - -// static -ErrorOr<DomainName> DomainName::Append(const DomainName& first, - const DomainName& second) { - OSP_CHECK(first.domain_name_.size()); - OSP_CHECK(second.domain_name_.size()); - - // Both vectors should represent null terminated domain names. - OSP_DCHECK_EQ(first.domain_name_.back(), '\0'); - OSP_DCHECK_EQ(second.domain_name_.back(), '\0'); - if ((first.domain_name_.size() + second.domain_name_.size() - 1) > - kDomainNameMaxLength) { - return Error::Code::kDomainNameTooLong; - } - - DomainName result; - result.domain_name_.clear(); - result.domain_name_.insert(result.domain_name_.begin(), - first.domain_name_.begin(), - first.domain_name_.end()); - result.domain_name_.insert(result.domain_name_.end() - 1, - second.domain_name_.begin(), - second.domain_name_.end() - 1); - return result; -} - -DomainName::DomainName() : domain_name_{0u} {} -DomainName::DomainName(std::vector<uint8_t>&& domain_name) - : domain_name_(std::move(domain_name)) { - OSP_CHECK_LE(domain_name_.size(), kDomainNameMaxLength); -} -DomainName::DomainName(const DomainName&) = default; -DomainName::DomainName(DomainName&&) noexcept = default; -DomainName::~DomainName() = default; -DomainName& DomainName::operator=(const DomainName&) = default; -DomainName& DomainName::operator=(DomainName&&) noexcept = default; - -bool DomainName::operator==(const DomainName& other) const { - if (domain_name_.size() != other.domain_name_.size()) { - return false; - } - for (size_t i = 0; i < domain_name_.size(); ++i) { - if (tolower(domain_name_[i]) != tolower(other.domain_name_[i])) { - return false; - } - } - return true; -} - -bool DomainName::operator!=(const DomainName& other) const { - return !(*this == other); -} - -bool DomainName::EndsWithLocalDomain() const { - const DomainName local_domain = GetLocalDomain(); - if (domain_name_.size() < local_domain.domain_name_.size()) - return false; - - return std::equal(local_domain.domain_name_.begin(), - local_domain.domain_name_.end(), - domain_name_.end() - local_domain.domain_name_.size()); -} - -Error DomainName::Append(const DomainName& after) { - OSP_CHECK(after.domain_name_.size()); - OSP_DCHECK_EQ(after.domain_name_.back(), 0u); - - if ((domain_name_.size() + after.domain_name_.size() - 1) > - kDomainNameMaxLength) { - return Error::Code::kDomainNameTooLong; - } - - domain_name_.insert(domain_name_.end() - 1, after.domain_name_.begin(), - after.domain_name_.end() - 1); - return Error::None(); -} - -std::vector<absl::string_view> DomainName::GetLabels() const { - OSP_DCHECK_GT(domain_name_.size(), 0u); - OSP_DCHECK_LT(domain_name_.size(), kDomainNameMaxLength); - - std::vector<absl::string_view> result; - const uint8_t* data = domain_name_.data(); - while (*data != 0) { - const size_t label_length = *data; - OSP_DCHECK_LT(label_length, kDomainNameMaxLabelLength); - - ++data; - result.emplace_back(reinterpret_cast<const char*>(data), label_length); - data += label_length; - } - return result; -} - -bool DomainNameComparator::operator()(const DomainName& a, - const DomainName& b) const { - return a.domain_name() < b.domain_name(); -} - -std::ostream& operator<<(std::ostream& os, const DomainName& domain_name) { - const auto& data = domain_name.domain_name(); - OSP_DCHECK_GT(data.size(), 0u); - auto it = data.begin(); - while (*it != 0) { - size_t length = *it++; - PrettyPrintAsciiHex(os, it, it + length); - it += length; - os << "."; - } - return os; -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/discovery/mdns/domain_name.h b/osp/impl/discovery/mdns/domain_name.h deleted file mode 100644 index c29ef9db..00000000 --- a/osp/impl/discovery/mdns/domain_name.h +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_IMPL_DISCOVERY_MDNS_DOMAIN_NAME_H_ -#define OSP_IMPL_DISCOVERY_MDNS_DOMAIN_NAME_H_ - -#include <cstdint> -#include <ostream> -#include <string> -#include <vector> - -#include "absl/strings/string_view.h" -#include "platform/base/error.h" -#include "util/osp_logging.h" - -namespace openscreen { -namespace osp { - -struct DomainName { - static ErrorOr<DomainName> Append(const DomainName& first, - const DomainName& second); - - template <typename It> - static ErrorOr<DomainName> FromLabels(It first, It last) { - size_t total_length = 1; - for (auto label = first; label != last; ++label) { - if (label->size() > kDomainNameMaxLabelLength) - return Error::Code::kDomainNameLabelTooLong; - - total_length += label->size() + 1; - } - if (total_length > kDomainNameMaxLength) - return Error::Code::kDomainNameTooLong; - - DomainName result; - result.domain_name_.resize(total_length); - auto result_it = result.domain_name_.begin(); - for (auto label = first; label != last; ++label) { - *result_it++ = static_cast<uint8_t>(label->size()); - result_it = std::copy(label->begin(), label->end(), result_it); - } - *result_it = 0; - return std::move(result); - } - - static DomainName GetLocalDomain(); - - static constexpr uint8_t kDomainNameMaxLabelLength = 63u; - static constexpr uint16_t kDomainNameMaxLength = 256u; - - DomainName(); - explicit DomainName(std::vector<uint8_t>&& domain_name); - DomainName(const DomainName&); - DomainName(DomainName&&) noexcept; - ~DomainName(); - DomainName& operator=(const DomainName&); - DomainName& operator=(DomainName&&) noexcept; - - bool operator==(const DomainName& other) const; - bool operator!=(const DomainName& other) const; - - bool EndsWithLocalDomain() const; - bool IsEmpty() const { return domain_name_.size() == 1 && !domain_name_[0]; } - - Error Append(const DomainName& after); - std::vector<absl::string_view> GetLabels() const; - - const std::vector<uint8_t>& domain_name() const { return domain_name_; } - - private: - // RFC 1035 domain name format: sequence of 1 octet label length followed by - // label data, ending with a 0 octet. May not exceed 256 bytes (including - // terminating 0). - // For example, openscreen.org would be encoded as: - // {10, 'o', 'p', 'e', 'n', 's', 'c', 'r', 'e', 'e', 'n', - // 3, 'o', 'r', 'g', 0} - std::vector<uint8_t> domain_name_; -}; - -class DomainNameComparator { - public: - bool operator()(const DomainName& a, const DomainName& b) const; -}; - -std::ostream& operator<<(std::ostream& os, const DomainName& domain_name); - -} // namespace osp -} // namespace openscreen - -#endif // OSP_IMPL_DISCOVERY_MDNS_DOMAIN_NAME_H_ diff --git a/osp/impl/discovery/mdns/domain_name_unittest.cc b/osp/impl/discovery/mdns/domain_name_unittest.cc deleted file mode 100644 index 76f31003..00000000 --- a/osp/impl/discovery/mdns/domain_name_unittest.cc +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/discovery/mdns/domain_name.h" - -#include <sstream> - -#include "gtest/gtest.h" -#include "platform/base/error.h" - -namespace openscreen { -namespace osp { - -namespace { - -ErrorOr<DomainName> FromLabels(const std::vector<std::string>& labels) { - return DomainName::FromLabels(labels.begin(), labels.end()); -} - -template <typename T> -T UnpackErrorOr(ErrorOr<T> error_or) { - EXPECT_TRUE(error_or); - return std::move(error_or.value()); -} - -} // namespace - -TEST(DomainNameTest, Constructors) { - DomainName empty; - - ASSERT_EQ(1u, empty.domain_name().size()); - EXPECT_EQ(0, empty.domain_name()[0]); - - DomainName original({10, 'o', 'p', 'e', 'n', 's', 'c', 'r', 'e', 'e', 'n', 3, - 'o', 'r', 'g', 0}); - ASSERT_EQ(16u, original.domain_name().size()); - - auto data_copy = original.domain_name(); - DomainName direct_ctor(std::move(data_copy)); - EXPECT_EQ(direct_ctor.domain_name(), original.domain_name()); - - DomainName copy_ctor(original); - EXPECT_EQ(copy_ctor.domain_name(), original.domain_name()); - - DomainName move_ctor(std::move(copy_ctor)); - EXPECT_EQ(move_ctor.domain_name(), original.domain_name()); - - DomainName copy_assign; - copy_assign = move_ctor; - EXPECT_EQ(copy_assign.domain_name(), original.domain_name()); - - DomainName move_assign; - move_assign = std::move(move_ctor); - EXPECT_EQ(move_assign.domain_name(), original.domain_name()); -} - -TEST(DomainNameTest, FromLabels) { - const auto typical = - std::vector<uint8_t>{10, 'o', 'p', 'e', 'n', 's', 'c', 'r', - 'e', 'e', 'n', 3, 'o', 'r', 'g', 0}; - DomainName result = UnpackErrorOr(FromLabels({"openscreen", "org"})); - EXPECT_EQ(result.domain_name(), typical); - - const auto includes_dot = - std::vector<uint8_t>{11, 'o', 'p', 'e', 'n', '.', 's', 'c', 'r', - 'e', 'e', 'n', 3, 'o', 'r', 'g', 0}; - result = UnpackErrorOr(FromLabels({"open.screen", "org"})); - EXPECT_EQ(result.domain_name(), includes_dot); - - const auto includes_non_ascii = - std::vector<uint8_t>{11, 'o', 'p', 'e', 'n', 7, 's', 'c', 'r', - 'e', 'e', 'n', 3, 'o', 'r', 'g', 0}; - result = UnpackErrorOr(FromLabels({"open\7screen", "org"})); - EXPECT_EQ(result.domain_name(), includes_non_ascii); - - ASSERT_FALSE( - FromLabels({"extremely-long-label-that-is-actually-too-long-" - "for-rfc-1034-and-will-not-generate"})); - - ASSERT_FALSE(FromLabels({ - "extremely-long-domain-name-that-is-made-of", - "valid-labels", - "however-overall-it-is-too-long-for-rfc-1034", - "so-it-should-fail-to-generate", - "filler-filler-filler-filler-filler", - "filler-filler-filler-filler-filler", - "filler-filler-filler-filler-filler", - "filler-filler-filler-filler-filler", - })); -} - -TEST(DomainNameTest, Equality) { - DomainName alpha = UnpackErrorOr(FromLabels({"alpha", "openscreen", "org"})); - DomainName beta = UnpackErrorOr(FromLabels({"beta", "openscreen", "org"})); - - const DomainName alpha_copy = alpha; - - EXPECT_TRUE(alpha == alpha); - EXPECT_FALSE(alpha != alpha); - EXPECT_TRUE(alpha == alpha_copy); - EXPECT_FALSE(alpha != alpha_copy); - EXPECT_FALSE(alpha == beta); - EXPECT_TRUE(alpha != beta); -} - -TEST(DomainNameTest, EndsWithLocalDomain) { - DomainName alpha; - EXPECT_FALSE(alpha.EndsWithLocalDomain()); - - alpha = UnpackErrorOr(FromLabels({"alpha", "openscreen", "org"})); - DomainName beta = UnpackErrorOr(FromLabels({"beta", "local"})); - - EXPECT_FALSE(alpha.EndsWithLocalDomain()); - EXPECT_TRUE(beta.EndsWithLocalDomain()); -} - -TEST(DomainNameTest, IsEmpty) { - DomainName alpha; - DomainName beta(std::vector<uint8_t>{0}); - - EXPECT_TRUE(alpha.IsEmpty()); - EXPECT_TRUE(beta.IsEmpty()); - - alpha = UnpackErrorOr(FromLabels({"alpha", "openscreen", "org"})); - EXPECT_FALSE(alpha.IsEmpty()); -} - -TEST(DomainNameTest, Append) { - const auto expected_service_name = - std::vector<uint8_t>{5, 'a', 'l', 'p', 'h', 'a', '\0'}; - const auto expected_service_type_initial = std::vector<uint8_t>{ - 11, '_', 'o', 'p', 'e', 'n', 's', 'c', 'r', 'e', 'e', 'n', '\0'}; - const auto expected_protocol = - std::vector<uint8_t>{5, '_', 'q', 'u', 'i', 'c', '\0'}; - const auto expected_service_type = - std::vector<uint8_t>{11, '_', 'o', 'p', 'e', 'n', 's', 'c', 'r', 'e', - 'e', 'n', 5, '_', 'q', 'u', 'i', 'c', '\0'}; - const auto total_expected = std::vector<uint8_t>{ - 5, 'a', 'l', 'p', 'h', 'a', 11, '_', 'o', 'p', 'e', 'n', 's', - 'c', 'r', 'e', 'e', 'n', 5, '_', 'q', 'u', 'i', 'c', '\0'}; - - DomainName service_name = UnpackErrorOr(FromLabels({"alpha"})); - EXPECT_EQ(service_name.domain_name(), expected_service_name); - - DomainName service_type = UnpackErrorOr(FromLabels({"_openscreen"})); - EXPECT_EQ(service_type.domain_name(), expected_service_type_initial); - - DomainName protocol = UnpackErrorOr(FromLabels({"_quic"})); - EXPECT_EQ(protocol.domain_name(), expected_protocol); - - EXPECT_TRUE(service_type.Append(protocol).ok()); - EXPECT_EQ(service_type.domain_name(), expected_service_type); - - DomainName result = - UnpackErrorOr(DomainName::Append(service_name, service_type)); - EXPECT_EQ(result.domain_name(), total_expected); -} - -TEST(DomainNameTest, GetLabels) { - const auto labels = std::vector<std::string>{"alpha", "beta", "gamma", "org"}; - DomainName domain_name = UnpackErrorOr(FromLabels(labels)); - - const auto actual_labels = domain_name.GetLabels(); - for (size_t i = 0; i < labels.size(); ++i) { - EXPECT_EQ(labels[i], actual_labels[i]); - } -} - -TEST(DomainNameTest, StreamEscaping) { - { - std::stringstream ss; - ss << DomainName(std::vector<uint8_t>{1, 0, 0}); - EXPECT_EQ(ss.str(), "\\x00."); - } - { - std::stringstream ss; - ss << DomainName(std::vector<uint8_t>{1, 1, 0}); - EXPECT_EQ(ss.str(), "\\x01."); - } - { - std::stringstream ss; - ss << DomainName(std::vector<uint8_t>{1, 18, 0}); - EXPECT_EQ(ss.str(), "\\x12."); - } - { - std::stringstream ss; - ss << DomainName(std::vector<uint8_t>{1, 255, 0}); - EXPECT_EQ(ss.str(), "\\xff."); - } -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/discovery/mdns/mdns_demo.cc b/osp/impl/discovery/mdns/mdns_demo.cc deleted file mode 100644 index 1fc1513e..00000000 --- a/osp/impl/discovery/mdns/mdns_demo.cc +++ /dev/null @@ -1,374 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include <signal.h> -#include <unistd.h> - -#include <algorithm> -#include <map> -#include <memory> -#include <vector> - -// TODO(rwkeane): Remove references to platform/impl -#include "osp/impl/discovery/mdns/mdns_responder_adapter_impl.h" -#include "platform/api/network_interface.h" -#include "platform/api/time.h" -#include "platform/base/error.h" -#include "platform/impl/logging.h" -#include "platform/impl/platform_client_posix.h" -#include "platform/impl/task_runner.h" -#include "platform/impl/udp_socket_reader_posix.h" - -// This file contains a demo of our mDNSResponder wrapper code. It can both -// listen for mDNS services and advertise an mDNS service. The command-line -// usage is: -// mdns_demo [service_type] [service_instance_name] -// service_type defaults to '_openscreen._udp' and service_instance_name -// defaults to ''. service_type determines services the program listens for and -// when service_instance_name is not empty, a service of -// 'service_instance_name.service_type' is also advertised. -// -// The program will print a list of discovered services when it receives a USR1 -// or INT signal. The pid is printed at the beginning of the program to -// facilitate this. -// -// There are a few known bugs around the handling of record events, so this -// shouldn't be expected to be a source of truth, nor should it be expected to -// be correct after running for a long time. - -namespace openscreen { -namespace osp { -namespace { - -bool g_done = false; -bool g_dump_services = false; - -struct Service { - explicit Service(DomainName service_instance) - : service_instance(std::move(service_instance)) {} - ~Service() = default; - - DomainName service_instance; - DomainName domain_name; - IPAddress address; - uint16_t port; - std::vector<std::string> txt; -}; - -class DemoSocketClient : public UdpSocket::Client { - public: - explicit DemoSocketClient(MdnsResponderAdapterImpl* mdns) : mdns_(mdns) {} - - void OnError(UdpSocket* socket, Error error) override { - // TODO(crbug.com/openscreen/66): Change to OSP_LOG_FATAL. - OSP_LOG_ERROR << "configuration failed for interface " << error.message(); - OSP_CHECK(false); - } - - void OnSendError(UdpSocket* socket, Error error) override { - OSP_UNIMPLEMENTED(); - } - - void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override { - mdns_->OnRead(socket, std::move(packet)); - } - - private: - MdnsResponderAdapterImpl* mdns_; -}; - -using ServiceMap = std::map<DomainName, Service, DomainNameComparator>; -ServiceMap* g_services = nullptr; - -void sigusr1_dump_services(int) { - g_dump_services = true; -} - -void sigint_stop(int) { - OSP_LOG_INFO << "caught SIGINT, exiting..."; - g_done = true; -} - -std::vector<std::string> SplitByDot(const std::string& domain_part) { - std::vector<std::string> result; - auto copy_it = domain_part.begin(); - for (auto it = domain_part.begin(); it != domain_part.end(); ++it) { - if (*it == '.') { - result.emplace_back(copy_it, it); - copy_it = it + 1; - } - } - if (copy_it != domain_part.end()) - result.emplace_back(copy_it, domain_part.end()); - - return result; -} - -void SignalThings() { - struct sigaction usr1_sa; - struct sigaction int_sa; - struct sigaction unused; - - usr1_sa.sa_handler = &sigusr1_dump_services; - sigemptyset(&usr1_sa.sa_mask); - usr1_sa.sa_flags = 0; - - int_sa.sa_handler = &sigint_stop; - sigemptyset(&int_sa.sa_mask); - int_sa.sa_flags = 0; - - sigaction(SIGUSR1, &usr1_sa, &unused); - sigaction(SIGINT, &int_sa, &unused); - - OSP_LOG_INFO << "signal handlers setup" << std::endl << "pid: " << getpid(); -} - -std::vector<std::unique_ptr<UdpSocket>> SetUpMulticastSockets( - TaskRunner* task_runner, - const std::vector<NetworkInterfaceIndex>& index_list, - UdpSocket::Client* client) { - std::vector<std::unique_ptr<UdpSocket>> sockets; - for (const auto ifindex : index_list) { - auto create_result = - UdpSocket::Create(task_runner, client, IPEndpoint{{}, 5353}); - if (!create_result) { - OSP_LOG_ERROR << "failed to create IPv4 socket for interface " << ifindex - << ": " << create_result.error().message(); - continue; - } - std::unique_ptr<UdpSocket> socket = std::move(create_result.value()); - - socket->JoinMulticastGroup(IPAddress{224, 0, 0, 251}, ifindex); - socket->SetMulticastOutboundInterface(ifindex); - socket->Bind(); - - OSP_LOG_INFO << "listening on interface " << ifindex; - sockets.emplace_back(std::move(socket)); - } - return sockets; -} - -void LogService(const Service& s) { - OSP_LOG_INFO << "PTR: (" << s.service_instance << ")" << std::endl - << "SRV: " << s.domain_name << ":" << s.port << std::endl - << "TXT:"; - - for (const auto& l : s.txt) { - OSP_LOG_INFO << " | " << l; - } - OSP_LOG_INFO << "A: " << s.address; -} - -void HandleEvents(MdnsResponderAdapterImpl* mdns_adapter) { - for (auto& ptr_event : mdns_adapter->TakePtrResponses()) { - auto it = g_services->find(ptr_event.service_instance); - switch (ptr_event.header.response_type) { - case QueryEventHeader::Type::kAdded: - case QueryEventHeader::Type::kAddedNoCache: - mdns_adapter->StartSrvQuery(ptr_event.header.socket, - ptr_event.service_instance); - mdns_adapter->StartTxtQuery(ptr_event.header.socket, - ptr_event.service_instance); - if (it == g_services->end()) { - g_services->emplace(ptr_event.service_instance, - Service(ptr_event.service_instance)); - } - break; - case QueryEventHeader::Type::kRemoved: - // PTR may be removed and added without updating related entries (SRV - // and friends) so this simple logic is actually broken, but I don't - // want to do a better design or pointer hell for just a demo. - OSP_LOG_WARN << "ptr-remove: " << ptr_event.service_instance; - if (it != g_services->end()) - g_services->erase(it); - - break; - } - } - for (auto& srv_event : mdns_adapter->TakeSrvResponses()) { - auto it = g_services->find(srv_event.service_instance); - if (it == g_services->end()) - continue; - - switch (srv_event.header.response_type) { - case QueryEventHeader::Type::kAdded: - case QueryEventHeader::Type::kAddedNoCache: - mdns_adapter->StartAQuery(srv_event.header.socket, - srv_event.domain_name); - it->second.domain_name = std::move(srv_event.domain_name); - it->second.port = srv_event.port; - break; - case QueryEventHeader::Type::kRemoved: - OSP_LOG_WARN << "srv-remove: " << srv_event.service_instance; - it->second.domain_name = DomainName(); - it->second.port = 0; - break; - } - } - for (auto& txt_event : mdns_adapter->TakeTxtResponses()) { - auto it = g_services->find(txt_event.service_instance); - if (it == g_services->end()) - continue; - - switch (txt_event.header.response_type) { - case QueryEventHeader::Type::kAdded: - case QueryEventHeader::Type::kAddedNoCache: - it->second.txt = std::move(txt_event.txt_info); - break; - case QueryEventHeader::Type::kRemoved: - OSP_LOG_WARN << "txt-remove: " << txt_event.service_instance; - it->second.txt.clear(); - break; - } - } - for (const auto& a_event : mdns_adapter->TakeAResponses()) { - // TODO(btolsch): If multiple SRV records specify the same domain, the A - // will only update the first. I didn't think this would happen but I - // noticed this happens for cast groups. - auto it = std::find_if(g_services->begin(), g_services->end(), - [&a_event](const std::pair<DomainName, Service>& s) { - return s.second.domain_name == a_event.domain_name; - }); - if (it == g_services->end()) - continue; - - switch (a_event.header.response_type) { - case QueryEventHeader::Type::kAdded: - case QueryEventHeader::Type::kAddedNoCache: - it->second.address = a_event.address; - break; - case QueryEventHeader::Type::kRemoved: - OSP_LOG_WARN << "a-remove: " << a_event.domain_name; - it->second.address = IPAddress(0, 0, 0, 0); - break; - } - } -} - -void BrowseDemo(TaskRunner* task_runner, - const std::string& service_name, - const std::string& service_protocol, - const std::string& service_instance) { - SignalThings(); - - std::vector<std::string> labels{service_name, service_protocol}; - ErrorOr<DomainName> service_type = - DomainName::FromLabels(labels.begin(), labels.end()); - - if (!service_type) { - OSP_LOG_ERROR << "bad domain labels: " << service_name << ", " - << service_protocol; - return; - } - - auto mdns_adapter = std::make_unique<MdnsResponderAdapterImpl>(); - mdns_adapter->Init(); - mdns_adapter->SetHostLabel("gigliorononomicon"); - const std::vector<InterfaceInfo> interfaces = GetNetworkInterfaces(); - std::vector<NetworkInterfaceIndex> index_list; - for (const auto& interface : interfaces) { - OSP_LOG_INFO << "Found interface: " << interface; - if (!interface.addresses.empty()) { - index_list.push_back(interface.index); - } - } - OSP_LOG_IF(WARN, index_list.empty()) - << "No network interfaces had usable addresses for mDNS."; - - DemoSocketClient client(mdns_adapter.get()); - auto sockets = SetUpMulticastSockets(task_runner, index_list, &client); - // The code below assumes the elements in |sockets| is in exact 1:1 - // correspondence with the elements in |index_list|. Crash the demo if any - // sockets are missing (i.e., failed to be set up). - OSP_CHECK_EQ(sockets.size(), index_list.size()); - - // Listen on all interfaces. - auto socket_it = sockets.begin(); - for (NetworkInterfaceIndex index : index_list) { - const auto& interface = - *std::find_if(interfaces.begin(), interfaces.end(), - [index](const openscreen::InterfaceInfo& info) { - return info.index == index; - }); - // Pick any address for the given interface. - mdns_adapter->RegisterInterface(interface, interface.addresses.front(), - socket_it->get()); - ++socket_it; - } - - if (!service_instance.empty()) { - mdns_adapter->RegisterService(service_instance, service_name, - service_protocol, DomainName(), 12345, - {{"k1", "yurtle"}, {"k2", "turtle"}}); - } - - for (const std::unique_ptr<UdpSocket>& socket : sockets) { - mdns_adapter->StartPtrQuery(socket.get(), service_type.value()); - } - - while (!g_done) { - HandleEvents(mdns_adapter.get()); - if (g_dump_services) { - OSP_LOG_INFO << "num services: " << g_services->size(); - for (const auto& s : *g_services) { - LogService(s.second); - } - if (!service_instance.empty()) { - mdns_adapter->UpdateTxtData( - service_instance, service_name, service_protocol, - {{"k1", "oogley"}, {"k2", "moogley"}, {"k3", "googley"}}); - } - g_dump_services = false; - } - mdns_adapter->RunTasks(); - } - OSP_LOG_INFO << "num services: " << g_services->size(); - for (const auto& s : *g_services) { - LogService(s.second); - } - for (const std::unique_ptr<UdpSocket>& socket : sockets) { - mdns_adapter->DeregisterInterface(socket.get()); - } - mdns_adapter->Close(); -} - -} // namespace -} // namespace osp -} // namespace openscreen - -int main(int argc, char** argv) { - using openscreen::Clock; - using openscreen::PlatformClientPosix; - - openscreen::SetLogLevel(openscreen::LogLevel::kVerbose); - - std::string service_instance; - std::string service_type("_openscreen._udp"); - if (argc >= 2) - service_type = argv[1]; - - if (argc >= 3) - service_instance = argv[2]; - - if (service_type.size() && service_type[0] == '.') - return 1; - - auto labels = openscreen::osp::SplitByDot(service_type); - if (labels.size() != 2) - return 1; - - openscreen::osp::ServiceMap services; - openscreen::osp::g_services = &services; - - PlatformClientPosix::Create(std::chrono::milliseconds(50)); - - openscreen::osp::BrowseDemo( - PlatformClientPosix::GetInstance()->GetTaskRunner(), labels[0], labels[1], - service_instance); - - PlatformClientPosix::ShutDown(); - - openscreen::osp::g_services = nullptr; - return 0; -} diff --git a/osp/impl/discovery/mdns/mdns_responder_adapter.cc b/osp/impl/discovery/mdns/mdns_responder_adapter.cc deleted file mode 100644 index edf25503..00000000 --- a/osp/impl/discovery/mdns/mdns_responder_adapter.cc +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/discovery/mdns/mdns_responder_adapter.h" - -namespace openscreen { -namespace osp { - -QueryEventHeader::QueryEventHeader() = default; -QueryEventHeader::QueryEventHeader(QueryEventHeader::Type response_type, - UdpSocket* socket) - : response_type(response_type), socket(socket) {} -QueryEventHeader::QueryEventHeader(QueryEventHeader&&) noexcept = default; -QueryEventHeader::~QueryEventHeader() = default; -QueryEventHeader& QueryEventHeader::operator=(QueryEventHeader&&) noexcept = - default; - -AEvent::AEvent() = default; -AEvent::AEvent(QueryEventHeader header, - DomainName domain_name, - IPAddress address) - : header(std::move(header)), - domain_name(std::move(domain_name)), - address(std::move(address)) {} -AEvent::AEvent(AEvent&&) noexcept = default; -AEvent::~AEvent() = default; -AEvent& AEvent::operator=(AEvent&&) noexcept = default; - -AaaaEvent::AaaaEvent() = default; -AaaaEvent::AaaaEvent(QueryEventHeader header, - DomainName domain_name, - IPAddress address) - : header(std::move(header)), - domain_name(std::move(domain_name)), - address(std::move(address)) {} -AaaaEvent::AaaaEvent(AaaaEvent&&) noexcept = default; -AaaaEvent::~AaaaEvent() = default; -AaaaEvent& AaaaEvent::operator=(AaaaEvent&&) noexcept = default; - -PtrEvent::PtrEvent() = default; -PtrEvent::PtrEvent(QueryEventHeader header, DomainName service_instance) - : header(std::move(header)), - service_instance(std::move(service_instance)) {} -PtrEvent::PtrEvent(PtrEvent&&) noexcept = default; -PtrEvent::~PtrEvent() = default; -PtrEvent& PtrEvent::operator=(PtrEvent&&) noexcept = default; - -SrvEvent::SrvEvent() = default; -SrvEvent::SrvEvent(QueryEventHeader header, - DomainName service_instance, - DomainName domain_name, - uint16_t port) - : header(std::move(header)), - service_instance(std::move(service_instance)), - domain_name(std::move(domain_name)), - port(port) {} -SrvEvent::SrvEvent(SrvEvent&&) noexcept = default; -SrvEvent::~SrvEvent() = default; -SrvEvent& SrvEvent::operator=(SrvEvent&&) noexcept = default; - -TxtEvent::TxtEvent() = default; -TxtEvent::TxtEvent(QueryEventHeader header, - DomainName service_instance, - std::vector<std::string> txt_info) - : header(std::move(header)), - service_instance(std::move(service_instance)), - txt_info(std::move(txt_info)) {} -TxtEvent::TxtEvent(TxtEvent&&) noexcept = default; -TxtEvent::~TxtEvent() = default; -TxtEvent& TxtEvent::operator=(TxtEvent&&) noexcept = default; - -MdnsResponderAdapter::MdnsResponderAdapter() = default; -MdnsResponderAdapter::~MdnsResponderAdapter() = default; - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/discovery/mdns/mdns_responder_adapter.h b/osp/impl/discovery/mdns/mdns_responder_adapter.h deleted file mode 100644 index 66083d57..00000000 --- a/osp/impl/discovery/mdns/mdns_responder_adapter.h +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_ADAPTER_H_ -#define OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_ADAPTER_H_ - -#include <cstdint> -#include <map> -#include <string> -#include <vector> - -#include "osp/impl/discovery/mdns/domain_name.h" -#include "osp/impl/discovery/mdns/mdns_responder_platform.h" -#include "platform/api/network_interface.h" -#include "platform/api/time.h" -#include "platform/api/udp_socket.h" -#include "platform/base/error.h" -#include "platform/base/ip_address.h" - -namespace openscreen { -namespace osp { - -struct QueryEventHeader { - enum class Type { - kAdded = 0, - kAddedNoCache, - kRemoved, - }; - - QueryEventHeader(); - QueryEventHeader(Type response_type, UdpSocket* socket); - QueryEventHeader(QueryEventHeader&&) noexcept; - ~QueryEventHeader(); - QueryEventHeader& operator=(QueryEventHeader&&) noexcept; - - Type response_type; - UdpSocket* socket; -}; - -struct PtrEvent { - PtrEvent(); - PtrEvent(QueryEventHeader header, DomainName service_instance); - PtrEvent(PtrEvent&&) noexcept; - ~PtrEvent(); - PtrEvent& operator=(PtrEvent&&) noexcept; - - QueryEventHeader header; - DomainName service_instance; -}; - -struct SrvEvent { - SrvEvent(); - SrvEvent(QueryEventHeader header, - DomainName service_instance, - DomainName domain_name, - uint16_t port); - SrvEvent(SrvEvent&&) noexcept; - ~SrvEvent(); - SrvEvent& operator=(SrvEvent&&) noexcept; - - QueryEventHeader header; - DomainName service_instance; - DomainName domain_name; - uint16_t port; -}; - -struct TxtEvent { - TxtEvent(); - TxtEvent(QueryEventHeader header, - DomainName service_instance, - std::vector<std::string> txt_info); - TxtEvent(TxtEvent&&) noexcept; - ~TxtEvent(); - TxtEvent& operator=(TxtEvent&&) noexcept; - - QueryEventHeader header; - DomainName service_instance; - - // NOTE: mDNS does not specify a character encoding for the data in TXT - // records. - std::vector<std::string> txt_info; -}; - -struct AEvent { - AEvent(); - AEvent(QueryEventHeader header, DomainName domain_name, IPAddress address); - AEvent(AEvent&&) noexcept; - ~AEvent(); - AEvent& operator=(AEvent&&) noexcept; - - QueryEventHeader header; - DomainName domain_name; - IPAddress address; -}; - -struct AaaaEvent { - AaaaEvent(); - AaaaEvent(QueryEventHeader header, DomainName domain_name, IPAddress address); - AaaaEvent(AaaaEvent&&) noexcept; - ~AaaaEvent(); - AaaaEvent& operator=(AaaaEvent&&) noexcept; - - QueryEventHeader header; - DomainName domain_name; - IPAddress address; -}; - -enum class MdnsResponderErrorCode { - kNoError = 0, - kUnsupportedError, - kDomainOverflowError, - kInvalidParameters, - kUnknownError, -}; - -// This interface wraps all the functionality of mDNSResponder, which includes -// both listening and publishing. As a result, some methods are only used by -// listeners, some are only used by publishers, and some are used by both. -// -// Listening for records might look like this: -// adapter->Init(); -// -// // Once for each interface, the meaning of false is described below. -// adapter->RegisterInterface(..., false); -// -// adapter->StartPtrQuery("_openscreen._udp"); -// adapter->RunTasks(); -// -// // When receiving multicast UDP traffic from port 5353. -// adapter->OnDataReceived(...); -// adapter->RunTasks(); -// -// // Check |ptrs| for responses after pulling. -// auto ptrs = adapter->TakePtrResponses(); -// -// // Eventually... -// adapter->StopPtrQuery("_openscreen._udp"); -// -// Publishing a service might look like this: -// adapter->Init(); -// -// // Once for each interface, the meaning of true is described below. -// adapter->RegisterInterface(..., true); -// -// adapter->SetHostLabel("deadbeef"); -// adapter->RegisterService("living-room", "_openscreen._udp", ...); -// adapter->RunTasks(); -// -// // When receiving multicast UDP traffic from port 5353. -// adapter->OnDataReceived(...); -// adapter->RunTasks(); -// -// // Eventually... -// adapter->DeregisterService("living-room", "_openscreen", "_udp"); -// -// Additionally, it's important to understand that mDNSResponder may defer some -// tasks (e.g. parsing responses, sending queries, etc.) and those deferred -// tasks are only run when RunTasks is called. Therefore, RunTasks should be -// called after any sequence of calls to mDNSResponder. It also returns a -// timeout value, after which it must be called again (e.g. for maintaining its -// cache). -class MdnsResponderAdapter : public UdpSocket::Client { - public: - MdnsResponderAdapter(); - virtual ~MdnsResponderAdapter() = 0; - - // Initializes mDNSResponder. This should be called before any queries or - // service registrations are made. - virtual Error Init() = 0; - - // Stops all open queries and service registrations. If this is not called - // before destruction, any registered services will not send their goodbye - // messages. - virtual void Close() = 0; - - // Called to change the name published by the A and AAAA records for the host - // when any service is active (via RegisterService). Returns true if the - // label was set successfully, false otherwise (e.g. the label did not meet - // DNS name requirements). - virtual Error SetHostLabel(const std::string& host_label) = 0; - - // The following methods register and deregister a network interface with - // mDNSResponder. |socket| will be used to identify which interface received - // the data in OnDataReceived and will be used to send data via the platform - // layer. - virtual Error RegisterInterface(const InterfaceInfo& interface_info, - const IPSubnet& interface_address, - UdpSocket* socket) = 0; - virtual Error DeregisterInterface(UdpSocket* socket) = 0; - - // Returns the time period after which this method must be called again, if - // any. - virtual Clock::duration RunTasks() = 0; - - virtual std::vector<PtrEvent> TakePtrResponses() = 0; - virtual std::vector<SrvEvent> TakeSrvResponses() = 0; - virtual std::vector<TxtEvent> TakeTxtResponses() = 0; - virtual std::vector<AEvent> TakeAResponses() = 0; - virtual std::vector<AaaaEvent> TakeAaaaResponses() = 0; - - virtual MdnsResponderErrorCode StartPtrQuery( - UdpSocket* socket, - const DomainName& service_type) = 0; - virtual MdnsResponderErrorCode StartSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) = 0; - virtual MdnsResponderErrorCode StartTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) = 0; - virtual MdnsResponderErrorCode StartAQuery(UdpSocket* socket, - const DomainName& domain_name) = 0; - virtual MdnsResponderErrorCode StartAaaaQuery( - UdpSocket* socket, - const DomainName& domain_name) = 0; - - virtual MdnsResponderErrorCode StopPtrQuery( - UdpSocket* socket, - const DomainName& service_type) = 0; - virtual MdnsResponderErrorCode StopSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) = 0; - virtual MdnsResponderErrorCode StopTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) = 0; - virtual MdnsResponderErrorCode StopAQuery(UdpSocket* socket, - const DomainName& domain_name) = 0; - virtual MdnsResponderErrorCode StopAaaaQuery( - UdpSocket* socket, - const DomainName& domain_name) = 0; - - // The following methods concern advertising a service via mDNS. The - // arguments correspond to values needed in the PTR, SRV, and TXT records that - // will be published for the service. An A or AAAA record will also be - // published with the service for each active interface known to mDNSResponder - // via RegisterInterface. - virtual MdnsResponderErrorCode RegisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const DomainName& target_host, - uint16_t target_port, - const std::map<std::string, std::string>& txt_data) = 0; - virtual MdnsResponderErrorCode DeregisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol) = 0; - virtual MdnsResponderErrorCode UpdateTxtData( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const std::map<std::string, std::string>& txt_data) = 0; -}; - -} // namespace osp -} // namespace openscreen - -#endif // OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_ADAPTER_H_ diff --git a/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc b/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc deleted file mode 100644 index 205e125b..00000000 --- a/osp/impl/discovery/mdns/mdns_responder_adapter_impl.cc +++ /dev/null @@ -1,1044 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/discovery/mdns/mdns_responder_adapter_impl.h" - -#include <algorithm> -#include <cctype> -#include <cstring> -#include <iostream> -#include <memory> -#include <string> -#include <utility> - -#include "util/osp_logging.h" -#include "util/trace_logging.h" - -namespace openscreen { -namespace osp { -namespace { - -// RFC 1035 specifies a max string length of 256, including the leading length -// octet. -constexpr size_t kMaxDnsStringLength = 255; - -// RFC 6763 recommends a maximum key length of 9 characters. -constexpr size_t kMaxTxtKeyLength = 9; - -constexpr size_t kMaxStaticTxtDataSize = 256; - -static_assert(sizeof(std::declval<RData>().u.txt) == kMaxStaticTxtDataSize, - "mDNSResponder static TXT data size expected to be 256 bytes"); - -static_assert(sizeof(mDNSAddr::ip.v4.b) == 4u, - "mDNSResponder IPv4 address must be 4 bytes"); -static_assert(sizeof(mDNSAddr::ip.v6.b) == 16u, - "mDNSResponder IPv6 address must be 16 bytes"); - -void AssignMdnsPort(mDNSIPPort* mdns_port, uint16_t port) { - mdns_port->b[0] = (port >> 8) & 0xff; - mdns_port->b[1] = port & 0xff; -} - -uint16_t GetNetworkOrderPort(const mDNSOpaque16& port) { - return port.b[0] << 8 | port.b[1]; -} - -bool IsValidServiceName(const std::string& service_name) { - // Service name requirements come from RFC 6335: - // - No more than 16 characters. - // - Begin with '_'. - // - Next is a letter or digit and end with a letter or digit. - // - May contain hyphens, but no consecutive hyphens. - // - Must contain at least one letter. - if (service_name.size() <= 1 || service_name.size() > 16) - return false; - - if (service_name[0] != '_' || !std::isalnum(service_name[1]) || - !std::isalnum(service_name.back())) { - return false; - } - bool has_alpha = false; - bool previous_hyphen = false; - for (auto it = service_name.begin() + 1; it != service_name.end(); ++it) { - if (*it == '-' && previous_hyphen) - return false; - - previous_hyphen = *it == '-'; - has_alpha = has_alpha || std::isalpha(*it); - } - return has_alpha && !previous_hyphen; -} - -bool IsValidServiceProtocol(const std::string& protocol) { - // RFC 6763 requires _tcp be used for TCP services and _udp for all others. - return protocol == "_tcp" || protocol == "_udp"; -} - -void MakeLocalServiceNameParts(const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - domainlabel* instance, - domainlabel* name, - domainlabel* protocol, - domainname* type, - domainname* domain) { - MakeDomainLabelFromLiteralString(instance, service_instance.c_str()); - MakeDomainLabelFromLiteralString(name, service_name.c_str()); - MakeDomainLabelFromLiteralString(protocol, service_protocol.c_str()); - type->c[0] = 0; - AppendDomainLabel(type, name); - AppendDomainLabel(type, protocol); - const DomainName local_domain = DomainName::GetLocalDomain(); - std::copy(local_domain.domain_name().begin(), - local_domain.domain_name().end(), domain->c); -} - -void MakeSubnetMaskFromPrefixLengthV4(uint8_t mask[4], uint8_t prefix_length) { - for (int i = 0; i < 4; prefix_length -= 8, ++i) { - if (prefix_length >= 8) { - mask[i] = 0xff; - } else if (prefix_length > 0) { - mask[i] = 0xff << (8 - prefix_length); - } else { - mask[i] = 0; - } - } -} - -void MakeSubnetMaskFromPrefixLengthV6(uint8_t mask[16], uint8_t prefix_length) { - for (int i = 0; i < 16; prefix_length -= 8, ++i) { - if (prefix_length >= 8) { - mask[i] = 0xff; - } else if (prefix_length > 0) { - mask[i] = 0xff << (8 - prefix_length); - } else { - mask[i] = 0; - } - } -} - -bool IsValidTxtDataKey(const std::string& s) { - if (s.size() > kMaxTxtKeyLength) - return false; - for (unsigned char c : s) - if (c < 0x20 || c > 0x7e || c == '=') - return false; - return true; -} - -std::string MakeTxtData(const std::map<std::string, std::string>& txt_data) { - std::string txt; - txt.reserve(kMaxStaticTxtDataSize); - for (const auto& line : txt_data) { - const auto key_size = line.first.size(); - const auto value_size = line.second.size(); - const auto line_size = value_size ? (key_size + 1 + value_size) : key_size; - if (!IsValidTxtDataKey(line.first) || line_size > kMaxDnsStringLength || - (txt.size() + 1 + line_size) > kMaxStaticTxtDataSize) { - return {}; - } - txt.push_back(line_size); - txt += line.first; - if (value_size) { - txt.push_back('='); - txt += line.second; - } - } - return txt; -} - -MdnsResponderErrorCode MapMdnsError(int err) { - switch (err) { - case mStatus_NoError: - return MdnsResponderErrorCode::kNoError; - case mStatus_UnsupportedErr: - return MdnsResponderErrorCode::kUnsupportedError; - case mStatus_UnknownErr: - return MdnsResponderErrorCode::kUnknownError; - default: - break; - } - OSP_DLOG_WARN << "unmapped mDNSResponder error: " << err; - return MdnsResponderErrorCode::kUnknownError; -} - -std::vector<std::string> ParseTxtResponse( - const uint8_t data[kMaxStaticTxtDataSize], - uint16_t length) { - OSP_DCHECK(length <= kMaxStaticTxtDataSize); - if (length == 0) - return {}; - - std::vector<std::string> lines; - int total_pos = 0; - while (total_pos < length) { - uint8_t line_length = data[total_pos]; - if ((line_length > kMaxDnsStringLength) || - (total_pos + line_length >= length)) { - return {}; - } - lines.emplace_back(&data[total_pos + 1], - &data[total_pos + line_length + 1]); - total_pos += line_length + 1; - } - return lines; -} - -void MdnsStatusCallback(mDNS* mdns, mStatus result) { - OSP_LOG_INFO << "status good? " << (result == mStatus_NoError); -} - -} // namespace - -MdnsResponderAdapterImpl::MdnsResponderAdapterImpl() = default; -MdnsResponderAdapterImpl::~MdnsResponderAdapterImpl() = default; - -Error MdnsResponderAdapterImpl::Init() { - const auto err = - mDNS_Init(&mdns_, &platform_storage_, rr_cache_, kRrCacheSize, - mDNS_Init_DontAdvertiseLocalAddresses, &MdnsStatusCallback, - mDNS_Init_NoInitCallbackContext); - - return (err == mStatus_NoError) ? Error::None() - : Error::Code::kInitializationFailure; -} - -void MdnsResponderAdapterImpl::Close() { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::Close"); - mDNS_StartExit(&mdns_); - // Let all services send goodbyes. - while (!service_records_.empty()) { - RunTasks(); - } - mDNS_FinalExit(&mdns_); - - socket_to_questions_.clear(); - - responder_interface_info_.clear(); - - a_responses_.clear(); - aaaa_responses_.clear(); - ptr_responses_.clear(); - srv_responses_.clear(); - txt_responses_.clear(); - - service_records_.clear(); -} - -Error MdnsResponderAdapterImpl::SetHostLabel(const std::string& host_label) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::SetHostLabel"); - if (host_label.size() > DomainName::kDomainNameMaxLabelLength) - return Error::Code::kDomainNameTooLong; - - MakeDomainLabelFromLiteralString(&mdns_.hostlabel, host_label.c_str()); - mDNS_SetFQDN(&mdns_); - if (!service_records_.empty()) { - DeadvertiseInterfaces(); - AdvertiseInterfaces(); - } - return Error::None(); -} - -Error MdnsResponderAdapterImpl::RegisterInterface( - const InterfaceInfo& interface_info, - const IPSubnet& interface_address, - UdpSocket* socket) { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::RegisterInterface"); - OSP_DCHECK(socket); - - const auto info_it = responder_interface_info_.find(socket); - if (info_it != responder_interface_info_.end()) - return Error::None(); - - NetworkInterfaceInfo& info = responder_interface_info_[socket]; - std::memset(&info, 0, sizeof(NetworkInterfaceInfo)); - info.InterfaceID = reinterpret_cast<decltype(info.InterfaceID)>(socket); - info.Advertise = mDNSfalse; - if (interface_address.address.IsV4()) { - info.ip.type = mDNSAddrType_IPv4; - interface_address.address.CopyToV4(info.ip.ip.v4.b); - info.mask.type = mDNSAddrType_IPv4; - MakeSubnetMaskFromPrefixLengthV4(info.mask.ip.v4.b, - interface_address.prefix_length); - } else { - info.ip.type = mDNSAddrType_IPv6; - interface_address.address.CopyToV6(info.ip.ip.v6.b); - info.mask.type = mDNSAddrType_IPv6; - MakeSubnetMaskFromPrefixLengthV6(info.mask.ip.v6.b, - interface_address.prefix_length); - } - - static_assert(sizeof(info.MAC.b) == sizeof(interface_info.hardware_address), - "MAC address size mismatch."); - memcpy(info.MAC.b, interface_info.hardware_address.data(), - sizeof(info.MAC.b)); - info.McastTxRx = 1; - platform_storage_.sockets.push_back(socket); - auto result = mDNS_RegisterInterface(&mdns_, &info, mDNSfalse); - OSP_LOG_IF(WARN, result != mStatus_NoError) - << "mDNS_RegisterInterface failed: " << result; - - return (result == mStatus_NoError) ? Error::None() - : Error::Code::kMdnsRegisterFailure; -} - -Error MdnsResponderAdapterImpl::DeregisterInterface(UdpSocket* socket) { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::DeregisterInterface"); - const auto info_it = responder_interface_info_.find(socket); - if (info_it == responder_interface_info_.end()) - return Error::Code::kItemNotFound; - - const auto it = std::find(platform_storage_.sockets.begin(), - platform_storage_.sockets.end(), socket); - OSP_DCHECK(it != platform_storage_.sockets.end()); - platform_storage_.sockets.erase(it); - if (info_it->second.RR_A.namestorage.c[0]) { - mDNS_Deregister(&mdns_, &info_it->second.RR_A); - info_it->second.RR_A.namestorage.c[0] = 0; - } - mDNS_DeregisterInterface(&mdns_, &info_it->second, mDNSfalse); - responder_interface_info_.erase(info_it); - return Error::None(); -} -void MdnsResponderAdapterImpl::OnRead(UdpSocket* socket, - ErrorOr<UdpPacket> packet_or_error) { - if (packet_or_error.is_error()) { - return; - } - - UdpPacket packet = std::move(packet_or_error.value()); - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::OnRead"); - mDNSAddr src; - if (packet.source().address.IsV4()) { - src.type = mDNSAddrType_IPv4; - packet.source().address.CopyToV4(src.ip.v4.b); - } else { - src.type = mDNSAddrType_IPv6; - packet.source().address.CopyToV6(src.ip.v6.b); - } - mDNSIPPort srcport; - AssignMdnsPort(&srcport, packet.source().port); - - mDNSAddr dst; - if (packet.source().address.IsV4()) { - dst.type = mDNSAddrType_IPv4; - packet.destination().address.CopyToV4(dst.ip.v4.b); - } else { - dst.type = mDNSAddrType_IPv6; - packet.destination().address.CopyToV6(dst.ip.v6.b); - } - mDNSIPPort dstport; - AssignMdnsPort(&dstport, packet.destination().port); - - auto* packet_data = packet.data(); - mDNSCoreReceive(&mdns_, const_cast<uint8_t*>(packet_data), - packet_data + packet.size(), &src, srcport, &dst, dstport, - reinterpret_cast<mDNSInterfaceID>(packet.socket())); -} - -void MdnsResponderAdapterImpl::OnSendError(UdpSocket* socket, Error error) { - // TODO(crbug.com/openscreen/67): Implement this method. - OSP_UNIMPLEMENTED(); -} - -void MdnsResponderAdapterImpl::OnError(UdpSocket* socket, Error error) { - // TODO(crbug.com/openscreen/67): Implement this method. - OSP_UNIMPLEMENTED(); -} - -void MdnsResponderAdapterImpl::OnBound(UdpSocket* socket) { - // TODO(crbug.com/openscreen/67): Implement this method. - OSP_UNIMPLEMENTED(); -} - -Clock::duration MdnsResponderAdapterImpl::RunTasks() { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::RunTasks"); - - mDNS_Execute(&mdns_); - - // Using mDNS_Execute's response to determine the correct timespan before - // re-running this method doesn't work as expected. In the demo, under some - // cases (about 25% of demo runs), the response is set to an unreasonably - // large number (in the order of multiple days). - // - // From the mDNS documentation: "it is the responsibility [...] to set the - // timer according to the m->NextScheduledEvent value, and then when the timer - // fires, the timer callback function should call mDNS_Execute()" - for more - // details see third_party/mDNSResponder/src/mDNSCore/mDNS.c : 3390 - // - // Together, I understand these to mean that the mdns library code doesn't - // expect we need mDNS_Execute called again by the task runner, only in the - // other special cases it calls out in documentation (which we currently do - // correctly). In our code, when we call mDNS_Execute again outside of the - // task runner, the result is currently discarded. What we would need to do is - // reach into the Task Runner's task and update how long before the task runs - // again. That would require some large refactoring and changes. - // - // Additionally, beyond this, the mDNS code documents that there are cases - // where the return value for mDNS_Execute should be ignored because it may be - // stale. - // - // TODO(rwkeane): More accurately determine when the next run of this method - // should be. - constexpr auto seconds_before_next_run = 1; - - // Return as a duration. - return std::chrono::seconds(seconds_before_next_run); -} - -std::vector<PtrEvent> MdnsResponderAdapterImpl::TakePtrResponses() { - return std::move(ptr_responses_); -} - -std::vector<SrvEvent> MdnsResponderAdapterImpl::TakeSrvResponses() { - return std::move(srv_responses_); -} - -std::vector<TxtEvent> MdnsResponderAdapterImpl::TakeTxtResponses() { - return std::move(txt_responses_); -} - -std::vector<AEvent> MdnsResponderAdapterImpl::TakeAResponses() { - return std::move(a_responses_); -} - -std::vector<AaaaEvent> MdnsResponderAdapterImpl::TakeAaaaResponses() { - return std::move(aaaa_responses_); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StartPtrQuery( - UdpSocket* socket, - const DomainName& service_type) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartPtrQuery"); - auto& ptr_questions = socket_to_questions_[socket].ptr; - if (ptr_questions.find(service_type) != ptr_questions.end()) - return MdnsResponderErrorCode::kNoError; - - auto& question = ptr_questions[service_type]; - - question.InterfaceID = reinterpret_cast<mDNSInterfaceID>(socket); - question.Target = {0}; - if (service_type.EndsWithLocalDomain()) { - std::copy(service_type.domain_name().begin(), - service_type.domain_name().end(), question.qname.c); - } else { - const DomainName local_domain = DomainName::GetLocalDomain(); - ErrorOr<DomainName> service_type_with_local = - DomainName::Append(service_type, local_domain); - if (!service_type_with_local) { - return MdnsResponderErrorCode::kDomainOverflowError; - } - std::copy(service_type_with_local.value().domain_name().begin(), - service_type_with_local.value().domain_name().end(), - question.qname.c); - } - question.qtype = kDNSType_PTR; - question.qclass = kDNSClass_IN; - question.LongLived = mDNStrue; - question.ExpectUnique = mDNSfalse; - question.ForceMCast = mDNStrue; - question.ReturnIntermed = mDNSfalse; - question.SuppressUnusable = mDNSfalse; - question.RetryWithSearchDomains = mDNSfalse; - question.TimeoutQuestion = 0; - question.WakeOnResolve = 0; - question.SearchListIndex = 0; - question.AppendSearchDomains = 0; - question.AppendLocalSearchDomains = 0; - question.qnameOrig = nullptr; - question.QuestionCallback = &MdnsResponderAdapterImpl::PtrQueryCallback; - question.QuestionContext = this; - const auto err = mDNS_StartQuery(&mdns_, &question); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StartQuery failed: " << err; - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StartSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartSrvQuery"); - if (!service_instance.EndsWithLocalDomain()) - return MdnsResponderErrorCode::kInvalidParameters; - - auto& srv_questions = socket_to_questions_[socket].srv; - if (srv_questions.find(service_instance) != srv_questions.end()) - return MdnsResponderErrorCode::kNoError; - - auto& question = srv_questions[service_instance]; - - question.InterfaceID = reinterpret_cast<mDNSInterfaceID>(socket); - question.Target = {0}; - std::copy(service_instance.domain_name().begin(), - service_instance.domain_name().end(), question.qname.c); - question.qtype = kDNSType_SRV; - question.qclass = kDNSClass_IN; - question.LongLived = mDNStrue; - question.ExpectUnique = mDNSfalse; - question.ForceMCast = mDNStrue; - question.ReturnIntermed = mDNSfalse; - question.SuppressUnusable = mDNSfalse; - question.RetryWithSearchDomains = mDNSfalse; - question.TimeoutQuestion = 0; - question.WakeOnResolve = 0; - question.SearchListIndex = 0; - question.AppendSearchDomains = 0; - question.AppendLocalSearchDomains = 0; - question.qnameOrig = nullptr; - question.QuestionCallback = &MdnsResponderAdapterImpl::SrvQueryCallback; - question.QuestionContext = this; - const auto err = mDNS_StartQuery(&mdns_, &question); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StartQuery failed: " << err; - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StartTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartTxtQuery"); - if (!service_instance.EndsWithLocalDomain()) - return MdnsResponderErrorCode::kInvalidParameters; - - auto& txt_questions = socket_to_questions_[socket].txt; - if (txt_questions.find(service_instance) != txt_questions.end()) - return MdnsResponderErrorCode::kNoError; - - auto& question = txt_questions[service_instance]; - - question.InterfaceID = reinterpret_cast<mDNSInterfaceID>(socket); - question.Target = {0}; - std::copy(service_instance.domain_name().begin(), - service_instance.domain_name().end(), question.qname.c); - question.qtype = kDNSType_TXT; - question.qclass = kDNSClass_IN; - question.LongLived = mDNStrue; - question.ExpectUnique = mDNSfalse; - question.ForceMCast = mDNStrue; - question.ReturnIntermed = mDNSfalse; - question.SuppressUnusable = mDNSfalse; - question.RetryWithSearchDomains = mDNSfalse; - question.TimeoutQuestion = 0; - question.WakeOnResolve = 0; - question.SearchListIndex = 0; - question.AppendSearchDomains = 0; - question.AppendLocalSearchDomains = 0; - question.qnameOrig = nullptr; - question.QuestionCallback = &MdnsResponderAdapterImpl::TxtQueryCallback; - question.QuestionContext = this; - const auto err = mDNS_StartQuery(&mdns_, &question); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StartQuery failed: " << err; - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAQuery( - UdpSocket* socket, - const DomainName& domain_name) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StartAQuery"); - if (!domain_name.EndsWithLocalDomain()) - return MdnsResponderErrorCode::kInvalidParameters; - - auto& a_questions = socket_to_questions_[socket].a; - if (a_questions.find(domain_name) != a_questions.end()) - return MdnsResponderErrorCode::kNoError; - - auto& question = a_questions[domain_name]; - std::copy(domain_name.domain_name().begin(), domain_name.domain_name().end(), - question.qname.c); - - question.InterfaceID = reinterpret_cast<mDNSInterfaceID>(socket); - question.Target = {0}; - question.qtype = kDNSType_A; - question.qclass = kDNSClass_IN; - question.LongLived = mDNStrue; - question.ExpectUnique = mDNSfalse; - question.ForceMCast = mDNStrue; - question.ReturnIntermed = mDNSfalse; - question.SuppressUnusable = mDNSfalse; - question.RetryWithSearchDomains = mDNSfalse; - question.TimeoutQuestion = 0; - question.WakeOnResolve = 0; - question.SearchListIndex = 0; - question.AppendSearchDomains = 0; - question.AppendLocalSearchDomains = 0; - question.qnameOrig = nullptr; - question.QuestionCallback = &MdnsResponderAdapterImpl::AQueryCallback; - question.QuestionContext = this; - const auto err = mDNS_StartQuery(&mdns_, &question); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StartQuery failed: " << err; - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StartAaaaQuery( - UdpSocket* socket, - const DomainName& domain_name) { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::StartAaaaQuery"); - if (!domain_name.EndsWithLocalDomain()) - return MdnsResponderErrorCode::kInvalidParameters; - - auto& aaaa_questions = socket_to_questions_[socket].aaaa; - if (aaaa_questions.find(domain_name) != aaaa_questions.end()) - return MdnsResponderErrorCode::kNoError; - - auto& question = aaaa_questions[domain_name]; - std::copy(domain_name.domain_name().begin(), domain_name.domain_name().end(), - question.qname.c); - - question.InterfaceID = reinterpret_cast<mDNSInterfaceID>(socket); - question.Target = {0}; - question.qtype = kDNSType_AAAA; - question.qclass = kDNSClass_IN; - question.LongLived = mDNStrue; - question.ExpectUnique = mDNSfalse; - question.ForceMCast = mDNStrue; - question.ReturnIntermed = mDNSfalse; - question.SuppressUnusable = mDNSfalse; - question.RetryWithSearchDomains = mDNSfalse; - question.TimeoutQuestion = 0; - question.WakeOnResolve = 0; - question.SearchListIndex = 0; - question.AppendSearchDomains = 0; - question.AppendLocalSearchDomains = 0; - question.qnameOrig = nullptr; - question.QuestionCallback = &MdnsResponderAdapterImpl::AaaaQueryCallback; - question.QuestionContext = this; - const auto err = mDNS_StartQuery(&mdns_, &question); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StartQuery failed: " << err; - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StopPtrQuery( - UdpSocket* socket, - const DomainName& service_type) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopPtrQuery"); - auto interface_entry = socket_to_questions_.find(socket); - if (interface_entry == socket_to_questions_.end()) - return MdnsResponderErrorCode::kNoError; - auto entry = interface_entry->second.ptr.find(service_type); - if (entry == interface_entry->second.ptr.end()) - return MdnsResponderErrorCode::kNoError; - - const auto err = mDNS_StopQuery(&mdns_, &entry->second); - interface_entry->second.ptr.erase(entry); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StopQuery failed: " << err; - RemoveQuestionsIfEmpty(socket); - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StopSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopSrvQuery"); - auto interface_entry = socket_to_questions_.find(socket); - if (interface_entry == socket_to_questions_.end()) - return MdnsResponderErrorCode::kNoError; - auto entry = interface_entry->second.srv.find(service_instance); - if (entry == interface_entry->second.srv.end()) - return MdnsResponderErrorCode::kNoError; - - const auto err = mDNS_StopQuery(&mdns_, &entry->second); - interface_entry->second.srv.erase(entry); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StopQuery failed: " << err; - RemoveQuestionsIfEmpty(socket); - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StopTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopTxtQuery"); - auto interface_entry = socket_to_questions_.find(socket); - if (interface_entry == socket_to_questions_.end()) - return MdnsResponderErrorCode::kNoError; - auto entry = interface_entry->second.txt.find(service_instance); - if (entry == interface_entry->second.txt.end()) - return MdnsResponderErrorCode::kNoError; - - const auto err = mDNS_StopQuery(&mdns_, &entry->second); - interface_entry->second.txt.erase(entry); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StopQuery failed: " << err; - RemoveQuestionsIfEmpty(socket); - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAQuery( - UdpSocket* socket, - const DomainName& domain_name) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopAQuery"); - auto interface_entry = socket_to_questions_.find(socket); - if (interface_entry == socket_to_questions_.end()) - return MdnsResponderErrorCode::kNoError; - auto entry = interface_entry->second.a.find(domain_name); - if (entry == interface_entry->second.a.end()) - return MdnsResponderErrorCode::kNoError; - - const auto err = mDNS_StopQuery(&mdns_, &entry->second); - interface_entry->second.a.erase(entry); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StopQuery failed: " << err; - RemoveQuestionsIfEmpty(socket); - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::StopAaaaQuery( - UdpSocket* socket, - const DomainName& domain_name) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::StopAaaaQuery"); - auto interface_entry = socket_to_questions_.find(socket); - if (interface_entry == socket_to_questions_.end()) - return MdnsResponderErrorCode::kNoError; - auto entry = interface_entry->second.aaaa.find(domain_name); - if (entry == interface_entry->second.aaaa.end()) - return MdnsResponderErrorCode::kNoError; - - const auto err = mDNS_StopQuery(&mdns_, &entry->second); - interface_entry->second.aaaa.erase(entry); - OSP_LOG_IF(WARN, err != mStatus_NoError) << "mDNS_StopQuery failed: " << err; - RemoveQuestionsIfEmpty(socket); - return MapMdnsError(err); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::RegisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const DomainName& target_host, - uint16_t target_port, - const std::map<std::string, std::string>& txt_data) { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::RegisterService"); - OSP_DCHECK(IsValidServiceName(service_name)); - OSP_DCHECK(IsValidServiceProtocol(service_protocol)); - service_records_.push_back(std::make_unique<ServiceRecordSet>()); - auto* service_record = service_records_.back().get(); - domainlabel instance; - domainlabel name; - domainlabel protocol; - domainname type; - domainname domain; - domainname host; - mDNSIPPort port; - - MakeLocalServiceNameParts(service_instance, service_name, service_protocol, - &instance, &name, &protocol, &type, &domain); - std::copy(target_host.domain_name().begin(), target_host.domain_name().end(), - host.c); - AssignMdnsPort(&port, target_port); - auto txt = MakeTxtData(txt_data); - if (txt.size() > kMaxStaticTxtDataSize) { - // Not handling oversized TXT records. - return MdnsResponderErrorCode::kUnsupportedError; - } - - if (service_records_.size() == 1) - AdvertiseInterfaces(); - - auto result = mDNS_RegisterService( - &mdns_, service_record, &instance, &type, &domain, &host, port, - reinterpret_cast<const uint8_t*>(txt.data()), txt.size(), nullptr, 0, - mDNSInterface_Any, &MdnsResponderAdapterImpl::ServiceCallback, this, 0); - - if (result != mStatus_NoError) { - service_records_.pop_back(); - if (service_records_.empty()) - DeadvertiseInterfaces(); - } - return MapMdnsError(result); -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::DeregisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol) { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::DeregisterService"); - domainlabel instance; - domainlabel name; - domainlabel protocol; - domainname type; - domainname domain; - domainname full_instance_name; - - MakeLocalServiceNameParts(service_instance, service_name, service_protocol, - &instance, &name, &protocol, &type, &domain); - if (!ConstructServiceName(&full_instance_name, &instance, &type, &domain)) - return MdnsResponderErrorCode::kInvalidParameters; - - for (auto it = service_records_.begin(); it != service_records_.end(); ++it) { - if (SameDomainName(&full_instance_name, &(*it)->RR_SRV.namestorage)) { - // |it| will be removed from |service_records_| in ServiceCallback, when - // mDNSResponder is done with the memory. - mDNS_DeregisterService(&mdns_, it->get()); - return MdnsResponderErrorCode::kNoError; - } - } - return MdnsResponderErrorCode::kNoError; -} - -MdnsResponderErrorCode MdnsResponderAdapterImpl::UpdateTxtData( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const std::map<std::string, std::string>& txt_data) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderAdapterImpl::UpdateTxtData"); - domainlabel instance; - domainlabel name; - domainlabel protocol; - domainname type; - domainname domain; - domainname full_instance_name; - - MakeLocalServiceNameParts(service_instance, service_name, service_protocol, - &instance, &name, &protocol, &type, &domain); - if (!ConstructServiceName(&full_instance_name, &instance, &type, &domain)) - return MdnsResponderErrorCode::kInvalidParameters; - std::string txt = MakeTxtData(txt_data); - if (txt.size() > kMaxStaticTxtDataSize) { - // Not handling oversized TXT records. - return MdnsResponderErrorCode::kUnsupportedError; - } - - for (std::unique_ptr<ServiceRecordSet>& record : service_records_) { - if (SameDomainName(&full_instance_name, &record->RR_SRV.namestorage)) { - std::copy(txt.begin(), txt.end(), record->RR_TXT.rdatastorage.u.txt.c); - mDNS_Update(&mdns_, &record->RR_TXT, 0, txt.size(), - &record->RR_TXT.rdatastorage, nullptr); - return MdnsResponderErrorCode::kNoError; - } - } - return MdnsResponderErrorCode::kNoError; -} - -// static -void MdnsResponderAdapterImpl::AQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added) { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::AQueryCallback"); - OSP_DCHECK(question); - OSP_DCHECK(answer); - OSP_DCHECK_EQ(answer->rrtype, kDNSType_A); - DomainName domain(std::vector<uint8_t>( - question->qname.c, - question->qname.c + DomainNameLength(&question->qname))); - IPAddress address(answer->rdata->u.ipv4.b); - - auto* adapter = - reinterpret_cast<MdnsResponderAdapterImpl*>(question->QuestionContext); - OSP_DCHECK(adapter); - auto event_type = QueryEventHeader::Type::kAddedNoCache; - if (added == QC_add) { - event_type = QueryEventHeader::Type::kAdded; - } else if (added == QC_rmv) { - event_type = QueryEventHeader::Type::kRemoved; - } else { - OSP_DCHECK_EQ(added, QC_addnocache); - } - adapter->a_responses_.emplace_back( - QueryEventHeader{event_type, - reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, - std::move(domain), address); -} - -// static -void MdnsResponderAdapterImpl::AaaaQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added) { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::AaaaQueryCallback"); - OSP_DCHECK(question); - OSP_DCHECK(answer); - OSP_DCHECK_EQ(answer->rrtype, kDNSType_A); - DomainName domain(std::vector<uint8_t>( - question->qname.c, - question->qname.c + DomainNameLength(&question->qname))); - IPAddress address(IPAddress::Version::kV6, answer->rdata->u.ipv6.b); - - auto* adapter = - reinterpret_cast<MdnsResponderAdapterImpl*>(question->QuestionContext); - OSP_DCHECK(adapter); - auto event_type = QueryEventHeader::Type::kAddedNoCache; - if (added == QC_add) { - event_type = QueryEventHeader::Type::kAdded; - } else if (added == QC_rmv) { - event_type = QueryEventHeader::Type::kRemoved; - } else { - OSP_DCHECK_EQ(added, QC_addnocache); - } - adapter->aaaa_responses_.emplace_back( - QueryEventHeader{event_type, - reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, - std::move(domain), address); -} - -// static -void MdnsResponderAdapterImpl::PtrQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added) { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::PtrQueryCallback"); - OSP_DCHECK(question); - OSP_DCHECK(answer); - OSP_DCHECK_EQ(answer->rrtype, kDNSType_PTR); - DomainName result(std::vector<uint8_t>( - answer->rdata->u.name.c, - answer->rdata->u.name.c + DomainNameLength(&answer->rdata->u.name))); - - auto* adapter = - reinterpret_cast<MdnsResponderAdapterImpl*>(question->QuestionContext); - OSP_DCHECK(adapter); - auto event_type = QueryEventHeader::Type::kAddedNoCache; - if (added == QC_add) { - event_type = QueryEventHeader::Type::kAdded; - } else if (added == QC_rmv) { - event_type = QueryEventHeader::Type::kRemoved; - } else { - OSP_DCHECK_EQ(added, QC_addnocache); - } - adapter->ptr_responses_.emplace_back( - QueryEventHeader{event_type, - reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, - std::move(result)); -} - -// static -void MdnsResponderAdapterImpl::SrvQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added) { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::SrvQueryCallback"); - OSP_DCHECK(question); - OSP_DCHECK(answer); - OSP_DCHECK_EQ(answer->rrtype, kDNSType_SRV); - DomainName service(std::vector<uint8_t>( - question->qname.c, - question->qname.c + DomainNameLength(&question->qname))); - DomainName result( - std::vector<uint8_t>(answer->rdata->u.srv.target.c, - answer->rdata->u.srv.target.c + - DomainNameLength(&answer->rdata->u.srv.target))); - - auto* adapter = - reinterpret_cast<MdnsResponderAdapterImpl*>(question->QuestionContext); - OSP_DCHECK(adapter); - auto event_type = QueryEventHeader::Type::kAddedNoCache; - if (added == QC_add) { - event_type = QueryEventHeader::Type::kAdded; - } else if (added == QC_rmv) { - event_type = QueryEventHeader::Type::kRemoved; - } else { - OSP_DCHECK_EQ(added, QC_addnocache); - } - adapter->srv_responses_.emplace_back( - QueryEventHeader{event_type, - reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, - std::move(service), std::move(result), - GetNetworkOrderPort(answer->rdata->u.srv.port)); -} - -// static -void MdnsResponderAdapterImpl::TxtQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added) { - OSP_DCHECK(question); - OSP_DCHECK(answer); - OSP_DCHECK_EQ(answer->rrtype, kDNSType_TXT); - DomainName service(std::vector<uint8_t>( - question->qname.c, - question->qname.c + DomainNameLength(&question->qname))); - auto lines = ParseTxtResponse(answer->rdata->u.txt.c, answer->rdlength); - - auto* adapter = - reinterpret_cast<MdnsResponderAdapterImpl*>(question->QuestionContext); - OSP_DCHECK(adapter); - auto event_type = QueryEventHeader::Type::kAddedNoCache; - if (added == QC_add) { - event_type = QueryEventHeader::Type::kAdded; - } else if (added == QC_rmv) { - event_type = QueryEventHeader::Type::kRemoved; - } else { - OSP_DCHECK_EQ(added, QC_addnocache); - } - adapter->txt_responses_.emplace_back( - QueryEventHeader{event_type, - reinterpret_cast<UdpSocket*>(answer->InterfaceID)}, - std::move(service), std::move(lines)); -} - -// static -void MdnsResponderAdapterImpl::ServiceCallback(mDNS* m, - ServiceRecordSet* service_record, - mStatus result) { - // TODO(btolsch): Handle mStatus_NameConflict. - if (result == mStatus_MemFree) { - OSP_DLOG_INFO << "free service record"; - auto* adapter = reinterpret_cast<MdnsResponderAdapterImpl*>( - service_record->ServiceContext); - auto& service_records = adapter->service_records_; - service_records.erase( - std::remove_if( - service_records.begin(), service_records.end(), - [service_record](const std::unique_ptr<ServiceRecordSet>& sr) { - return sr.get() == service_record; - }), - service_records.end()); - - if (service_records.empty()) - adapter->DeadvertiseInterfaces(); - } -} - -void MdnsResponderAdapterImpl::AdvertiseInterfaces() { - TRACE_SCOPED(TraceCategory::kMdns, - "MdnsResponderAdapterImpl::AdvertiseInterfaces"); - for (auto& info : responder_interface_info_) { - UdpSocket* socket = info.first; - NetworkInterfaceInfo& interface_info = info.second; - mDNS_SetupResourceRecord(&interface_info.RR_A, /** RDataStorage */ nullptr, - reinterpret_cast<mDNSInterfaceID>(socket), - kDNSType_A, kHostNameTTL, kDNSRecordTypeUnique, - AuthRecordAny, - /** Callback */ nullptr, /** Context */ nullptr); - AssignDomainName(&interface_info.RR_A.namestorage, - &mdns_.MulticastHostname); - if (interface_info.ip.type == mDNSAddrType_IPv4) { - interface_info.RR_A.resrec.rdata->u.ipv4 = interface_info.ip.ip.v4; - } else { - interface_info.RR_A.resrec.rdata->u.ipv6 = interface_info.ip.ip.v6; - } - mDNS_Register(&mdns_, &interface_info.RR_A); - } -} - -void MdnsResponderAdapterImpl::DeadvertiseInterfaces() { - // Both loops below use the A resource record's domain name to determine - // whether the record was advertised. AdvertiseInterfaces sets the domain - // name before registering the A record, and this clears it after - // deregistering. - for (auto& info : responder_interface_info_) { - NetworkInterfaceInfo& interface_info = info.second; - if (interface_info.RR_A.namestorage.c[0]) { - mDNS_Deregister(&mdns_, &interface_info.RR_A); - interface_info.RR_A.namestorage.c[0] = 0; - } - } -} - -void MdnsResponderAdapterImpl::RemoveQuestionsIfEmpty(UdpSocket* socket) { - auto entry = socket_to_questions_.find(socket); - bool empty = entry->second.a.empty() || entry->second.aaaa.empty() || - entry->second.ptr.empty() || entry->second.srv.empty() || - entry->second.txt.empty(); - if (empty) - socket_to_questions_.erase(entry); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h b/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h deleted file mode 100644 index d0dd55a1..00000000 --- a/osp/impl/discovery/mdns/mdns_responder_adapter_impl.h +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_ADAPTER_IMPL_H_ -#define OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_ADAPTER_IMPL_H_ - -#include <map> -#include <memory> -#include <string> -#include <vector> - -#include "osp/impl/discovery/mdns/mdns_responder_adapter.h" -#include "platform/api/udp_socket.h" -#include "platform/base/error.h" -#include "third_party/mDNSResponder/src/mDNSCore/mDNSEmbeddedAPI.h" - -namespace openscreen { -namespace osp { - -class MdnsResponderAdapterImpl final : public MdnsResponderAdapter { - public: - static constexpr int kRrCacheSize = 500; - - MdnsResponderAdapterImpl(); - ~MdnsResponderAdapterImpl() override; - - Error Init() override; - void Close() override; - - Error SetHostLabel(const std::string& host_label) override; - - Error RegisterInterface(const InterfaceInfo& interface_info, - const IPSubnet& interface_address, - UdpSocket* socket) override; - Error DeregisterInterface(UdpSocket* socket) override; - - void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; - void OnSendError(UdpSocket* socket, Error error) override; - void OnError(UdpSocket* socket, Error error) override; - void OnBound(UdpSocket* socket) override; - - Clock::duration RunTasks() override; - - std::vector<PtrEvent> TakePtrResponses() override; - std::vector<SrvEvent> TakeSrvResponses() override; - std::vector<TxtEvent> TakeTxtResponses() override; - std::vector<AEvent> TakeAResponses() override; - std::vector<AaaaEvent> TakeAaaaResponses() override; - - MdnsResponderErrorCode StartPtrQuery(UdpSocket* socket, - const DomainName& service_type) override; - MdnsResponderErrorCode StartSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) override; - MdnsResponderErrorCode StartTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) override; - MdnsResponderErrorCode StartAQuery(UdpSocket* socket, - const DomainName& domain_name) override; - MdnsResponderErrorCode StartAaaaQuery(UdpSocket* socket, - const DomainName& domain_name) override; - MdnsResponderErrorCode StopPtrQuery(UdpSocket* socket, - const DomainName& service_type) override; - MdnsResponderErrorCode StopSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) override; - MdnsResponderErrorCode StopTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) override; - MdnsResponderErrorCode StopAQuery(UdpSocket* socket, - const DomainName& domain_name) override; - MdnsResponderErrorCode StopAaaaQuery(UdpSocket* socket, - const DomainName& domain_name) override; - - MdnsResponderErrorCode RegisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const DomainName& target_host, - uint16_t target_port, - const std::map<std::string, std::string>& txt_data) override; - MdnsResponderErrorCode DeregisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol) override; - MdnsResponderErrorCode UpdateTxtData( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const std::map<std::string, std::string>& txt_data) override; - - private: - struct Questions { - std::map<DomainName, DNSQuestion, DomainNameComparator> a; - std::map<DomainName, DNSQuestion, DomainNameComparator> aaaa; - std::map<DomainName, DNSQuestion, DomainNameComparator> ptr; - std::map<DomainName, DNSQuestion, DomainNameComparator> srv; - std::map<DomainName, DNSQuestion, DomainNameComparator> txt; - }; - - static void AQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added); - static void AaaaQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added); - static void PtrQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added); - static void SrvQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added); - static void TxtQueryCallback(mDNS* m, - DNSQuestion* question, - const ResourceRecord* answer, - QC_result added); - static void ServiceCallback(mDNS* m, - ServiceRecordSet* service_record, - mStatus result); - - void AdvertiseInterfaces(); - void DeadvertiseInterfaces(); - void RemoveQuestionsIfEmpty(UdpSocket* socket); - - CacheEntity rr_cache_[kRrCacheSize]; - - // The main context structure for mDNSResponder. - mDNS mdns_; - - // Our own storage that is placed inside |mdns_|. The intent in C is to allow - // us access to our own state during callbacks. Here we just use it to group - // platform sockets. - mDNS_PlatformSupport platform_storage_; - - std::map<UdpSocket*, Questions> socket_to_questions_; - - std::map<UdpSocket*, NetworkInterfaceInfo> responder_interface_info_; - - std::vector<AEvent> a_responses_; - std::vector<AaaaEvent> aaaa_responses_; - std::vector<PtrEvent> ptr_responses_; - std::vector<SrvEvent> srv_responses_; - std::vector<TxtEvent> txt_responses_; - - // A list of services we are advertising. ServiceRecordSet is an - // mDNSResponder structure which holds all the resource record data - // (PTR/SRV/TXT/A and misc.) that is necessary to advertise a service. - std::vector<std::unique_ptr<ServiceRecordSet>> service_records_; -}; - -} // namespace osp -} // namespace openscreen - -#endif // OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_ADAPTER_IMPL_H_ diff --git a/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc b/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc deleted file mode 100644 index 29b76679..00000000 --- a/osp/impl/discovery/mdns/mdns_responder_adapter_impl_unittest.cc +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/discovery/mdns/mdns_responder_adapter_impl.h" - -#include <memory> -#include <string> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace openscreen { -namespace osp { -namespace { - -using ::testing::ElementsAre; -using ::testing::ElementsAreArray; - -// Example response for _openscreen._udp. Contains PTR, SRV, TXT, A records. -uint8_t data[] = { - 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x03, - 0x06, 0x74, 0x75, 0x72, 0x74, 0x6c, 0x65, 0x0b, 0x5f, 0x6f, 0x70, 0x65, - 0x6e, 0x73, 0x63, 0x72, 0x65, 0x65, 0x6e, 0x04, 0x5f, 0x75, 0x64, 0x70, - 0x05, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x00, 0x00, 0x10, 0x80, 0x01, 0x00, - 0x00, 0x11, 0x94, 0x00, 0x0e, 0x06, 0x79, 0x75, 0x72, 0x74, 0x6c, 0x65, - 0x06, 0x74, 0x75, 0x72, 0x74, 0x6c, 0x65, 0x09, 0x5f, 0x73, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x73, 0x07, 0x5f, 0x64, 0x6e, 0x73, 0x2d, 0x73, - 0x64, 0xc0, 0x1f, 0x00, 0x0c, 0x00, 0x01, 0x00, 0x00, 0x11, 0x94, 0x00, - 0x02, 0xc0, 0x13, 0xc0, 0x13, 0x00, 0x0c, 0x00, 0x01, 0x00, 0x00, 0x11, - 0x94, 0x00, 0x02, 0xc0, 0x0c, 0x11, 0x67, 0x69, 0x67, 0x6c, 0x69, 0x6f, - 0x72, 0x6f, 0x6e, 0x6f, 0x6e, 0x6f, 0x6d, 0x69, 0x63, 0x6f, 0x6e, 0xc0, - 0x24, 0x00, 0x01, 0x80, 0x01, 0x00, 0x00, 0x00, 0x78, 0x00, 0x04, 0xac, - 0x11, 0x20, 0x96, 0xc0, 0x0c, 0x00, 0x21, 0x80, 0x01, 0x00, 0x00, 0x00, - 0x78, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x30, 0x39, 0xc0, 0x71, 0xc0, - 0x0c, 0x00, 0x2f, 0x80, 0x01, 0x00, 0x00, 0x11, 0x94, 0x00, 0x09, 0xc0, - 0x0c, 0x00, 0x05, 0x00, 0x00, 0x80, 0x00, 0x40, 0xc0, 0x71, 0x00, 0x2f, - 0x80, 0x01, 0x00, 0x00, 0x00, 0x78, 0x00, 0x05, 0xc0, 0x71, 0x00, 0x01, - 0x40, 0x00, 0x00, 0x29, 0x05, 0xa0, 0x00, 0x00, 0x11, 0x94, 0x00, 0x12, - 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x50, 0x65, 0xf3, 0x41, 0x27, 0x01, -}; - -} // namespace - -TEST(MdnsResponderAdapterImplTest, ExampleData) { - const DomainName openscreen_service{{11, '_', 'o', 'p', 'e', 'n', 's', 'c', - 'r', 'e', 'e', 'n', 4, '_', 'u', 'd', - 'p', 5, 'l', 'o', 'c', 'a', 'l', 0}}; - const IPEndpoint mdns_endpoint{{224, 0, 0, 251}, 5353}; - - UdpPacket packet(std::begin(data), std::end(data)); - packet.set_source({{192, 168, 0, 2}, 6556}); - packet.set_destination(mdns_endpoint); - packet.set_socket(nullptr); - - auto mdns_adapter = - std::unique_ptr<MdnsResponderAdapter>(new MdnsResponderAdapterImpl); - mdns_adapter->Init(); - mdns_adapter->StartPtrQuery(0, openscreen_service); - mdns_adapter->OnRead(nullptr, std::move(packet)); - mdns_adapter->RunTasks(); - - auto ptr = mdns_adapter->TakePtrResponses(); - ASSERT_EQ(1u, ptr.size()); - ASSERT_THAT(ptr[0].service_instance.GetLabels(), - ElementsAre("turtle", "_openscreen", "_udp", "local")); - mdns_adapter->StartSrvQuery(0, ptr[0].service_instance); - mdns_adapter->StartTxtQuery(0, ptr[0].service_instance); - mdns_adapter->RunTasks(); - - auto srv = mdns_adapter->TakeSrvResponses(); - ASSERT_EQ(1u, srv.size()); - ASSERT_THAT(srv[0].domain_name.GetLabels(), - ElementsAre("gigliorononomicon", "local")); - EXPECT_EQ(12345, srv[0].port); - - auto txt = mdns_adapter->TakeTxtResponses(); - ASSERT_EQ(1u, txt.size()); - const std::string expected_txt[] = {"yurtle", "turtle"}; - EXPECT_THAT(txt[0].txt_info, ElementsAreArray(expected_txt)); - - mdns_adapter->StartAQuery(0, srv[0].domain_name); - mdns_adapter->RunTasks(); - - auto a = mdns_adapter->TakeAResponses(); - ASSERT_EQ(1u, a.size()); - EXPECT_EQ((IPAddress{172, 17, 32, 150}), a[0].address); - - mdns_adapter->Close(); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/discovery/mdns/mdns_responder_platform.cc b/osp/impl/discovery/mdns/mdns_responder_platform.cc deleted file mode 100644 index 14204ff5..00000000 --- a/osp/impl/discovery/mdns/mdns_responder_platform.cc +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/discovery/mdns/mdns_responder_platform.h" - -#include <algorithm> -#include <chrono> -#include <cstring> -#include <limits> -#include <vector> - -#include "platform/api/network_interface.h" -#include "platform/api/time.h" -#include "platform/api/udp_socket.h" -#include "platform/base/error.h" -#include "platform/base/ip_address.h" -#include "third_party/mDNSResponder/src/mDNSCore/mDNSEmbeddedAPI.h" -#include "util/osp_logging.h" - -namespace { - -using std::chrono::duration_cast; -using std::chrono::hours; -using std::chrono::milliseconds; -using std::chrono::seconds; - -} // namespace - -extern "C" { - -const char ProgramName[] = "openscreen"; - -mDNSs32 mDNSPlatformOneSecond = 1000; - -mStatus mDNSPlatformInit(mDNS* m) { - mDNSCoreInitComplete(m, mStatus_NoError); - return mStatus_NoError; -} - -void mDNSPlatformClose(mDNS* m) {} - -mStatus mDNSPlatformSendUDP(const mDNS* m, - const void* msg, - const mDNSu8* last, - mDNSInterfaceID InterfaceID, - UDPSocket* src, - const mDNSAddr* dst, - mDNSIPPort dstport) { - auto* const socket = reinterpret_cast<openscreen::UdpSocket*>(InterfaceID); - const auto socket_it = - std::find(m->p->sockets.begin(), m->p->sockets.end(), socket); - if (socket_it == m->p->sockets.end()) - return mStatus_BadInterfaceErr; - - openscreen::IPEndpoint dest{ - openscreen::IPAddress{dst->type == mDNSAddrType_IPv4 - ? openscreen::IPAddress::Version::kV4 - : openscreen::IPAddress::Version::kV6, - dst->ip.v4.b}, - static_cast<uint16_t>((dstport.b[0] << 8) | dstport.b[1])}; - const int64_t length = last - static_cast<const uint8_t*>(msg); - if (length < 0 || length > std::numeric_limits<ssize_t>::max()) { - return mStatus_BadParamErr; - } - - // UDP is inherently lossy, so don't worry about async failures and let the - // underlying protocol handle it. - (*socket_it)->SendMessage(msg, length, dest); - return mStatus_NoError; -} - -void mDNSPlatformLock(const mDNS* m) { - // We're single threaded. -} - -void mDNSPlatformUnlock(const mDNS* m) {} - -void mDNSPlatformStrCopy(void* dst, const void* src) { - // Unfortunately, the caller is responsible for making sure that dst - // if of sufficient length to store the src string. Otherwise we may - // cause an access violation. - std::strcpy(static_cast<char*>(dst), // NOLINT - static_cast<const char*>(src)); -} - -mDNSu32 mDNSPlatformStrLen(const void* src) { - return std::strlen(static_cast<const char*>(src)); -} - -void mDNSPlatformMemCopy(void* dst, const void* src, mDNSu32 len) { - std::memcpy(dst, src, len); -} - -mDNSBool mDNSPlatformMemSame(const void* dst, const void* src, mDNSu32 len) { - return std::memcmp(dst, src, len) == 0 ? mDNStrue : mDNSfalse; -} - -void mDNSPlatformMemZero(void* dst, mDNSu32 len) { - std::memset(dst, 0, len); -} - -void* mDNSPlatformMemAllocate(mDNSu32 len) { - return malloc(len); -} - -void mDNSPlatformMemFree(void* mem) { - free(mem); -} - -mDNSu32 mDNSPlatformRandomSeed() { - return std::chrono::steady_clock::now().time_since_epoch().count(); -} - -mStatus mDNSPlatformTimeInit() { - return mStatus_NoError; -} - -mDNSs32 mDNSPlatformRawTime() { - using openscreen::Clock; - - const Clock::time_point now = Clock::now(); - - // A signed 32-bit integer counting milliseconds only gives ~24.8 days of - // range. Thus, the first time this function is called, record a new origin - // timestamp to subtract from the raw monotonic clock values. The "one hour - // before now" value is used to keep the results well-ahead of zero because - // the mDNS library assumes this is the time since kernel boot and has hacks - // to disable certain things in the first few minutes. :-/ - static const Clock::time_point origin = now - hours(1); - - const int64_t millis_since_origin = - duration_cast<milliseconds>(now - origin).count(); - OSP_CHECK_LE(millis_since_origin, std::numeric_limits<mDNSs32>::max()); - return static_cast<mDNSs32>(millis_since_origin); -} - -mDNSs32 mDNSPlatformUTC() { - const auto seconds_since_epoch = - duration_cast<seconds>(openscreen::GetWallTimeSinceUnixEpoch()).count(); - - // The return type will cause overflow in early 2038. Warn future developers - // a year ahead of time. - constexpr mDNSs32 a_year_before_overflow = - std::numeric_limits<mDNSs32>::max() - - duration_cast<seconds>(365 * hours(24)).count(); - OSP_DCHECK_LE(seconds_since_epoch, a_year_before_overflow); - - return static_cast<mDNSs32>(seconds_since_epoch); -} - -void mDNSPlatformWriteDebugMsg(const char* msg) { - OSP_DVLOG << __func__ << ": " << msg; -} - -void mDNSPlatformWriteLogMsg(const char* ident, - const char* msg, - mDNSLogLevel_t loglevel) { - OSP_VLOG << __func__ << ": " << msg; -} - -TCPSocket* mDNSPlatformTCPSocket(mDNS* const m, - TCPSocketFlags flags, - mDNSIPPort* port) { - OSP_UNIMPLEMENTED(); - return nullptr; -} - -TCPSocket* mDNSPlatformTCPAccept(TCPSocketFlags flags, int sd) { - OSP_UNIMPLEMENTED(); - return nullptr; -} - -int mDNSPlatformTCPGetFD(TCPSocket* sock) { - OSP_UNIMPLEMENTED(); - return 0; -} - -mStatus mDNSPlatformTCPConnect(TCPSocket* sock, - const mDNSAddr* dst, - mDNSOpaque16 dstport, - domainname* hostname, - mDNSInterfaceID InterfaceID, - TCPConnectionCallback callback, - void* context) { - OSP_UNIMPLEMENTED(); - return mStatus_NoError; -} - -void mDNSPlatformTCPCloseConnection(TCPSocket* sock) { - OSP_UNIMPLEMENTED(); -} - -long mDNSPlatformReadTCP(TCPSocket* sock, // NOLINT - void* buf, - unsigned long buflen, // NOLINT - mDNSBool* closed) { - OSP_UNIMPLEMENTED(); - return 0; -} - -long mDNSPlatformWriteTCP(TCPSocket* sock, // NOLINT - const char* msg, - unsigned long len) { // NOLINT - OSP_UNIMPLEMENTED(); - return 0; -} - -UDPSocket* mDNSPlatformUDPSocket(mDNS* const m, - const mDNSIPPort requestedport) { - OSP_UNIMPLEMENTED(); - return nullptr; -} - -void mDNSPlatformUDPClose(UDPSocket* sock) { - OSP_UNIMPLEMENTED(); -} - -void mDNSPlatformReceiveBPF_fd(mDNS* const m, int fd) { - OSP_UNIMPLEMENTED(); -} - -void mDNSPlatformUpdateProxyList(mDNS* const m, - const mDNSInterfaceID InterfaceID) { - OSP_UNIMPLEMENTED(); -} - -void mDNSPlatformSendRawPacket(const void* const msg, - const mDNSu8* const end, - mDNSInterfaceID InterfaceID) { - OSP_UNIMPLEMENTED(); -} - -void mDNSPlatformSetLocalAddressCacheEntry(mDNS* const m, - const mDNSAddr* const tpa, - const mDNSEthAddr* const tha, - mDNSInterfaceID InterfaceID) {} - -void mDNSPlatformSourceAddrForDest(mDNSAddr* const src, - const mDNSAddr* const dst) {} - -mStatus mDNSPlatformTLSSetupCerts(void) { - OSP_UNIMPLEMENTED(); - return mStatus_NoError; -} - -void mDNSPlatformTLSTearDownCerts(void) { - OSP_UNIMPLEMENTED(); -} - -void mDNSPlatformSetDNSConfig(mDNS* const m, - mDNSBool setservers, - mDNSBool setsearch, - domainname* const fqdn, - DNameListElem** RegDomains, - DNameListElem** BrowseDomains) { - if (fqdn) { - std::memset(fqdn, 0, sizeof(*fqdn)); - } -} - -mStatus mDNSPlatformGetPrimaryInterface(mDNS* const m, - mDNSAddr* v4, - mDNSAddr* v6, - mDNSAddr* router) { - return mStatus_NoError; -} - -void mDNSPlatformDynDNSHostNameStatusChanged(const domainname* const dname, - const mStatus status) {} - -void mDNSPlatformSetAllowSleep(mDNS* const m, - mDNSBool allowSleep, - const char* reason) {} - -void mDNSPlatformSendWakeupPacket(mDNS* const m, - mDNSInterfaceID InterfaceID, - char* EthAddr, - char* IPAddr, - int iteration) { - OSP_UNIMPLEMENTED(); -} - -mDNSBool mDNSPlatformValidRecordForInterface(AuthRecord* rr, - const NetworkInterfaceInfo* intf) { - OSP_UNIMPLEMENTED(); - return mDNStrue; -} - -} // extern "C" diff --git a/osp/impl/discovery/mdns/mdns_responder_platform.h b/osp/impl/discovery/mdns/mdns_responder_platform.h deleted file mode 100644 index 342913fe..00000000 --- a/osp/impl/discovery/mdns/mdns_responder_platform.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_PLATFORM_H_ -#define OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_PLATFORM_H_ - -#include <vector> - -#include "platform/api/udp_socket.h" - -struct mDNS_PlatformSupport_struct { - std::vector<openscreen::UdpSocket*> sockets; -}; - -#endif // OSP_IMPL_DISCOVERY_MDNS_MDNS_RESPONDER_PLATFORM_H_ diff --git a/osp/impl/dns_sd_publisher_client.cc b/osp/impl/dns_sd_publisher_client.cc new file mode 100644 index 00000000..322bd4e0 --- /dev/null +++ b/osp/impl/dns_sd_publisher_client.cc @@ -0,0 +1,131 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "osp/impl/dns_sd_publisher_client.h" + +#include <utility> + +#include "discovery/common/config.h" +#include "discovery/dnssd/public/dns_sd_instance.h" +#include "discovery/dnssd/public/dns_sd_txt_record.h" +#include "discovery/public/dns_sd_service_factory.h" +#include "osp/public/service_info.h" +#include "platform/base/macros.h" +#include "util/osp_logging.h" + +namespace openscreen { +namespace osp { + +using State = ServicePublisher::State; + +namespace { + +constexpr char kFriendlyNameTxtKey[] = "fn"; +constexpr char kDnsSdDomainId[] = "local"; + +discovery::DnsSdInstance ServiceConfigToDnsSdInstance( + const ServicePublisher::Config& config) { + discovery::DnsSdTxtRecord txt; + const bool did_set_everything = + txt.SetValue(kFriendlyNameTxtKey, config.friendly_name).ok(); + OSP_DCHECK(did_set_everything); + + // NOTE: Not totally clear how we should be using config.hostname, which in + // principle is already part of config.service_instance_name. + return discovery::DnsSdInstance( + config.service_instance_name, kOpenScreenServiceName, kDnsSdDomainId, + std::move(txt), config.connection_server_port); +} + +} // namespace + +DnsSdPublisherClient::DnsSdPublisherClient(ServicePublisher::Observer* observer, + openscreen::TaskRunner* task_runner) + : observer_(observer), task_runner_(task_runner) { + OSP_DCHECK(observer_); + OSP_DCHECK(task_runner_); +} + +DnsSdPublisherClient::~DnsSdPublisherClient() = default; + +void DnsSdPublisherClient::StartPublisher( + const ServicePublisher::Config& config) { + OSP_LOG_INFO << "StartPublisher with " << config.network_interfaces.size() + << " interfaces"; + StartPublisherInternal(config); + Error result = dns_sd_publisher_->Register(config); + if (result.ok()) { + SetState(State::kRunning); + } else { + OnFatalError(result); + SetState(State::kStopped); + } +} + +void DnsSdPublisherClient::StartAndSuspendPublisher( + const ServicePublisher::Config& config) { + StartPublisherInternal(config); + SetState(State::kSuspended); +} + +void DnsSdPublisherClient::StopPublisher() { + dns_sd_publisher_.reset(); + SetState(State::kStopped); +} + +void DnsSdPublisherClient::SuspendPublisher() { + OSP_DCHECK(dns_sd_publisher_); + dns_sd_publisher_->DeregisterAll(); + SetState(State::kSuspended); +} + +void DnsSdPublisherClient::ResumePublisher( + const ServicePublisher::Config& config) { + OSP_DCHECK(dns_sd_publisher_); + dns_sd_publisher_->Register(config); + SetState(State::kRunning); +} + +void DnsSdPublisherClient::OnFatalError(Error error) { + observer_->OnError(error); +} + +void DnsSdPublisherClient::OnRecoverableError(Error error) { + observer_->OnError(error); +} + +void DnsSdPublisherClient::StartPublisherInternal( + const ServicePublisher::Config& config) { + OSP_DCHECK(!dns_sd_publisher_); + if (!dns_sd_service_) { + dns_sd_service_ = CreateDnsSdServiceInternal(config); + } + dns_sd_publisher_ = std::make_unique<OspDnsSdPublisher>( + dns_sd_service_.get(), kOpenScreenServiceName, + ServiceConfigToDnsSdInstance); +} + +SerialDeletePtr<discovery::DnsSdService> +DnsSdPublisherClient::CreateDnsSdServiceInternal( + const ServicePublisher::Config& config) { + // NOTE: With the current API, the client cannot customize the behavior of + // DNS-SD beyond the interface list. + openscreen::discovery::Config dns_sd_config; + dns_sd_config.enable_querying = false; + dns_sd_config.network_info = config.network_interfaces; + + // NOTE: + // It's desirable for the DNS-SD publisher and the DNS-SD listener for OSP to + // share the underlying mDNS socket and state, to avoid the agent from + // binding 2 sockets per network interface. + // + // This can be accomplished by having the agent use a shared instance of the + // discovery::DnsSdService, e.g. through a ref-counting handle, so that the + // OSP publisher and the OSP listener don't have to coordinate through an + // additional object. + return CreateDnsSdService(task_runner_, this, dns_sd_config); +} + +} // namespace osp +} // namespace openscreen diff --git a/osp/impl/dns_sd_publisher_client.h b/osp/impl/dns_sd_publisher_client.h new file mode 100644 index 00000000..9b055ed5 --- /dev/null +++ b/osp/impl/dns_sd_publisher_client.h @@ -0,0 +1,62 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef OSP_IMPL_DNS_SD_PUBLISHER_CLIENT_H_ +#define OSP_IMPL_DNS_SD_PUBLISHER_CLIENT_H_ + +#include <memory> + +#include "discovery/common/reporting_client.h" +#include "discovery/dnssd/public/dns_sd_service.h" +#include "discovery/public/dns_sd_service_publisher.h" +#include "osp/impl/service_publisher_impl.h" +#include "platform/api/serial_delete_ptr.h" + +namespace openscreen { + +class TaskRunner; + +namespace osp { + +class DnsSdPublisherClient final : public ServicePublisherImpl::Delegate, + openscreen::discovery::ReportingClient { + public: + DnsSdPublisherClient(ServicePublisher::Observer* observer, + openscreen::TaskRunner* task_runner); + ~DnsSdPublisherClient() override; + + // ServicePublisherImpl::Delegate overrides. + void StartPublisher(const ServicePublisher::Config& config) override; + void StartAndSuspendPublisher( + const ServicePublisher::Config& config) override; + void StopPublisher() override; + void SuspendPublisher() override; + void ResumePublisher(const ServicePublisher::Config& config) override; + + private: + DnsSdPublisherClient(const DnsSdPublisherClient&) = delete; + DnsSdPublisherClient(DnsSdPublisherClient&&) noexcept = delete; + + // openscreen::discovery::ReportingClient overrides. + void OnFatalError(Error) override; + void OnRecoverableError(Error) override; + + void StartPublisherInternal(const ServicePublisher::Config& config); + SerialDeletePtr<discovery::DnsSdService> CreateDnsSdServiceInternal( + const ServicePublisher::Config& config); + + ServicePublisher::Observer* const observer_; + TaskRunner* const task_runner_; + SerialDeletePtr<discovery::DnsSdService> dns_sd_service_; + + using OspDnsSdPublisher = + discovery::DnsSdServicePublisher<ServicePublisher::Config>; + + std::unique_ptr<OspDnsSdPublisher> dns_sd_publisher_; +}; + +} // namespace osp +} // namespace openscreen + +#endif // OSP_IMPL_DNS_SD_PUBLISHER_CLIENT_H_ diff --git a/osp/impl/dns_sd_service_publisher_factory.cc b/osp/impl/dns_sd_service_publisher_factory.cc new file mode 100644 index 00000000..5c63dc83 --- /dev/null +++ b/osp/impl/dns_sd_service_publisher_factory.cc @@ -0,0 +1,34 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <algorithm> +#include <memory> + +#include "discovery/dnssd/public/dns_sd_publisher.h" +#include "osp/impl/dns_sd_publisher_client.h" +#include "osp/impl/service_publisher_impl.h" +#include "osp/public/service_publisher.h" +#include "osp/public/service_publisher_factory.h" + +namespace openscreen { + +class TaskRunner; + +namespace osp { + +// static +std::unique_ptr<ServicePublisher> ServicePublisherFactory::Create( + const ServicePublisher::Config& config, + ServicePublisher::Observer* observer, + TaskRunner* task_runner) { + auto dns_sd_client = + std::make_unique<DnsSdPublisherClient>(observer, task_runner); + auto publisher_impl = std::make_unique<ServicePublisherImpl>( + observer, std::move(dns_sd_client)); + publisher_impl->SetConfig(config); + return publisher_impl; +} + +} // namespace osp +} // namespace openscreen diff --git a/osp/impl/internal_services.cc b/osp/impl/internal_services.cc deleted file mode 100644 index 19b55927..00000000 --- a/osp/impl/internal_services.cc +++ /dev/null @@ -1,229 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/internal_services.h" - -#include <algorithm> -#include <utility> - -#include "osp/impl/discovery/mdns/mdns_responder_adapter_impl.h" -#include "osp/impl/mdns_responder_service.h" -#include "platform/api/udp_socket.h" -#include "platform/base/error.h" -#include "util/osp_logging.h" - -namespace openscreen { -namespace osp { -namespace { - -constexpr char kServiceName[] = "_openscreen"; -constexpr char kServiceProtocol[] = "_udp"; -const IPAddress kMulticastAddress{224, 0, 0, 251}; -const IPAddress kMulticastIPv6Address{ - // ff02::fb - 0xff02, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x00fb, -}; -const uint16_t kMulticastListeningPort = 5353; - -class MdnsResponderAdapterImplFactory final - : public MdnsResponderAdapterFactory { - public: - MdnsResponderAdapterImplFactory() = default; - ~MdnsResponderAdapterImplFactory() override = default; - - std::unique_ptr<MdnsResponderAdapter> Create() override { - return std::make_unique<MdnsResponderAdapterImpl>(); - } -}; - -Error SetUpMulticastSocket(UdpSocket* socket, NetworkInterfaceIndex ifindex) { - const IPAddress broadcast_address = - socket->IsIPv6() ? kMulticastIPv6Address : kMulticastAddress; - - socket->JoinMulticastGroup(broadcast_address, ifindex); - socket->SetMulticastOutboundInterface(ifindex); - socket->Bind(); - - return Error::None(); -} - -// Ref-counted singleton instance of InternalServices. This lives only as long -// as there is at least one ServiceListener and/or ServicePublisher alive. -InternalServices* g_instance = nullptr; -int g_instance_ref_count = 0; - -} // namespace - -// static -std::unique_ptr<ServiceListener> InternalServices::CreateListener( - const MdnsServiceListenerConfig& config, - ServiceListener::Observer* observer, - TaskRunner* task_runner) { - auto* services = ReferenceSingleton(task_runner); - auto listener = - std::make_unique<ServiceListenerImpl>(&services->mdns_service_); - listener->AddObserver(observer); - listener->SetDestructionCallback(&InternalServices::DereferenceSingleton, - services); - return listener; -} - -// static -std::unique_ptr<ServicePublisher> InternalServices::CreatePublisher( - const ServicePublisher::Config& config, - ServicePublisher::Observer* observer, - TaskRunner* task_runner) { - auto* services = ReferenceSingleton(task_runner); - services->mdns_service_.SetServiceConfig( - config.hostname, config.service_instance_name, - config.connection_server_port, config.network_interface_indices, - {{"fn", config.friendly_name}}); - auto publisher = std::make_unique<ServicePublisherImpl>( - observer, &services->mdns_service_); - publisher->SetDestructionCallback(&InternalServices::DereferenceSingleton, - services); - return publisher; -} - -InternalServices::InternalPlatformLinkage::InternalPlatformLinkage( - InternalServices* parent) - : parent_(parent) {} - -InternalServices::InternalPlatformLinkage::~InternalPlatformLinkage() { - // If there are open sockets, then there will be dangling references to - // destroyed objects after destruction. - OSP_CHECK(open_sockets_.empty()); -} - -std::vector<MdnsPlatformService::BoundInterface> -InternalServices::InternalPlatformLinkage::RegisterInterfaces( - const std::vector<NetworkInterfaceIndex>& allowlist) { - const std::vector<InterfaceInfo> interfaces = GetNetworkInterfaces(); - const bool do_filter_using_allowlist = !allowlist.empty(); - std::vector<NetworkInterfaceIndex> index_list; - for (const auto& interface : interfaces) { - OSP_VLOG << "Found interface: " << interface; - if (do_filter_using_allowlist && - std::find(allowlist.begin(), allowlist.end(), interface.index) == - allowlist.end()) { - OSP_VLOG << "Ignoring interface not in allowed list: " << interface; - continue; - } - if (!interface.addresses.empty()) - index_list.push_back(interface.index); - } - OSP_LOG_IF(WARN, index_list.empty()) - << "No network interfaces had usable addresses for mDNS."; - - // Set up sockets to send and listen to mDNS multicast traffic on all - // interfaces. - std::vector<BoundInterface> result; - for (NetworkInterfaceIndex index : index_list) { - const auto& interface = *std::find_if( - interfaces.begin(), interfaces.end(), - [index](const InterfaceInfo& info) { return info.index == index; }); - if (interface.addresses.empty()) { - continue; - } - - // Pick any address for the given interface. - const IPSubnet& primary_subnet = interface.addresses.front(); - - auto create_result = - UdpSocket::Create(parent_->task_runner_, parent_, - IPEndpoint{{}, kMulticastListeningPort}); - if (!create_result) { - OSP_LOG_ERROR << "failed to create socket for interface " << index << ": " - << create_result.error().message(); - continue; - } - std::unique_ptr<UdpSocket> socket = std::move(create_result.value()); - if (!SetUpMulticastSocket(socket.get(), index).ok()) { - continue; - } - result.emplace_back(interface, primary_subnet, socket.get()); - parent_->RegisterMdnsSocket(socket.get()); - - open_sockets_.emplace_back(std::move(socket)); - } - - return result; -} - -void InternalServices::InternalPlatformLinkage::DeregisterInterfaces( - const std::vector<BoundInterface>& registered_interfaces) { - for (const auto& interface : registered_interfaces) { - UdpSocket* const socket = interface.socket; - parent_->DeregisterMdnsSocket(socket); - - const auto it = std::find_if(open_sockets_.begin(), open_sockets_.end(), - [socket](const std::unique_ptr<UdpSocket>& s) { - return s.get() == socket; - }); - OSP_DCHECK(it != open_sockets_.end()); - open_sockets_.erase(it); - } -} - -InternalServices::InternalServices(ClockNowFunctionPtr now_function, - TaskRunner* task_runner) - : mdns_service_(now_function, - task_runner, - kServiceName, - kServiceProtocol, - std::make_unique<MdnsResponderAdapterImplFactory>(), - std::make_unique<InternalPlatformLinkage>(this)), - task_runner_(task_runner) {} - -InternalServices::~InternalServices() = default; - -void InternalServices::RegisterMdnsSocket(UdpSocket* socket) { - OSP_CHECK(g_instance) << "No listener or publisher is alive."; - // TODO(rwkeane): Hook this up to the new mDNS library once we swap out the - // mDNSResponder. -} - -void InternalServices::DeregisterMdnsSocket(UdpSocket* socket) { - // TODO(rwkeane): Hook this up to the new mDNS library once we swap out the - // mDNSResponder. -} - -// static -InternalServices* InternalServices::ReferenceSingleton( - TaskRunner* task_runner) { - if (!g_instance) { - OSP_CHECK_EQ(g_instance_ref_count, 0); - g_instance = new InternalServices(&Clock::now, task_runner); - } - ++g_instance_ref_count; - return g_instance; -} - -// static -void InternalServices::DereferenceSingleton(void* instance) { - OSP_CHECK_EQ(static_cast<InternalServices*>(instance), g_instance); - OSP_CHECK_GT(g_instance_ref_count, 0); - --g_instance_ref_count; - if (g_instance_ref_count == 0) { - delete g_instance; - g_instance = nullptr; - } -} - -void InternalServices::OnError(UdpSocket* socket, Error error) { - OSP_LOG_ERROR << "failed to configure socket " << error.message(); - this->DeregisterMdnsSocket(socket); -} - -void InternalServices::OnSendError(UdpSocket* socket, Error error) { - // TODO(crbug.com/openscreen/67): Implement this method. - OSP_UNIMPLEMENTED(); -} - -void InternalServices::OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) { - g_instance->mdns_service_.OnRead(socket, std::move(packet)); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/internal_services.h b/osp/impl/internal_services.h deleted file mode 100644 index 042be4c1..00000000 --- a/osp/impl/internal_services.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_IMPL_INTERNAL_SERVICES_H_ -#define OSP_IMPL_INTERNAL_SERVICES_H_ - -#include <memory> -#include <vector> - -#include "osp/impl/mdns_platform_service.h" -#include "osp/impl/mdns_responder_service.h" -#include "osp/impl/quic/quic_connection_factory.h" -#include "osp/impl/service_listener_impl.h" -#include "osp/impl/service_publisher_impl.h" -#include "osp/public/mdns_service_listener_factory.h" -#include "osp/public/mdns_service_publisher_factory.h" -#include "osp/public/protocol_connection_client.h" -#include "osp/public/protocol_connection_server.h" -#include "platform/api/network_interface.h" -#include "platform/api/time.h" -#include "platform/api/udp_socket.h" -#include "platform/base/ip_address.h" -#include "platform/base/macros.h" - -namespace openscreen { - -class TaskRunner; - -namespace osp { - -// Factory for ServiceListener and ServicePublisher instances; owns internal -// objects needed to instantiate them such as MdnsResponderService and runs an -// event loop. -// TODO(btolsch): This may be renamed and/or split up once QUIC code lands and -// this use case is more concrete. -class InternalServices : UdpSocket::Client { - public: - static std::unique_ptr<ServiceListener> CreateListener( - const MdnsServiceListenerConfig& config, - ServiceListener::Observer* observer, - TaskRunner* task_runner); - static std::unique_ptr<ServicePublisher> CreatePublisher( - const ServicePublisher::Config& config, - ServicePublisher::Observer* observer, - TaskRunner* task_runner); - - // UdpSocket::Client overrides. - void OnError(UdpSocket* socket, Error error) override; - void OnSendError(UdpSocket* socket, Error error) override; - void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; - - private: - class InternalPlatformLinkage final : public MdnsPlatformService { - public: - explicit InternalPlatformLinkage(InternalServices* parent); - ~InternalPlatformLinkage() override; - - std::vector<BoundInterface> RegisterInterfaces( - const std::vector<NetworkInterfaceIndex>& allowlist) override; - void DeregisterInterfaces( - const std::vector<BoundInterface>& registered_interfaces) override; - - private: - InternalServices* const parent_; - std::vector<std::unique_ptr<UdpSocket>> open_sockets_; - }; - - // The TaskRunner provided here should live for the duration of this - // InternalService object's lifetime. - InternalServices(ClockNowFunctionPtr now_function, TaskRunner* task_runner); - ~InternalServices() override; - - void RegisterMdnsSocket(UdpSocket* socket); - void DeregisterMdnsSocket(UdpSocket* socket); - - static InternalServices* ReferenceSingleton(TaskRunner* task_runner); - static void DereferenceSingleton(void* instance); - - MdnsResponderService mdns_service_; - - TaskRunner* const task_runner_; - - OSP_DISALLOW_COPY_AND_ASSIGN(InternalServices); -}; - -} // namespace osp -} // namespace openscreen - -#endif // OSP_IMPL_INTERNAL_SERVICES_H_ diff --git a/osp/impl/mdns_platform_service.cc b/osp/impl/mdns_platform_service.cc deleted file mode 100644 index 4968c259..00000000 --- a/osp/impl/mdns_platform_service.cc +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/mdns_platform_service.h" - -#include <cstring> - -#include "util/osp_logging.h" - -namespace openscreen { -namespace osp { - -MdnsPlatformService::BoundInterface::BoundInterface( - const InterfaceInfo& interface_info, - const IPSubnet& subnet, - UdpSocket* socket) - : interface_info(interface_info), subnet(subnet), socket(socket) { - OSP_DCHECK(socket); -} - -MdnsPlatformService::BoundInterface::~BoundInterface() = default; - -bool MdnsPlatformService::BoundInterface::operator==( - const MdnsPlatformService::BoundInterface& other) const { - if (interface_info.index != other.interface_info.index) - return false; - - if (subnet.address != other.subnet.address || - subnet.prefix_length != other.subnet.prefix_length) { - return false; - } - - if (socket != other.socket) - return false; - - return true; -} - -bool MdnsPlatformService::BoundInterface::operator!=( - const MdnsPlatformService::BoundInterface& other) const { - return !(*this == other); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/mdns_platform_service.h b/osp/impl/mdns_platform_service.h deleted file mode 100644 index aca4ffd7..00000000 --- a/osp/impl/mdns_platform_service.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_IMPL_MDNS_PLATFORM_SERVICE_H_ -#define OSP_IMPL_MDNS_PLATFORM_SERVICE_H_ - -#include <vector> - -#include "platform/api/network_interface.h" -#include "platform/api/udp_socket.h" - -namespace openscreen { -namespace osp { - -class MdnsPlatformService { - public: - struct BoundInterface { - BoundInterface(const InterfaceInfo& interface_info, - const IPSubnet& subnet, - UdpSocket* socket); - ~BoundInterface(); - - bool operator==(const BoundInterface& other) const; - bool operator!=(const BoundInterface& other) const; - - InterfaceInfo interface_info; - IPSubnet subnet; - UdpSocket* socket; - }; - - virtual ~MdnsPlatformService() = default; - - virtual std::vector<BoundInterface> RegisterInterfaces( - const std::vector<NetworkInterfaceIndex>& allowlist) = 0; - virtual void DeregisterInterfaces( - const std::vector<BoundInterface>& registered_interfaces) = 0; -}; - -} // namespace osp -} // namespace openscreen - -#endif // OSP_IMPL_MDNS_PLATFORM_SERVICE_H_ diff --git a/osp/impl/mdns_responder_service.cc b/osp/impl/mdns_responder_service.cc deleted file mode 100644 index f9a80fae..00000000 --- a/osp/impl/mdns_responder_service.cc +++ /dev/null @@ -1,664 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/mdns_responder_service.h" - -#include <algorithm> -#include <memory> -#include <utility> - -#include "osp/impl/internal_services.h" -#include "platform/base/error.h" -#include "util/osp_logging.h" -#include "util/trace_logging.h" - -namespace openscreen { -namespace osp { -namespace { - -// TODO(btolsch): This should probably at least also contain network identity -// information. -std::string ServiceIdFromServiceInstanceName( - const DomainName& service_instance) { - std::string service_id; - service_id.assign( - reinterpret_cast<const char*>(service_instance.domain_name().data()), - service_instance.domain_name().size()); - return service_id; -} - -} // namespace - -MdnsResponderService::MdnsResponderService( - ClockNowFunctionPtr now_function, - TaskRunner* task_runner, - const std::string& service_name, - const std::string& service_protocol, - std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory, - std::unique_ptr<MdnsPlatformService> platform) - : service_type_{{service_name, service_protocol}}, - mdns_responder_factory_(std::move(mdns_responder_factory)), - platform_(std::move(platform)), - task_runner_(task_runner), - background_tasks_alarm_(now_function, task_runner) {} - -MdnsResponderService::~MdnsResponderService() = default; - -void MdnsResponderService::SetServiceConfig( - const std::string& hostname, - const std::string& instance, - uint16_t port, - const std::vector<NetworkInterfaceIndex> allowlist, - const std::map<std::string, std::string>& txt_data) { - OSP_DCHECK(!hostname.empty()); - OSP_DCHECK(!instance.empty()); - OSP_DCHECK_NE(0, port); - service_hostname_ = hostname; - service_instance_name_ = instance; - service_port_ = port; - interface_index_allowlist_ = allowlist; - service_txt_data_ = txt_data; -} - -void MdnsResponderService::OnRead(UdpSocket* socket, - ErrorOr<UdpPacket> packet) { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderService::OnRead"); - if (!mdns_responder_) { - return; - } - - mdns_responder_->OnRead(socket, std::move(packet)); - HandleMdnsEvents(); -} - -void MdnsResponderService::OnSendError(UdpSocket* socket, Error error) { - mdns_responder_->OnSendError(socket, std::move(error)); -} - -void MdnsResponderService::OnError(UdpSocket* socket, Error error) { - mdns_responder_->OnError(socket, std::move(error)); -} - -void MdnsResponderService::StartListener() { - task_runner_->PostTask([this]() { this->StartListenerInternal(); }); -} - -void MdnsResponderService::StartAndSuspendListener() { - task_runner_->PostTask([this]() { this->StartAndSuspendListenerInternal(); }); -} - -void MdnsResponderService::StopListener() { - task_runner_->PostTask([this]() { this->StopListenerInternal(); }); -} - -void MdnsResponderService::SuspendListener() { - task_runner_->PostTask([this]() { this->SuspendListenerInternal(); }); -} - -void MdnsResponderService::ResumeListener() { - task_runner_->PostTask([this]() { this->ResumeListenerInternal(); }); -} - -void MdnsResponderService::SearchNow(ServiceListener::State from) { - task_runner_->PostTask([this, from]() { this->SearchNowInternal(from); }); -} - -void MdnsResponderService::StartPublisher() { - task_runner_->PostTask([this]() { this->StartPublisherInternal(); }); -} - -void MdnsResponderService::StartAndSuspendPublisher() { - task_runner_->PostTask( - [this]() { this->StartAndSuspendPublisherInternal(); }); -} - -void MdnsResponderService::StopPublisher() { - task_runner_->PostTask([this]() { this->StopPublisherInternal(); }); -} - -void MdnsResponderService::SuspendPublisher() { - task_runner_->PostTask([this]() { this->SuspendPublisherInternal(); }); -} - -void MdnsResponderService::ResumePublisher() { - task_runner_->PostTask([this]() { this->ResumePublisherInternal(); }); -} - -void MdnsResponderService::StartListenerInternal() { - if (!mdns_responder_) { - mdns_responder_ = mdns_responder_factory_->Create(); - } - - StartListening(); - ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kRunning); - RunBackgroundTasks(); -} - -void MdnsResponderService::StartAndSuspendListenerInternal() { - mdns_responder_ = mdns_responder_factory_->Create(); - ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kSuspended); -} - -void MdnsResponderService::StopListenerInternal() { - StopListening(); - if (!publisher_ || publisher_->state() == ServicePublisher::State::kStopped || - publisher_->state() == ServicePublisher::State::kSuspended) { - StopMdnsResponder(); - if (!publisher_ || publisher_->state() == ServicePublisher::State::kStopped) - mdns_responder_.reset(); - } - ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kStopped); -} - -void MdnsResponderService::SuspendListenerInternal() { - StopMdnsResponder(); - ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kSuspended); -} - -void MdnsResponderService::ResumeListenerInternal() { - StartListening(); - ServiceListenerImpl::Delegate::SetState(ServiceListener::State::kRunning); -} - -void MdnsResponderService::SearchNowInternal(ServiceListener::State from) { - ServiceListenerImpl::Delegate::SetState(from); -} - -void MdnsResponderService::StartPublisherInternal() { - if (!mdns_responder_) { - mdns_responder_ = mdns_responder_factory_->Create(); - } - - StartService(); - ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kRunning); - RunBackgroundTasks(); -} - -void MdnsResponderService::StartAndSuspendPublisherInternal() { - mdns_responder_ = mdns_responder_factory_->Create(); - ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kSuspended); -} - -void MdnsResponderService::StopPublisherInternal() { - StopService(); - if (!listener_ || listener_->state() == ServiceListener::State::kStopped || - listener_->state() == ServiceListener::State::kSuspended) { - StopMdnsResponder(); - if (!listener_ || listener_->state() == ServiceListener::State::kStopped) - mdns_responder_.reset(); - } - ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kStopped); -} - -void MdnsResponderService::SuspendPublisherInternal() { - StopService(); - ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kSuspended); -} - -void MdnsResponderService::ResumePublisherInternal() { - StartService(); - ServicePublisherImpl::Delegate::SetState(ServicePublisher::State::kRunning); -} - -bool MdnsResponderService::NetworkScopedDomainNameComparator::operator()( - const NetworkScopedDomainName& a, - const NetworkScopedDomainName& b) const { - if (a.socket != b.socket) { - return (a.socket - b.socket) < 0; - } - return DomainNameComparator()(a.domain_name, b.domain_name); -} - -void MdnsResponderService::HandleMdnsEvents() { - TRACE_SCOPED(TraceCategory::kMdns, "MdnsResponderService::HandleMdnsEvents"); - // NOTE: In the common case, we will get a single combined packet for - // PTR/SRV/TXT/A and then no other packets. If we don't loop here, we would - // start SRV/TXT queries based on the PTR response, but never check for events - // again. This should no longer be a problem when we have correct scheduling - // of RunTasks. - bool events_possible = false; - // NOTE: This set will track which service instances were changed by all the - // events throughout all the loop iterations. At the end, we can dispatch our - // ServiceInfo updates to |listener_| just once (e.g. instead of - // OnReceiverChanged, OnReceiverChanged, ..., just a single - // OnReceiverChanged). - InstanceNameSet modified_instance_names; - do { - events_possible = false; - for (auto& ptr_event : mdns_responder_->TakePtrResponses()) { - events_possible = HandlePtrEvent(ptr_event, &modified_instance_names) || - events_possible; - } - for (auto& srv_event : mdns_responder_->TakeSrvResponses()) { - events_possible = HandleSrvEvent(srv_event, &modified_instance_names) || - events_possible; - } - for (auto& txt_event : mdns_responder_->TakeTxtResponses()) { - events_possible = HandleTxtEvent(txt_event, &modified_instance_names) || - events_possible; - } - for (const auto& a_event : mdns_responder_->TakeAResponses()) { - events_possible = - HandleAEvent(a_event, &modified_instance_names) || events_possible; - } - for (const auto& aaaa_event : mdns_responder_->TakeAaaaResponses()) { - events_possible = HandleAaaaEvent(aaaa_event, &modified_instance_names) || - events_possible; - } - if (events_possible) { - // NOTE: This still needs to be called here, even though it runs in the - // background regularly, because we just finished processing MDNS events. - RunBackgroundTasks(); - } - } while (events_possible); - - for (const auto& instance_name : modified_instance_names) { - auto service_entry = service_by_name_.find(instance_name); - std::unique_ptr<ServiceInstance>& service = service_entry->second; - - std::string service_id = ServiceIdFromServiceInstanceName(instance_name); - auto receiver_info_entry = receiver_info_.find(service_id); - HostInfo* host = GetHostInfo(service->ptr_socket, service->domain_name); - if (!IsServiceReady(*service, host)) { - if (receiver_info_entry != receiver_info_.end()) { - const ServiceInfo& receiver_info = receiver_info_entry->second; - listener_->OnReceiverRemoved(receiver_info); - receiver_info_.erase(receiver_info_entry); - } - if (!service->has_ptr_record && !service->has_srv()) - service_by_name_.erase(service_entry); - continue; - } - - // TODO(btolsch): Verify UTF-8 here. - std::string friendly_name(instance_name.GetLabels()[0]); - - if (receiver_info_entry == receiver_info_.end()) { - ServiceInfo receiver_info{ - std::move(service_id), - std::move(friendly_name), - GetNetworkInterfaceIndexFromSocket(service->ptr_socket), - {host->v4_address, service->port}, - {host->v6_address, service->port}}; - listener_->OnReceiverAdded(receiver_info); - receiver_info_.emplace(receiver_info.service_id, - std::move(receiver_info)); - } else { - ServiceInfo& receiver_info = receiver_info_entry->second; - if (receiver_info.Update( - std::move(friendly_name), - GetNetworkInterfaceIndexFromSocket(service->ptr_socket), - {host->v4_address, service->port}, - {host->v6_address, service->port})) { - listener_->OnReceiverChanged(receiver_info); - } - } - } -} - -void MdnsResponderService::StartListening() { - // TODO(btolsch): This needs the same |interface_index_allowlist_| logic as - // StartService, but this can also wait until the network-change TODO is - // addressed. - if (bound_interfaces_.empty()) { - mdns_responder_->Init(); - bound_interfaces_ = platform_->RegisterInterfaces({}); - for (auto& interface : bound_interfaces_) { - mdns_responder_->RegisterInterface(interface.interface_info, - interface.subnet, interface.socket); - } - } - ErrorOr<DomainName> service_type = - DomainName::FromLabels(service_type_.begin(), service_type_.end()); - OSP_CHECK(service_type); - for (const auto& interface : bound_interfaces_) { - mdns_responder_->StartPtrQuery(interface.socket, service_type.value()); - } -} - -void MdnsResponderService::StopListening() { - ErrorOr<DomainName> service_type = - DomainName::FromLabels(service_type_.begin(), service_type_.end()); - OSP_CHECK(service_type); - for (const auto& kv : network_scoped_domain_to_host_) { - const NetworkScopedDomainName& scoped_domain = kv.first; - - mdns_responder_->StopAQuery(scoped_domain.socket, - scoped_domain.domain_name); - mdns_responder_->StopAaaaQuery(scoped_domain.socket, - scoped_domain.domain_name); - } - network_scoped_domain_to_host_.clear(); - for (const auto& service : service_by_name_) { - UdpSocket* const socket = service.second->ptr_socket; - mdns_responder_->StopSrvQuery(socket, service.first); - mdns_responder_->StopTxtQuery(socket, service.first); - } - service_by_name_.clear(); - for (const auto& interface : bound_interfaces_) { - mdns_responder_->StopPtrQuery(interface.socket, service_type.value()); - } - RemoveAllReceivers(); -} - -void MdnsResponderService::StartService() { - // TODO(crbug.com/openscreen/45): This should really be a library-wide - // allowed list. - if (!bound_interfaces_.empty() && !interface_index_allowlist_.empty()) { - // TODO(btolsch): New interfaces won't be picked up on this path, but this - // also highlights a larger issue of the interface list being frozen while - // no state transitions are being made. There should be another interface - // on MdnsPlatformService for getting network interface updates. - std::vector<MdnsPlatformService::BoundInterface> deregistered_interfaces; - for (auto it = bound_interfaces_.begin(); it != bound_interfaces_.end();) { - if (std::find(interface_index_allowlist_.begin(), - interface_index_allowlist_.end(), - it->interface_info.index) == - interface_index_allowlist_.end()) { - mdns_responder_->DeregisterInterface(it->socket); - deregistered_interfaces.push_back(*it); - it = bound_interfaces_.erase(it); - } else { - ++it; - } - } - platform_->DeregisterInterfaces(deregistered_interfaces); - } else if (bound_interfaces_.empty()) { - mdns_responder_->Init(); - mdns_responder_->SetHostLabel(service_hostname_); - bound_interfaces_ = - platform_->RegisterInterfaces(interface_index_allowlist_); - for (auto& interface : bound_interfaces_) { - mdns_responder_->RegisterInterface(interface.interface_info, - interface.subnet, interface.socket); - } - } - - ErrorOr<DomainName> domain_name = - DomainName::FromLabels(&service_hostname_, &service_hostname_ + 1); - OSP_CHECK(domain_name) << "bad hostname configured: " << service_hostname_; - DomainName name = std::move(domain_name.value()); - - Error error = name.Append(DomainName::GetLocalDomain()); - OSP_CHECK(error.ok()); - - mdns_responder_->RegisterService(service_instance_name_, service_type_[0], - service_type_[1], name, service_port_, - service_txt_data_); -} - -void MdnsResponderService::StopService() { - mdns_responder_->DeregisterService(service_instance_name_, service_type_[0], - service_type_[1]); -} - -void MdnsResponderService::StopMdnsResponder() { - mdns_responder_->Close(); - platform_->DeregisterInterfaces(bound_interfaces_); - bound_interfaces_.clear(); - network_scoped_domain_to_host_.clear(); - service_by_name_.clear(); - RemoveAllReceivers(); -} - -void MdnsResponderService::UpdatePendingServiceInfoSet( - InstanceNameSet* modified_instance_names, - const DomainName& domain_name) { - for (auto& entry : service_by_name_) { - const auto& instance_name = entry.first; - const auto& instance = entry.second; - if (instance->domain_name == domain_name) { - modified_instance_names->emplace(instance_name); - } - } -} - -void MdnsResponderService::RemoveAllReceivers() { - bool had_receivers = !receiver_info_.empty(); - receiver_info_.clear(); - if (had_receivers) - listener_->OnAllReceiversRemoved(); -} - -bool MdnsResponderService::HandlePtrEvent( - const PtrEvent& ptr_event, - InstanceNameSet* modified_instance_names) { - bool events_possible = false; - const auto& instance_name = ptr_event.service_instance; - UdpSocket* const socket = ptr_event.header.socket; - auto entry = service_by_name_.find(ptr_event.service_instance); - switch (ptr_event.header.response_type) { - case QueryEventHeader::Type::kAddedNoCache: - break; - case QueryEventHeader::Type::kAdded: { - if (entry != service_by_name_.end()) { - entry->second->has_ptr_record = true; - modified_instance_names->emplace(instance_name); - break; - } - mdns_responder_->StartSrvQuery(socket, instance_name); - mdns_responder_->StartTxtQuery(socket, instance_name); - events_possible = true; - - auto new_instance = std::make_unique<ServiceInstance>(); - new_instance->ptr_socket = socket; - new_instance->has_ptr_record = true; - modified_instance_names->emplace(instance_name); - service_by_name_.emplace(std::move(instance_name), - std::move(new_instance)); - } break; - case QueryEventHeader::Type::kRemoved: - if (entry == service_by_name_.end()) - break; - if (entry->second->ptr_socket != socket) - break; - entry->second->has_ptr_record = false; - // NOTE: Occasionally, we can observe this situation in the wild where the - // PTR for a service is removed and then immediately re-added (like an odd - // refresh). Additionally, the recommended TTL of PTR records is much - // shorter than the other records. This means that short network drops or - // latency spikes could cause the PTR refresh queries and/or responses to - // be lost so the record isn't quite refreshed in time. The solution here - // and in HandleSrvEvent is to only remove the service records completely - // when both the PTR and SRV have been removed. - if (!entry->second->has_srv()) { - mdns_responder_->StopSrvQuery(socket, instance_name); - mdns_responder_->StopTxtQuery(socket, instance_name); - } - modified_instance_names->emplace(std::move(instance_name)); - break; - } - return events_possible; -} - -bool MdnsResponderService::HandleSrvEvent( - const SrvEvent& srv_event, - InstanceNameSet* modified_instance_names) { - bool events_possible = false; - auto& domain_name = srv_event.domain_name; - const auto& instance_name = srv_event.service_instance; - UdpSocket* const socket = srv_event.header.socket; - auto entry = service_by_name_.find(srv_event.service_instance); - if (entry == service_by_name_.end()) - return events_possible; - switch (srv_event.header.response_type) { - case QueryEventHeader::Type::kAddedNoCache: - break; - case QueryEventHeader::Type::kAdded: { - NetworkScopedDomainName scoped_domain_name{socket, domain_name}; - auto host_entry = network_scoped_domain_to_host_.find(scoped_domain_name); - if (host_entry == network_scoped_domain_to_host_.end()) { - mdns_responder_->StartAQuery(socket, domain_name); - mdns_responder_->StartAaaaQuery(socket, domain_name); - events_possible = true; - auto result = network_scoped_domain_to_host_.emplace( - std::move(scoped_domain_name), HostInfo{}); - host_entry = result.first; - } - auto& dependent_services = host_entry->second.services; - if (std::find_if(dependent_services.begin(), dependent_services.end(), - [entry](ServiceInstance* instance) { - return instance == entry->second.get(); - }) == dependent_services.end()) { - dependent_services.push_back(entry->second.get()); - } - entry->second->domain_name = std::move(domain_name); - entry->second->port = srv_event.port; - modified_instance_names->emplace(std::move(instance_name)); - } break; - case QueryEventHeader::Type::kRemoved: { - NetworkScopedDomainName scoped_domain_name{socket, domain_name}; - auto host_entry = network_scoped_domain_to_host_.find(scoped_domain_name); - if (host_entry != network_scoped_domain_to_host_.end()) { - auto& dependent_services = host_entry->second.services; - dependent_services.erase( - std::remove_if(dependent_services.begin(), dependent_services.end(), - [entry](ServiceInstance* instance) { - return instance == entry->second.get(); - }), - dependent_services.end()); - if (dependent_services.empty()) { - mdns_responder_->StopAQuery(socket, domain_name); - mdns_responder_->StopAaaaQuery(socket, domain_name); - network_scoped_domain_to_host_.erase(host_entry); - } - } - entry->second->domain_name = DomainName(); - entry->second->port = 0; - if (!entry->second->has_ptr_record) { - mdns_responder_->StopSrvQuery(socket, instance_name); - mdns_responder_->StopTxtQuery(socket, instance_name); - } - modified_instance_names->emplace(std::move(instance_name)); - } break; - } - return events_possible; -} - -bool MdnsResponderService::HandleTxtEvent( - const TxtEvent& txt_event, - InstanceNameSet* modified_instance_names) { - bool events_possible = false; - const auto& instance_name = txt_event.service_instance; - auto entry = service_by_name_.find(instance_name); - if (entry == service_by_name_.end()) - return events_possible; - switch (txt_event.header.response_type) { - case QueryEventHeader::Type::kAddedNoCache: - break; - case QueryEventHeader::Type::kAdded: - modified_instance_names->emplace(instance_name); - if (entry == service_by_name_.end()) { - auto result = service_by_name_.emplace( - std::move(instance_name), std::make_unique<ServiceInstance>()); - entry = result.first; - } - entry->second->txt_info = std::move(txt_event.txt_info); - break; - case QueryEventHeader::Type::kRemoved: - entry->second->txt_info.clear(); - modified_instance_names->emplace(std::move(instance_name)); - break; - } - return events_possible; -} - -bool MdnsResponderService::HandleAddressEvent( - UdpSocket* socket, - QueryEventHeader::Type response_type, - const DomainName& domain_name, - bool a_event, - const IPAddress& address, - InstanceNameSet* modified_instance_names) { - bool events_possible = false; - switch (response_type) { - case QueryEventHeader::Type::kAddedNoCache: - break; - case QueryEventHeader::Type::kAdded: { - HostInfo* host = AddOrGetHostInfo(socket, domain_name); - if (a_event) - host->v4_address = address; - else - host->v6_address = address; - UpdatePendingServiceInfoSet(modified_instance_names, domain_name); - } break; - case QueryEventHeader::Type::kRemoved: { - HostInfo* host = GetHostInfo(socket, domain_name); - - if (a_event) - host->v4_address = IPAddress(); - else - host->v6_address = IPAddress(); - - if (host->v4_address || host->v6_address) - UpdatePendingServiceInfoSet(modified_instance_names, domain_name); - } break; - } - return events_possible; -} - -bool MdnsResponderService::HandleAEvent( - const AEvent& a_event, - InstanceNameSet* modified_instance_names) { - return HandleAddressEvent(a_event.header.socket, a_event.header.response_type, - a_event.domain_name, true, a_event.address, - modified_instance_names); -} - -bool MdnsResponderService::HandleAaaaEvent( - const AaaaEvent& aaaa_event, - InstanceNameSet* modified_instance_names) { - return HandleAddressEvent(aaaa_event.header.socket, - aaaa_event.header.response_type, - aaaa_event.domain_name, false, aaaa_event.address, - modified_instance_names); -} - -MdnsResponderService::HostInfo* MdnsResponderService::AddOrGetHostInfo( - UdpSocket* socket, - const DomainName& domain_name) { - return &network_scoped_domain_to_host_[NetworkScopedDomainName{socket, - domain_name}]; -} - -MdnsResponderService::HostInfo* MdnsResponderService::GetHostInfo( - UdpSocket* socket, - const DomainName& domain_name) { - auto kv = network_scoped_domain_to_host_.find( - NetworkScopedDomainName{socket, domain_name}); - if (kv == network_scoped_domain_to_host_.end()) - return nullptr; - - return &kv->second; -} - -bool MdnsResponderService::IsServiceReady(const ServiceInstance& instance, - HostInfo* host) const { - return (host && instance.has_ptr_record && instance.has_srv() && - !instance.txt_info.empty() && (host->v4_address || host->v6_address)); -} - -NetworkInterfaceIndex MdnsResponderService::GetNetworkInterfaceIndexFromSocket( - const UdpSocket* socket) const { - auto it = std::find_if( - bound_interfaces_.begin(), bound_interfaces_.end(), - [socket](const MdnsPlatformService::BoundInterface& interface) { - return interface.socket == socket; - }); - if (it == bound_interfaces_.end()) - return kInvalidNetworkInterfaceIndex; - return it->interface_info.index; -} - -void MdnsResponderService::RunBackgroundTasks() { - if (!mdns_responder_) { - return; - } - const auto delay_until_next_run = mdns_responder_->RunTasks(); - background_tasks_alarm_.ScheduleFromNow([this] { RunBackgroundTasks(); }, - delay_until_next_run); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/mdns_responder_service.h b/osp/impl/mdns_responder_service.h deleted file mode 100644 index ddcd0dbd..00000000 --- a/osp/impl/mdns_responder_service.h +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_IMPL_MDNS_RESPONDER_SERVICE_H_ -#define OSP_IMPL_MDNS_RESPONDER_SERVICE_H_ - -#include <array> -#include <map> -#include <memory> -#include <set> -#include <string> -#include <vector> - -#include "osp/impl/discovery/mdns/mdns_responder_adapter.h" -#include "osp/impl/mdns_platform_service.h" -#include "osp/impl/service_listener_impl.h" -#include "osp/impl/service_publisher_impl.h" -#include "platform/api/network_interface.h" -#include "platform/api/task_runner.h" -#include "platform/api/time.h" -#include "platform/base/ip_address.h" -#include "util/alarm.h" - -namespace openscreen { -namespace osp { - -class MdnsResponderAdapterFactory { - public: - virtual ~MdnsResponderAdapterFactory() = default; - - virtual std::unique_ptr<MdnsResponderAdapter> Create() = 0; -}; - -class MdnsResponderService : public ServiceListenerImpl::Delegate, - public ServicePublisherImpl::Delegate, - public UdpSocket::Client { - public: - MdnsResponderService( - ClockNowFunctionPtr now_function, - TaskRunner* task_runner, - const std::string& service_name, - const std::string& service_protocol, - std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory, - std::unique_ptr<MdnsPlatformService> platform); - ~MdnsResponderService() override; - - void SetServiceConfig(const std::string& hostname, - const std::string& instance, - uint16_t port, - const std::vector<NetworkInterfaceIndex> allowlist, - const std::map<std::string, std::string>& txt_data); - - // UdpSocket::Client overrides. - void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; - void OnSendError(UdpSocket* socket, Error error) override; - void OnError(UdpSocket* socket, Error error) override; - - // ServiceListenerImpl::Delegate overrides. - void StartListener() override; - void StartAndSuspendListener() override; - void StopListener() override; - void SuspendListener() override; - void ResumeListener() override; - void SearchNow(ServiceListener::State from) override; - - // ServicePublisherImpl::Delegate overrides. - void StartPublisher() override; - void StartAndSuspendPublisher() override; - void StopPublisher() override; - void SuspendPublisher() override; - void ResumePublisher() override; - - protected: - void HandleMdnsEvents(); - - std::unique_ptr<MdnsResponderAdapter> mdns_responder_; - - private: - // Create internal versions of all public methods. These are used to push all - // calls to these methods to the task runner. - // TODO(rwkeane): Clean up these methods. Some result in multiple pushes to - // the task runner when just one would suffice. - // ServiceListenerImpl::Delegate overrides. - void StartListenerInternal(); - void StartAndSuspendListenerInternal(); - void StopListenerInternal(); - void SuspendListenerInternal(); - void ResumeListenerInternal(); - void SearchNowInternal(ServiceListener::State from); - void StartPublisherInternal(); - void StartAndSuspendPublisherInternal(); - void StopPublisherInternal(); - void SuspendPublisherInternal(); - void ResumePublisherInternal(); - - // NOTE: service_instance implicit in map key. - struct ServiceInstance { - UdpSocket* ptr_socket = nullptr; - DomainName domain_name; - uint16_t port = 0; - bool has_ptr_record = false; - std::vector<std::string> txt_info; - - // |port| == 0 signals that we have no SRV record. - bool has_srv() const { return port != 0; } - }; - - // NOTE: hostname implicit in map key. - struct HostInfo { - std::vector<ServiceInstance*> services; - IPAddress v4_address; - IPAddress v6_address; - }; - - struct NetworkScopedDomainName { - UdpSocket* socket; - DomainName domain_name; - }; - - struct NetworkScopedDomainNameComparator { - bool operator()(const NetworkScopedDomainName& a, - const NetworkScopedDomainName& b) const; - }; - - using InstanceNameSet = std::set<DomainName, DomainNameComparator>; - - void StartListening(); - void StopListening(); - void StartService(); - void StopService(); - void StopMdnsResponder(); - void UpdatePendingServiceInfoSet(InstanceNameSet* modified_instance_names, - const DomainName& domain_name); - void RemoveAllReceivers(); - - // NOTE: |modified_instance_names| is used to track which service instances - // are modified by the record events. See HandleMdnsEvents for more details. - bool HandlePtrEvent(const PtrEvent& ptr_event, - InstanceNameSet* modified_instance_names); - bool HandleSrvEvent(const SrvEvent& srv_event, - InstanceNameSet* modified_instance_names); - bool HandleTxtEvent(const TxtEvent& txt_event, - InstanceNameSet* modified_instance_names); - bool HandleAddressEvent(UdpSocket* socket, - QueryEventHeader::Type response_type, - const DomainName& domain_name, - bool a_event, - const IPAddress& address, - InstanceNameSet* modified_instance_names); - bool HandleAEvent(const AEvent& a_event, - InstanceNameSet* modified_instance_names); - bool HandleAaaaEvent(const AaaaEvent& aaaa_event, - InstanceNameSet* modified_instance_names); - - HostInfo* AddOrGetHostInfo(UdpSocket* socket, const DomainName& domain_name); - HostInfo* GetHostInfo(UdpSocket* socket, const DomainName& domain_name); - bool IsServiceReady(const ServiceInstance& instance, HostInfo* host) const; - NetworkInterfaceIndex GetNetworkInterfaceIndexFromSocket( - const UdpSocket* socket) const; - - // Runs background tasks to manage the internal mDNS state. - void RunBackgroundTasks(); - - // Service type separated as service name and service protocol for both - // listening and publishing (e.g. {"_openscreen", "_udp"}). - std::array<std::string, 2> service_type_; - - // The following variables all relate to what MdnsResponderService publishes, - // if anything. - std::string service_hostname_; - std::string service_instance_name_; - uint16_t service_port_; - std::vector<NetworkInterfaceIndex> interface_index_allowlist_; - std::map<std::string, std::string> service_txt_data_; - - std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory_; - std::unique_ptr<MdnsPlatformService> platform_; - std::vector<MdnsPlatformService::BoundInterface> bound_interfaces_; - - // A map of service information collected from PTR, SRV, and TXT records. It - // is keyed by service instance names. - std::map<DomainName, std::unique_ptr<ServiceInstance>, DomainNameComparator> - service_by_name_; - - // The map key is a combination of the interface to which the address records - // belong and the hostname of the address records. The values are IPAddresses - // for the given hostname on the given network and pointers to dependent - // service instances. The service instance pointers act as a reference count - // to keep the A/AAAA queries alive, when more than one service refers to the - // same hostname. This is not currently used by openscreen, but is used by - // Cast, so may be supported in openscreen in the future. - std::map<NetworkScopedDomainName, HostInfo, NetworkScopedDomainNameComparator> - network_scoped_domain_to_host_; - - std::map<std::string, ServiceInfo> receiver_info_; - - TaskRunner* const task_runner_; - - // Scheduled to run periodic background tasks. - Alarm background_tasks_alarm_; - - friend class TestingMdnsResponderService; -}; - -} // namespace osp -} // namespace openscreen - -#endif // OSP_IMPL_MDNS_RESPONDER_SERVICE_H_ diff --git a/osp/impl/mdns_responder_service_unittest.cc b/osp/impl/mdns_responder_service_unittest.cc deleted file mode 100644 index 1d542a3e..00000000 --- a/osp/impl/mdns_responder_service_unittest.cc +++ /dev/null @@ -1,884 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/mdns_responder_service.h" - -#include <cstdint> -#include <iostream> -#include <memory> -#include <utility> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "osp/impl/service_listener_impl.h" -#include "osp/impl/testing/fake_mdns_platform_service.h" -#include "osp/impl/testing/fake_mdns_responder_adapter.h" -#include "platform/test/fake_task_runner.h" - -namespace openscreen { -namespace osp { - -// Child of the MdnsResponderService for testing purposes. Only difference -// betweeen this and the base class is that methods on this class are executed -// synchronously, rather than pushed to the task runner for later execution. -class TestingMdnsResponderService final : public MdnsResponderService { - public: - TestingMdnsResponderService( - FakeTaskRunner* task_runner, - const std::string& service_name, - const std::string& service_protocol, - std::unique_ptr<MdnsResponderAdapterFactory> mdns_responder_factory, - std::unique_ptr<MdnsPlatformService> platform_service) - : MdnsResponderService(&FakeClock::now, - task_runner, - service_name, - service_protocol, - std::move(mdns_responder_factory), - std::move(platform_service)) {} - ~TestingMdnsResponderService() = default; - - // Override the default ServiceListenerImpl and ServicePublisherImpl - // implementations. These call the internal implementations of each of the - // methods provided, meaning that the end result of the call is the same, but - // without pushing to the task runner and waiting for it to be pulled off - // again. - // ServiceListenerImpl::Delegate overrides. - void StartListener() override { StartListenerInternal(); } - void StartAndSuspendListener() override { StartAndSuspendListenerInternal(); } - void StopListener() override { StopListenerInternal(); } - void SuspendListener() override { SuspendListenerInternal(); } - void ResumeListener() override { ResumeListenerInternal(); } - void SearchNow(ServiceListener::State from) override { - SearchNowInternal(from); - } - - // ServicePublisherImpl::Delegate overrides. - void StartPublisher() override { StartPublisherInternal(); } - void StartAndSuspendPublisher() override { - StartAndSuspendPublisherInternal(); - } - void StopPublisher() override { StopPublisherInternal(); } - void SuspendPublisher() override { SuspendPublisherInternal(); } - void ResumePublisher() override { ResumePublisherInternal(); } - - // Handles new events as OnRead does, but without the need of a TaskRunner. - void HandleNewEvents() { - if (!mdns_responder_) { - return; - } - - mdns_responder_->RunTasks(); - HandleMdnsEvents(); - } -}; - -class FakeMdnsResponderAdapterFactory final - : public MdnsResponderAdapterFactory, - public FakeMdnsResponderAdapter::LifetimeObserver { - public: - ~FakeMdnsResponderAdapterFactory() override = default; - - std::unique_ptr<MdnsResponderAdapter> Create() override { - auto mdns = std::make_unique<FakeMdnsResponderAdapter>(); - mdns->SetLifetimeObserver(this); - last_mdns_responder_ = mdns.get(); - ++instances_; - return mdns; - } - - void OnDestroyed() override { - last_running_ = last_mdns_responder_->running(); - last_registered_services_size_ = - last_mdns_responder_->registered_services().size(); - last_mdns_responder_ = nullptr; - } - - FakeMdnsResponderAdapter* last_mdns_responder() { - return last_mdns_responder_; - } - - int32_t instances() const { return instances_; } - bool last_running() const { return last_running_; } - size_t last_registered_services_size() const { - return last_registered_services_size_; - } - - private: - FakeMdnsResponderAdapter* last_mdns_responder_ = nullptr; - int32_t instances_ = 0; - bool last_running_ = false; - size_t last_registered_services_size_ = 0; -}; - -namespace { - -using ::testing::_; - -constexpr char kTestServiceInstance[] = "turtle"; -constexpr char kTestServiceName[] = "_foo"; -constexpr char kTestServiceProtocol[] = "_udp"; -constexpr char kTestHostname[] = "hostname"; -constexpr uint16_t kTestPort = 12345; - -// Wrapper around the above class. In MdnsResponderServiceTest, we need to both -// pass a unique_ptr to the created MdnsResponderService and to maintain a -// local pointer as well. Doing this with the same object causes a race -// condition, where ~FakeMdnsResponderAdapter() calls observer_->OnDestroyed() -// after the object is already deleted, resulting in a seg fault. This is to -// prevent that race condition. -class WrapperMdnsResponderAdapterFactory final - : public MdnsResponderAdapterFactory, - public FakeMdnsResponderAdapter::LifetimeObserver { - public: - explicit WrapperMdnsResponderAdapterFactory( - FakeMdnsResponderAdapterFactory* ptr) - : other_(ptr) {} - - std::unique_ptr<MdnsResponderAdapter> Create() override { - return other_->Create(); - } - - void OnDestroyed() override { other_->OnDestroyed(); } - - private: - FakeMdnsResponderAdapterFactory* other_; -}; - -class MockServiceListenerObserver final : public ServiceListener::Observer { - public: - ~MockServiceListenerObserver() override = default; - - MOCK_METHOD0(OnStarted, void()); - MOCK_METHOD0(OnStopped, void()); - MOCK_METHOD0(OnSuspended, void()); - MOCK_METHOD0(OnSearching, void()); - - MOCK_METHOD1(OnReceiverAdded, void(const ServiceInfo&)); - MOCK_METHOD1(OnReceiverChanged, void(const ServiceInfo&)); - MOCK_METHOD1(OnReceiverRemoved, void(const ServiceInfo&)); - MOCK_METHOD0(OnAllReceiversRemoved, void()); - - MOCK_METHOD1(OnError, void(ServiceListenerError)); - MOCK_METHOD1(OnMetrics, void(ServiceListener::Metrics)); -}; - -class MockServicePublisherObserver final : public ServicePublisher::Observer { - public: - ~MockServicePublisherObserver() override = default; - - MOCK_METHOD0(OnStarted, void()); - MOCK_METHOD0(OnStopped, void()); - MOCK_METHOD0(OnSuspended, void()); - MOCK_METHOD1(OnError, void(ServicePublisherError)); - MOCK_METHOD1(OnMetrics, void(ServicePublisher::Metrics)); -}; - -UdpSocket* const kDefaultSocket = - reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(16)); -UdpSocket* const kSecondSocket = - reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(24)); - -class MdnsResponderServiceTest : public ::testing::Test { - protected: - void SetUp() override { - mdns_responder_factory_ = - std::make_unique<FakeMdnsResponderAdapterFactory>(); - auto wrapper_factory = std::make_unique<WrapperMdnsResponderAdapterFactory>( - mdns_responder_factory_.get()); - clock_ = std::make_unique<FakeClock>(Clock::now()); - task_runner_ = std::make_unique<FakeTaskRunner>(clock_.get()); - auto platform_service = std::make_unique<FakeMdnsPlatformService>(); - fake_platform_service_ = platform_service.get(); - fake_platform_service_->set_interfaces(bound_interfaces_); - mdns_service_ = std::make_unique<TestingMdnsResponderService>( - task_runner_.get(), kTestServiceName, kTestServiceProtocol, - std::move(wrapper_factory), std::move(platform_service)); - service_listener_ = - std::make_unique<ServiceListenerImpl>(mdns_service_.get()); - service_listener_->AddObserver(&observer_); - - mdns_service_->SetServiceConfig(kTestHostname, kTestServiceInstance, - kTestPort, {}, {{"model", "shifty"}}); - service_publisher_ = std::make_unique<ServicePublisherImpl>( - &publisher_observer_, mdns_service_.get()); - } - - std::unique_ptr<FakeClock> clock_; - std::unique_ptr<FakeTaskRunner> task_runner_; - MockServiceListenerObserver observer_; - FakeMdnsPlatformService* fake_platform_service_; - std::unique_ptr<FakeMdnsResponderAdapterFactory> mdns_responder_factory_; - std::unique_ptr<TestingMdnsResponderService> mdns_service_; - std::unique_ptr<ServiceListenerImpl> service_listener_; - MockServicePublisherObserver publisher_observer_; - std::unique_ptr<ServicePublisherImpl> service_publisher_; - const uint8_t default_mac_[6] = {0, 11, 22, 33, 44, 55}; - const uint8_t second_mac_[6] = {55, 33, 22, 33, 44, 77}; - const IPSubnet default_subnet_{IPAddress{192, 168, 3, 2}, 24}; - const IPSubnet second_subnet_{IPAddress{10, 0, 0, 3}, 24}; - std::vector<MdnsPlatformService::BoundInterface> bound_interfaces_{ - MdnsPlatformService::BoundInterface{ - InterfaceInfo{1, - default_mac_, - "eth0", - InterfaceInfo::Type::kEthernet, - {default_subnet_}}, - default_subnet_, kDefaultSocket}, - MdnsPlatformService::BoundInterface{ - InterfaceInfo{2, - second_mac_, - "eth1", - InterfaceInfo::Type::kEthernet, - {second_subnet_}}, - second_subnet_, kSecondSocket}, - }; -}; - -} // namespace - -TEST_F(MdnsResponderServiceTest, BasicServiceStates) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kDefaultSocket); - - std::string service_id; - EXPECT_CALL(observer_, OnReceiverAdded(_)) - .WillOnce(::testing::Invoke([&service_id](const ServiceInfo& info) { - service_id = info.service_id; - EXPECT_EQ(kTestServiceInstance, info.friendly_name); - EXPECT_EQ((IPEndpoint{{192, 168, 3, 7}, kTestPort}), info.v4_endpoint); - EXPECT_FALSE(info.v6_endpoint.address); - })); - mdns_service_->HandleNewEvents(); - - mdns_responder->AddAEvent(MakeAEvent( - "gigliorononomicon", IPAddress{192, 168, 3, 8}, kDefaultSocket)); - - EXPECT_CALL(observer_, OnReceiverChanged(_)) - .WillOnce(::testing::Invoke([&service_id](const ServiceInfo& info) { - EXPECT_EQ(service_id, info.service_id); - EXPECT_EQ(kTestServiceInstance, info.friendly_name); - EXPECT_EQ((IPEndpoint{{192, 168, 3, 8}, kTestPort}), info.v4_endpoint); - EXPECT_FALSE(info.v6_endpoint.address); - })); - mdns_service_->HandleNewEvents(); - - auto ptr_remove = MakePtrEvent(kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, kDefaultSocket); - ptr_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddPtrEvent(std::move(ptr_remove)); - - EXPECT_CALL(observer_, OnReceiverRemoved(_)) - .WillOnce(::testing::Invoke([&service_id](const ServiceInfo& info) { - EXPECT_EQ(service_id, info.service_id); - })); - mdns_service_->HandleNewEvents(); -} - -TEST_F(MdnsResponderServiceTest, NetworkNetworkInterfaceIndex) { - constexpr uint8_t mac[6] = {12, 34, 56, 78, 90}; - const IPSubnet subnet{IPAddress{10, 0, 0, 2}, 24}; - bound_interfaces_.emplace_back( - InterfaceInfo{2, mac, "wlan0", InterfaceInfo::Type::kWifi, {subnet}}, - subnet, kSecondSocket); - fake_platform_service_->set_interfaces(bound_interfaces_); - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kSecondSocket); - - EXPECT_CALL(observer_, OnReceiverAdded(_)) - .WillOnce(::testing::Invoke([](const ServiceInfo& info) { - EXPECT_EQ(2, info.network_interface_index); - })); - mdns_service_->HandleNewEvents(); -} - -TEST_F(MdnsResponderServiceTest, SimultaneousFieldChanges) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kDefaultSocket); - - EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents(); - - mdns_responder->AddSrvEvent( - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "gigliorononomicon", 54321, kDefaultSocket)); - auto a_remove = MakeAEvent("gigliorononomicon", IPAddress{192, 168, 3, 7}, - kDefaultSocket); - a_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddAEvent(std::move(a_remove)); - mdns_responder->AddAEvent(MakeAEvent( - "gigliorononomicon", IPAddress{192, 168, 3, 8}, kDefaultSocket)); - - EXPECT_CALL(observer_, OnReceiverChanged(_)) - .WillOnce(::testing::Invoke([](const ServiceInfo& info) { - EXPECT_EQ((IPAddress{192, 168, 3, 8}), info.v4_endpoint.address); - EXPECT_EQ(54321, info.v4_endpoint.port); - EXPECT_FALSE(info.v6_endpoint.address); - })); - mdns_service_->HandleNewEvents(); -} - -TEST_F(MdnsResponderServiceTest, SimultaneousHostAndAddressChange) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kDefaultSocket); - - EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents(); - - auto srv_remove = - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "gigliorononomicon", kTestPort, kDefaultSocket); - srv_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddSrvEvent(std::move(srv_remove)); - mdns_responder->AddSrvEvent( - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "alpha", kTestPort, kDefaultSocket)); - mdns_responder->AddAEvent(MakeAEvent( - "gigliorononomicon", IPAddress{192, 168, 3, 8}, kDefaultSocket)); - mdns_responder->AddAEvent( - MakeAEvent("alpha", IPAddress{192, 168, 3, 10}, kDefaultSocket)); - - EXPECT_CALL(observer_, OnReceiverChanged(_)) - .WillOnce(::testing::Invoke([](const ServiceInfo& info) { - EXPECT_EQ((IPAddress{192, 168, 3, 10}), info.v4_endpoint.address); - EXPECT_FALSE(info.v6_endpoint.address); - })); - mdns_service_->HandleNewEvents(); -} - -TEST_F(MdnsResponderServiceTest, ListenerStateTransitions) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - - EXPECT_CALL(observer_, OnSuspended()); - service_listener_->Suspend(); - ASSERT_EQ(mdns_responder, mdns_responder_factory_->last_mdns_responder()); - EXPECT_FALSE(mdns_responder->running()); - - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Resume(); - ASSERT_EQ(mdns_responder, mdns_responder_factory_->last_mdns_responder()); - EXPECT_TRUE(mdns_responder->running()); - - EXPECT_CALL(observer_, OnStopped()); - service_listener_->Stop(); - ASSERT_FALSE(mdns_responder_factory_->last_mdns_responder()); - - EXPECT_CALL(observer_, OnSuspended()); - auto instances = mdns_responder_factory_->instances(); - service_listener_->StartAndSuspend(); - EXPECT_EQ(instances + 1, mdns_responder_factory_->instances()); - mdns_responder = mdns_responder_factory_->last_mdns_responder(); - EXPECT_FALSE(mdns_responder->running()); - - EXPECT_CALL(observer_, OnStopped()); - service_listener_->Stop(); - ASSERT_FALSE(mdns_responder_factory_->last_mdns_responder()); -} - -TEST_F(MdnsResponderServiceTest, BasicServicePublish) { - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - - const auto& services = mdns_responder->registered_services(); - ASSERT_EQ(1u, services.size()); - EXPECT_EQ(kTestServiceInstance, services[0].service_instance); - EXPECT_EQ(kTestServiceName, services[0].service_name); - EXPECT_EQ(kTestServiceProtocol, services[0].service_protocol); - auto host_labels = services[0].target_host.GetLabels(); - ASSERT_EQ(2u, host_labels.size()); - EXPECT_EQ(kTestHostname, host_labels[0]); - EXPECT_EQ("local", host_labels[1]); - EXPECT_EQ(kTestPort, services[0].target_port); - - EXPECT_CALL(publisher_observer_, OnStopped()); - service_publisher_->Stop(); - - EXPECT_FALSE(mdns_responder_factory_->last_mdns_responder()); - EXPECT_EQ(0u, mdns_responder_factory_->last_registered_services_size()); -} - -TEST_F(MdnsResponderServiceTest, PublisherStateTransitions) { - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - EXPECT_EQ(1u, mdns_responder->registered_services().size()); - - EXPECT_CALL(publisher_observer_, OnSuspended()); - service_publisher_->Suspend(); - EXPECT_EQ(0u, mdns_responder->registered_services().size()); - - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Resume(); - EXPECT_EQ(1u, mdns_responder->registered_services().size()); - - EXPECT_CALL(publisher_observer_, OnStopped()); - service_publisher_->Stop(); - EXPECT_EQ(0u, mdns_responder_factory_->last_registered_services_size()); - - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Start(); - mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - EXPECT_EQ(1u, mdns_responder->registered_services().size()); - EXPECT_CALL(publisher_observer_, OnSuspended()); - service_publisher_->Suspend(); - EXPECT_EQ(0u, mdns_responder->registered_services().size()); - EXPECT_CALL(publisher_observer_, OnStopped()); - service_publisher_->Stop(); - EXPECT_FALSE(mdns_responder_factory_->last_mdns_responder()); - EXPECT_EQ(0u, mdns_responder_factory_->last_registered_services_size()); -} - -TEST_F(MdnsResponderServiceTest, PublisherObeysInterfaceAllowlist) { - { - mdns_service_->SetServiceConfig(kTestHostname, kTestServiceInstance, - kTestPort, {}, {{"model", "shifty"}}); - - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - auto interfaces = mdns_responder->registered_interfaces(); - ASSERT_EQ(2u, interfaces.size()); - EXPECT_EQ(kDefaultSocket, interfaces[0].socket); - EXPECT_EQ(kSecondSocket, interfaces[1].socket); - - EXPECT_CALL(publisher_observer_, OnStopped()); - service_publisher_->Stop(); - } - { - mdns_service_->SetServiceConfig(kTestHostname, kTestServiceInstance, - kTestPort, {1, 2}, {{"model", "shifty"}}); - - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - auto interfaces = mdns_responder->registered_interfaces(); - ASSERT_EQ(2u, interfaces.size()); - EXPECT_EQ(kDefaultSocket, interfaces[0].socket); - EXPECT_EQ(kSecondSocket, interfaces[1].socket); - - EXPECT_CALL(publisher_observer_, OnStopped()); - service_publisher_->Stop(); - } - { - mdns_service_->SetServiceConfig(kTestHostname, kTestServiceInstance, - kTestPort, {2}, {{"model", "shifty"}}); - - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - auto interfaces = mdns_responder->registered_interfaces(); - ASSERT_EQ(1u, interfaces.size()); - EXPECT_EQ(kSecondSocket, interfaces[0].socket); - - EXPECT_CALL(publisher_observer_, OnStopped()); - service_publisher_->Stop(); - } -} - -TEST_F(MdnsResponderServiceTest, ListenAndPublish) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - - { - auto interfaces = mdns_responder->registered_interfaces(); - ASSERT_EQ(2u, interfaces.size()); - EXPECT_EQ(kDefaultSocket, interfaces[0].socket); - EXPECT_EQ(kSecondSocket, interfaces[1].socket); - } - - mdns_service_->SetServiceConfig(kTestHostname, kTestServiceInstance, - kTestPort, {2}, {{"model", "shifty"}}); - - auto instances = mdns_responder_factory_->instances(); - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Start(); - - EXPECT_EQ(instances, mdns_responder_factory_->instances()); - ASSERT_TRUE(mdns_responder->running()); - { - auto interfaces = mdns_responder->registered_interfaces(); - ASSERT_EQ(1u, interfaces.size()); - EXPECT_EQ(kSecondSocket, interfaces[0].socket); - } - - EXPECT_CALL(observer_, OnStopped()); - service_listener_->Stop(); - ASSERT_TRUE(mdns_responder->running()); - EXPECT_EQ(1u, mdns_responder->registered_interfaces().size()); - - EXPECT_CALL(publisher_observer_, OnStopped()); - service_publisher_->Stop(); - EXPECT_FALSE(mdns_responder_factory_->last_mdns_responder()); - EXPECT_EQ(0u, mdns_responder_factory_->last_registered_services_size()); -} - -TEST_F(MdnsResponderServiceTest, PublishAndListen) { - mdns_service_->SetServiceConfig(kTestHostname, kTestServiceInstance, - kTestPort, {2}, {{"model", "shifty"}}); - - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - { - auto interfaces = mdns_responder->registered_interfaces(); - ASSERT_EQ(1u, interfaces.size()); - EXPECT_EQ(kSecondSocket, interfaces[0].socket); - } - - auto instances = mdns_responder_factory_->instances(); - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - EXPECT_EQ(instances, mdns_responder_factory_->instances()); - ASSERT_TRUE(mdns_responder->running()); - { - auto interfaces = mdns_responder->registered_interfaces(); - ASSERT_EQ(1u, interfaces.size()); - EXPECT_EQ(kSecondSocket, interfaces[0].socket); - } - - EXPECT_CALL(publisher_observer_, OnStopped()); - service_publisher_->Stop(); - ASSERT_TRUE(mdns_responder->running()); - EXPECT_EQ(1u, mdns_responder->registered_interfaces().size()); - - EXPECT_CALL(observer_, OnStopped()); - service_listener_->Stop(); - EXPECT_FALSE(mdns_responder_factory_->last_mdns_responder()); - EXPECT_EQ(0u, mdns_responder_factory_->last_registered_services_size()); -} - -TEST_F(MdnsResponderServiceTest, AddressQueryStopped) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kDefaultSocket); - - EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents(); - - auto srv_remove = - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "gigliorononomicon", kTestPort, kDefaultSocket); - srv_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddSrvEvent(std::move(srv_remove)); - - EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents(); - - EXPECT_FALSE(mdns_responder->ptr_queries_empty()); - EXPECT_FALSE(mdns_responder->srv_queries_empty()); - EXPECT_FALSE(mdns_responder->txt_queries_empty()); - EXPECT_TRUE(mdns_responder->a_queries_empty()); - EXPECT_TRUE(mdns_responder->aaaa_queries_empty()); -} - -TEST_F(MdnsResponderServiceTest, AddressQueryRefCount) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kDefaultSocket); - AddEventsForNewService(mdns_responder, "instance-2", kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", 4321, - {"model=shwofty", "id=asdf"}, - IPAddress{192, 168, 3, 7}, kDefaultSocket); - - EXPECT_CALL(observer_, OnReceiverAdded(_)).Times(2); - mdns_service_->HandleNewEvents(); - - auto srv_remove = - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "gigliorononomicon", kTestPort, kDefaultSocket); - srv_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddSrvEvent(std::move(srv_remove)); - - EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents(); - - EXPECT_FALSE(mdns_responder->ptr_queries_empty()); - EXPECT_FALSE(mdns_responder->srv_queries_empty()); - EXPECT_FALSE(mdns_responder->txt_queries_empty()); - EXPECT_FALSE(mdns_responder->a_queries_empty()); - EXPECT_FALSE(mdns_responder->aaaa_queries_empty()); - - srv_remove = - MakeSrvEvent("instance-2", kTestServiceName, kTestServiceProtocol, - "gigliorononomicon", 4321, kDefaultSocket); - srv_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddSrvEvent(std::move(srv_remove)); - - EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents(); - - EXPECT_FALSE(mdns_responder->ptr_queries_empty()); - EXPECT_FALSE(mdns_responder->srv_queries_empty()); - EXPECT_FALSE(mdns_responder->txt_queries_empty()); - EXPECT_TRUE(mdns_responder->a_queries_empty()); - EXPECT_TRUE(mdns_responder->aaaa_queries_empty()); -} - -TEST_F(MdnsResponderServiceTest, ServiceQueriesStoppedSrvFirst) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kDefaultSocket); - - EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents(); - - auto srv_remove = - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "gigliorononomicon", kTestPort, kDefaultSocket); - srv_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddSrvEvent(std::move(srv_remove)); - - EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents(); - - EXPECT_FALSE(mdns_responder->ptr_queries_empty()); - EXPECT_FALSE(mdns_responder->srv_queries_empty()); - EXPECT_FALSE(mdns_responder->txt_queries_empty()); - EXPECT_TRUE(mdns_responder->a_queries_empty()); - EXPECT_TRUE(mdns_responder->aaaa_queries_empty()); - - auto ptr_remove = MakePtrEvent(kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, kDefaultSocket); - ptr_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddPtrEvent(std::move(ptr_remove)); - mdns_service_->HandleNewEvents(); - - EXPECT_FALSE(mdns_responder->ptr_queries_empty()); - EXPECT_TRUE(mdns_responder->srv_queries_empty()); - EXPECT_TRUE(mdns_responder->txt_queries_empty()); - EXPECT_TRUE(mdns_responder->a_queries_empty()); - EXPECT_TRUE(mdns_responder->aaaa_queries_empty()); -} - -TEST_F(MdnsResponderServiceTest, ServiceQueriesStoppedPtrFirst) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kDefaultSocket); - - EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents(); - - auto ptr_remove = MakePtrEvent(kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, kDefaultSocket); - ptr_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddPtrEvent(std::move(ptr_remove)); - - EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents(); - - EXPECT_FALSE(mdns_responder->ptr_queries_empty()); - EXPECT_FALSE(mdns_responder->srv_queries_empty()); - EXPECT_FALSE(mdns_responder->txt_queries_empty()); - EXPECT_FALSE(mdns_responder->a_queries_empty()); - EXPECT_FALSE(mdns_responder->aaaa_queries_empty()); - - auto srv_remove = - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "gigliorononomicon", kTestPort, kDefaultSocket); - srv_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddSrvEvent(std::move(srv_remove)); - mdns_service_->HandleNewEvents(); - - EXPECT_FALSE(mdns_responder->ptr_queries_empty()); - EXPECT_TRUE(mdns_responder->srv_queries_empty()); - EXPECT_TRUE(mdns_responder->txt_queries_empty()); - EXPECT_TRUE(mdns_responder->a_queries_empty()); - EXPECT_TRUE(mdns_responder->aaaa_queries_empty()); -} - -TEST_F(MdnsResponderServiceTest, MultipleInterfaceRemove) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kDefaultSocket); - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kSecondSocket); - - EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents(); - - auto srv_remove1 = - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "gigliorononomicon", kTestPort, kSecondSocket); - srv_remove1.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddSrvEvent(std::move(srv_remove1)); - EXPECT_CALL(observer_, OnReceiverChanged(_)).Times(0); - EXPECT_CALL(observer_, OnReceiverRemoved(_)).Times(0); - mdns_service_->HandleNewEvents(); - - auto srv_remove2 = - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "gigliorononomicon", kTestPort, kDefaultSocket); - srv_remove2.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddSrvEvent(std::move(srv_remove2)); - EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents(); - EXPECT_TRUE(mdns_responder->a_queries_empty()); - - auto ptr_remove = MakePtrEvent(kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, kDefaultSocket); - ptr_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddPtrEvent(std::move(ptr_remove)); - mdns_service_->HandleNewEvents(); - - EXPECT_FALSE(mdns_responder->ptr_queries_empty()); - EXPECT_TRUE(mdns_responder->srv_queries_empty()); - EXPECT_TRUE(mdns_responder->txt_queries_empty()); - EXPECT_TRUE(mdns_responder->a_queries_empty()); - EXPECT_TRUE(mdns_responder->aaaa_queries_empty()); -} - -TEST_F(MdnsResponderServiceTest, ResumeService) { - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - ASSERT_TRUE(mdns_responder); - ASSERT_TRUE(mdns_responder->running()); - - EXPECT_EQ(2u, mdns_responder->registered_interfaces().size()); - ASSERT_EQ(1u, mdns_responder->registered_services().size()); - - EXPECT_CALL(publisher_observer_, OnSuspended()); - service_publisher_->Suspend(); - - EXPECT_TRUE(mdns_responder_factory_->last_mdns_responder()); - EXPECT_EQ(0u, mdns_responder->registered_services().size()); - - EXPECT_CALL(publisher_observer_, OnStarted()); - service_publisher_->Resume(); - - EXPECT_EQ(2u, mdns_responder->registered_interfaces().size()); - ASSERT_EQ(1u, mdns_responder->registered_services().size()); -} - -TEST_F(MdnsResponderServiceTest, RestorePtrNotifiesObserver) { - EXPECT_CALL(observer_, OnStarted()); - service_listener_->Start(); - - auto* mdns_responder = mdns_responder_factory_->last_mdns_responder(); - - AddEventsForNewService(mdns_responder, kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, "gigliorononomicon", kTestPort, - {"model=shifty", "id=asdf"}, IPAddress{192, 168, 3, 7}, - kDefaultSocket); - - EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents(); - - auto ptr_remove = MakePtrEvent(kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, kDefaultSocket); - ptr_remove.header.response_type = QueryEventHeader::Type::kRemoved; - mdns_responder->AddPtrEvent(std::move(ptr_remove)); - - EXPECT_CALL(observer_, OnReceiverRemoved(_)); - mdns_service_->HandleNewEvents(); - - auto ptr_add = MakePtrEvent(kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, kDefaultSocket); - mdns_responder->AddPtrEvent(std::move(ptr_add)); - - EXPECT_CALL(observer_, OnReceiverAdded(_)); - mdns_service_->HandleNewEvents(); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/mdns_service_listener_factory.cc b/osp/impl/mdns_service_listener_factory.cc deleted file mode 100644 index cae4a341..00000000 --- a/osp/impl/mdns_service_listener_factory.cc +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/public/mdns_service_listener_factory.h" - -#include "osp/impl/internal_services.h" - -namespace openscreen { - -class TaskRunner; - -namespace osp { - -// static -std::unique_ptr<ServiceListener> MdnsServiceListenerFactory::Create( - const MdnsServiceListenerConfig& config, - ServiceListener::Observer* observer, - TaskRunner* task_runner) { - return InternalServices::CreateListener(config, observer, task_runner); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/mdns_service_publisher_factory.cc b/osp/impl/mdns_service_publisher_factory.cc deleted file mode 100644 index f055e772..00000000 --- a/osp/impl/mdns_service_publisher_factory.cc +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/public/mdns_service_publisher_factory.h" - -#include "osp/impl/internal_services.h" - -namespace openscreen { - -class TaskRunner; - -namespace osp { - -// static -std::unique_ptr<ServicePublisher> MdnsServicePublisherFactory::Create( - const ServicePublisher::Config& config, - ServicePublisher::Observer* observer, - TaskRunner* task_runner) { - return InternalServices::CreatePublisher(config, observer, task_runner); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/network_service_manager.cc b/osp/impl/network_service_manager.cc index d05192c0..10890ca8 100644 --- a/osp/impl/network_service_manager.cc +++ b/osp/impl/network_service_manager.cc @@ -17,14 +17,14 @@ namespace osp { // static NetworkServiceManager* NetworkServiceManager::Create( std::unique_ptr<ServiceListener> mdns_listener, - std::unique_ptr<ServicePublisher> mdns_publisher, + std::unique_ptr<ServicePublisher> service_publisher, std::unique_ptr<ProtocolConnectionClient> connection_client, std::unique_ptr<ProtocolConnectionServer> connection_server) { // TODO(mfoltz): Convert to assertion failure if (g_network_service_manager_instance) return nullptr; g_network_service_manager_instance = new NetworkServiceManager( - std::move(mdns_listener), std::move(mdns_publisher), + std::move(mdns_listener), std::move(service_publisher), std::move(connection_client), std::move(connection_server)); return g_network_service_manager_instance; } @@ -50,8 +50,8 @@ ServiceListener* NetworkServiceManager::GetMdnsServiceListener() { return mdns_listener_.get(); } -ServicePublisher* NetworkServiceManager::GetMdnsServicePublisher() { - return mdns_publisher_.get(); +ServicePublisher* NetworkServiceManager::GetServicePublisher() { + return service_publisher_.get(); } ProtocolConnectionClient* NetworkServiceManager::GetProtocolConnectionClient() { @@ -64,11 +64,11 @@ ProtocolConnectionServer* NetworkServiceManager::GetProtocolConnectionServer() { NetworkServiceManager::NetworkServiceManager( std::unique_ptr<ServiceListener> mdns_listener, - std::unique_ptr<ServicePublisher> mdns_publisher, + std::unique_ptr<ServicePublisher> service_publisher, std::unique_ptr<ProtocolConnectionClient> connection_client, std::unique_ptr<ProtocolConnectionServer> connection_server) : mdns_listener_(std::move(mdns_listener)), - mdns_publisher_(std::move(mdns_publisher)), + service_publisher_(std::move(service_publisher)), connection_client_(std::move(connection_client)), connection_server_(std::move(connection_server)) {} diff --git a/osp/impl/presentation/presentation_connection.cc b/osp/impl/presentation/presentation_connection.cc index 9924ed6e..4b616017 100644 --- a/osp/impl/presentation/presentation_connection.cc +++ b/osp/impl/presentation/presentation_connection.cc @@ -131,7 +131,6 @@ Error Connection::SendBinary(std::vector<uint8_t>&& data) { new (&cbor_message.message.bytes) std::vector<uint8_t>(std::move(data)); return WriteConnectionMessage(cbor_message, protocol_connection_.get()); - return Error::None(); } Error Connection::Close(CloseReason reason) { diff --git a/osp/impl/presentation/url_availability_requester.cc b/osp/impl/presentation/url_availability_requester.cc index a7b27972..5e4f4ade 100644 --- a/osp/impl/presentation/url_availability_requester.cc +++ b/osp/impl/presentation/url_availability_requester.cc @@ -461,7 +461,7 @@ ErrorOr<size_t> UrlAvailabilityRequester::ReceiverRequester::OnStreamMessage( StopWatching(&response_watch); return result; } - } break; + } case msgs::Type::kPresentationUrlAvailabilityEvent: { msgs::PresentationUrlAvailabilityEvent event; ssize_t result = msgs::DecodePresentationUrlAvailabilityEvent( @@ -483,7 +483,7 @@ ErrorOr<size_t> UrlAvailabilityRequester::ReceiverRequester::OnStreamMessage( } return result; } - } break; + } default: break; } diff --git a/osp/impl/quic/quic_connection.h b/osp/impl/quic/quic_connection.h index e00e25a0..6fbc81af 100644 --- a/osp/impl/quic/quic_connection.h +++ b/osp/impl/quic/quic_connection.h @@ -17,12 +17,14 @@ class QuicStream { public: class Delegate { public: - virtual ~Delegate() = default; virtual void OnReceived(QuicStream* stream, const char* data, size_t data_size) = 0; virtual void OnClose(uint64_t stream_id) = 0; + + protected: + virtual ~Delegate() = default; }; QuicStream(Delegate* delegate, uint64_t id) : delegate_(delegate), id_(id) {} @@ -41,7 +43,6 @@ class QuicConnection : public UdpSocket::Client { public: class Delegate { public: - virtual ~Delegate() = default; // Called when the QUIC handshake has successfully completed. virtual void OnCryptoHandshakeComplete(uint64_t connection_id) = 0; @@ -63,6 +64,9 @@ class QuicConnection : public UdpSocket::Client { // will be returned via OnIncomingStream immediately after this call. virtual QuicStream::Delegate* NextStreamDelegate(uint64_t connection_id, uint64_t stream_id) = 0; + + protected: + virtual ~Delegate() = default; }; explicit QuicConnection(Delegate* delegate) : delegate_(delegate) {} diff --git a/osp/impl/service_listener_impl.h b/osp/impl/service_listener_impl.h index b94dcdb1..9516ff18 100644 --- a/osp/impl/service_listener_impl.h +++ b/osp/impl/service_listener_impl.h @@ -22,7 +22,6 @@ class ServiceListenerImpl final : public ServiceListener, class Delegate { public: Delegate(); - virtual ~Delegate(); void SetListenerImpl(ServiceListenerImpl* listener); @@ -34,6 +33,7 @@ class ServiceListenerImpl final : public ServiceListener, virtual void SearchNow(State from) = 0; protected: + virtual ~Delegate(); void SetState(State state) { listener_->SetState(state); } ServiceListenerImpl* listener_ = nullptr; diff --git a/osp/impl/service_publisher_impl.cc b/osp/impl/service_publisher_impl.cc index bde3e528..8a984969 100644 --- a/osp/impl/service_publisher_impl.cc +++ b/osp/impl/service_publisher_impl.cc @@ -4,6 +4,8 @@ #include "osp/impl/service_publisher_impl.h" +#include <utility> + #include "util/osp_logging.h" namespace openscreen { @@ -44,8 +46,8 @@ void ServicePublisherImpl::Delegate::SetPublisherImpl( } ServicePublisherImpl::ServicePublisherImpl(Observer* observer, - Delegate* delegate) - : ServicePublisher(observer), delegate_(delegate) { + std::unique_ptr<Delegate> delegate) + : ServicePublisher(observer), delegate_(std::move(delegate)) { delegate_->SetPublisherImpl(this); } @@ -55,14 +57,14 @@ bool ServicePublisherImpl::Start() { if (state_ != State::kStopped) return false; state_ = State::kStarting; - delegate_->StartPublisher(); + delegate_->StartPublisher(config_); return true; } bool ServicePublisherImpl::StartAndSuspend() { if (state_ != State::kStopped) return false; state_ = State::kStarting; - delegate_->StartAndSuspendPublisher(); + delegate_->StartAndSuspendPublisher(config_); return true; } bool ServicePublisherImpl::Stop() { @@ -84,7 +86,7 @@ bool ServicePublisherImpl::Resume() { if (state_ != State::kSuspended) return false; - delegate_->ResumePublisher(); + delegate_->ResumePublisher(config_); return true; } diff --git a/osp/impl/service_publisher_impl.h b/osp/impl/service_publisher_impl.h index 1817ab1f..fa2c3890 100644 --- a/osp/impl/service_publisher_impl.h +++ b/osp/impl/service_publisher_impl.h @@ -5,6 +5,8 @@ #ifndef OSP_IMPL_SERVICE_PUBLISHER_IMPL_H_ #define OSP_IMPL_SERVICE_PUBLISHER_IMPL_H_ +#include <memory> + #include "osp/impl/with_destruction_callback.h" #include "osp/public/service_publisher.h" #include "platform/base/macros.h" @@ -22,11 +24,12 @@ class ServicePublisherImpl final : public ServicePublisher, void SetPublisherImpl(ServicePublisherImpl* publisher); - virtual void StartPublisher() = 0; - virtual void StartAndSuspendPublisher() = 0; + virtual void StartPublisher(const ServicePublisher::Config& config) = 0; + virtual void StartAndSuspendPublisher( + const ServicePublisher::Config& config) = 0; virtual void StopPublisher() = 0; virtual void SuspendPublisher() = 0; - virtual void ResumePublisher() = 0; + virtual void ResumePublisher(const ServicePublisher::Config& config) = 0; protected: void SetState(State state) { publisher_->SetState(state); } @@ -37,7 +40,7 @@ class ServicePublisherImpl final : public ServicePublisher, // |observer| is optional. If it is provided, it will receive appropriate // notifications about this ServicePublisher. |delegate| is required and // is used to implement state transitions. - ServicePublisherImpl(Observer* observer, Delegate* delegate); + ServicePublisherImpl(Observer* observer, std::unique_ptr<Delegate> delegate); ~ServicePublisherImpl() override; // ServicePublisher overrides. @@ -56,7 +59,7 @@ class ServicePublisherImpl final : public ServicePublisher, // by the observer interface. void MaybeNotifyObserver(); - Delegate* const delegate_; + std::unique_ptr<Delegate> delegate_; OSP_DISALLOW_COPY_AND_ASSIGN(ServicePublisherImpl); }; diff --git a/osp/impl/service_publisher_impl_unittest.cc b/osp/impl/service_publisher_impl_unittest.cc index 8c8bc9ec..b77a8aa5 100644 --- a/osp/impl/service_publisher_impl_unittest.cc +++ b/osp/impl/service_publisher_impl_unittest.cc @@ -5,6 +5,7 @@ #include "osp/impl/service_publisher_impl.h" #include <memory> +#include <utility> #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -13,6 +14,7 @@ namespace openscreen { namespace osp { namespace { +using ::testing::_; using ::testing::Expectation; using ::testing::NiceMock; @@ -26,7 +28,7 @@ class MockObserver final : public ServicePublisher::Observer { MOCK_METHOD0(OnStopped, void()); MOCK_METHOD0(OnSuspended, void()); - MOCK_METHOD1(OnError, void(ServicePublisherError)); + MOCK_METHOD1(OnError, void(Error)); MOCK_METHOD1(OnMetrics, void(ServicePublisher::Metrics)); }; @@ -38,23 +40,27 @@ class MockMdnsDelegate : public ServicePublisherImpl::Delegate { using ServicePublisherImpl::Delegate::SetState; - MOCK_METHOD0(StartPublisher, void()); - MOCK_METHOD0(StartAndSuspendPublisher, void()); + MOCK_METHOD1(StartPublisher, void(const ServicePublisher::Config&)); + MOCK_METHOD1(StartAndSuspendPublisher, void(const ServicePublisher::Config&)); MOCK_METHOD0(StopPublisher, void()); MOCK_METHOD0(SuspendPublisher, void()); - MOCK_METHOD0(ResumePublisher, void()); + MOCK_METHOD1(ResumePublisher, void(const ServicePublisher::Config&)); MOCK_METHOD0(RunTasksPublisher, void()); }; class ServicePublisherImplTest : public ::testing::Test { protected: void SetUp() override { - service_publisher_ = - std::make_unique<ServicePublisherImpl>(nullptr, &mock_delegate_); + auto mock_delegate = std::make_unique<NiceMock<MockMdnsDelegate>>(); + mock_delegate_ = mock_delegate.get(); + service_publisher_ = std::make_unique<ServicePublisherImpl>( + nullptr, std::move(mock_delegate)); + service_publisher_->SetConfig(config); } - NiceMock<MockMdnsDelegate> mock_delegate_; + NiceMock<MockMdnsDelegate>* mock_delegate_ = nullptr; std::unique_ptr<ServicePublisherImpl> service_publisher_; + ServicePublisher::Config config; }; } // namespace @@ -62,99 +68,100 @@ class ServicePublisherImplTest : public ::testing::Test { TEST_F(ServicePublisherImplTest, NormalStartStop) { ASSERT_EQ(State::kStopped, service_publisher_->state()); - EXPECT_CALL(mock_delegate_, StartPublisher()); + EXPECT_CALL(*mock_delegate_, StartPublisher(_)); EXPECT_TRUE(service_publisher_->Start()); EXPECT_FALSE(service_publisher_->Start()); EXPECT_EQ(State::kStarting, service_publisher_->state()); - mock_delegate_.SetState(State::kRunning); + mock_delegate_->SetState(State::kRunning); EXPECT_EQ(State::kRunning, service_publisher_->state()); - EXPECT_CALL(mock_delegate_, StopPublisher()); + EXPECT_CALL(*mock_delegate_, StopPublisher()); EXPECT_TRUE(service_publisher_->Stop()); EXPECT_FALSE(service_publisher_->Stop()); EXPECT_EQ(State::kStopping, service_publisher_->state()); - mock_delegate_.SetState(State::kStopped); + mock_delegate_->SetState(State::kStopped); EXPECT_EQ(State::kStopped, service_publisher_->state()); } TEST_F(ServicePublisherImplTest, StopBeforeRunning) { - EXPECT_CALL(mock_delegate_, StartPublisher()); + EXPECT_CALL(*mock_delegate_, StartPublisher(_)); EXPECT_TRUE(service_publisher_->Start()); EXPECT_EQ(State::kStarting, service_publisher_->state()); - EXPECT_CALL(mock_delegate_, StopPublisher()); + EXPECT_CALL(*mock_delegate_, StopPublisher()); EXPECT_TRUE(service_publisher_->Stop()); EXPECT_FALSE(service_publisher_->Stop()); EXPECT_EQ(State::kStopping, service_publisher_->state()); - mock_delegate_.SetState(State::kStopped); + mock_delegate_->SetState(State::kStopped); EXPECT_EQ(State::kStopped, service_publisher_->state()); } TEST_F(ServicePublisherImplTest, StartSuspended) { - EXPECT_CALL(mock_delegate_, StartAndSuspendPublisher()); - EXPECT_CALL(mock_delegate_, StartPublisher()).Times(0); + EXPECT_CALL(*mock_delegate_, StartAndSuspendPublisher(_)); + EXPECT_CALL(*mock_delegate_, StartPublisher(_)).Times(0); EXPECT_TRUE(service_publisher_->StartAndSuspend()); EXPECT_FALSE(service_publisher_->Start()); EXPECT_EQ(State::kStarting, service_publisher_->state()); - mock_delegate_.SetState(State::kSuspended); + mock_delegate_->SetState(State::kSuspended); EXPECT_EQ(State::kSuspended, service_publisher_->state()); } TEST_F(ServicePublisherImplTest, SuspendAndResume) { EXPECT_TRUE(service_publisher_->Start()); - mock_delegate_.SetState(State::kRunning); + mock_delegate_->SetState(State::kRunning); - EXPECT_CALL(mock_delegate_, ResumePublisher()).Times(0); - EXPECT_CALL(mock_delegate_, SuspendPublisher()).Times(2); + EXPECT_CALL(*mock_delegate_, ResumePublisher(_)).Times(0); + EXPECT_CALL(*mock_delegate_, SuspendPublisher()).Times(2); EXPECT_FALSE(service_publisher_->Resume()); EXPECT_TRUE(service_publisher_->Suspend()); EXPECT_TRUE(service_publisher_->Suspend()); - mock_delegate_.SetState(State::kSuspended); + mock_delegate_->SetState(State::kSuspended); EXPECT_EQ(State::kSuspended, service_publisher_->state()); - EXPECT_CALL(mock_delegate_, StartPublisher()).Times(0); - EXPECT_CALL(mock_delegate_, SuspendPublisher()).Times(0); - EXPECT_CALL(mock_delegate_, ResumePublisher()).Times(2); + EXPECT_CALL(*mock_delegate_, StartPublisher(_)).Times(0); + EXPECT_CALL(*mock_delegate_, SuspendPublisher()).Times(0); + EXPECT_CALL(*mock_delegate_, ResumePublisher(_)).Times(2); EXPECT_FALSE(service_publisher_->Start()); EXPECT_FALSE(service_publisher_->Suspend()); EXPECT_TRUE(service_publisher_->Resume()); EXPECT_TRUE(service_publisher_->Resume()); - mock_delegate_.SetState(State::kRunning); + mock_delegate_->SetState(State::kRunning); EXPECT_EQ(State::kRunning, service_publisher_->state()); - EXPECT_CALL(mock_delegate_, ResumePublisher()).Times(0); + EXPECT_CALL(*mock_delegate_, ResumePublisher(_)).Times(0); EXPECT_FALSE(service_publisher_->Resume()); } TEST_F(ServicePublisherImplTest, ObserverTransitions) { MockObserver observer; - NiceMock<MockMdnsDelegate> mock_delegate; - service_publisher_ = - std::make_unique<ServicePublisherImpl>(&observer, &mock_delegate); + auto mock_delegate = std::make_unique<NiceMock<MockMdnsDelegate>>(); + NiceMock<MockMdnsDelegate>* const mock_delegate_ptr = mock_delegate.get(); + auto service_publisher = std::make_unique<ServicePublisherImpl>( + &observer, std::move(mock_delegate)); - service_publisher_->Start(); + service_publisher->Start(); Expectation start_from_stopped = EXPECT_CALL(observer, OnStarted()); - mock_delegate.SetState(State::kRunning); + mock_delegate_ptr->SetState(State::kRunning); - service_publisher_->Suspend(); + service_publisher->Suspend(); Expectation suspend_from_running = EXPECT_CALL(observer, OnSuspended()).After(start_from_stopped); - mock_delegate.SetState(State::kSuspended); + mock_delegate_ptr->SetState(State::kSuspended); - service_publisher_->Resume(); + service_publisher->Resume(); Expectation resume_from_suspended = EXPECT_CALL(observer, OnStarted()).After(suspend_from_running); - mock_delegate.SetState(State::kRunning); + mock_delegate_ptr->SetState(State::kRunning); - service_publisher_->Stop(); + service_publisher->Stop(); EXPECT_CALL(observer, OnStopped()).After(resume_from_suspended); - mock_delegate.SetState(State::kStopped); + mock_delegate_ptr->SetState(State::kStopped); } } // namespace osp diff --git a/osp/impl/testing/BUILD.gn b/osp/impl/testing/BUILD.gn deleted file mode 100644 index 94bcc504..00000000 --- a/osp/impl/testing/BUILD.gn +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2018 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -import("../../build/config/services.gni") -assert(use_mdns_responder) - -source_set("testing") { - testonly = true - sources = [ - "fake_mdns_platform_service.cc", - "fake_mdns_platform_service.h", - "fake_mdns_responder_adapter.cc", - "fake_mdns_responder_adapter.h", - ] - - deps = [ - "../discovery/mdns:mdns_interface", - ] - - public_deps = [ - "../../../platform", - ] -} - -source_set("unittests") { - testonly = true - sources = [ - "fake_mdns_platform_service_unittest.cc", - "fake_mdns_responder_adapter_unittest.cc", - ] - - deps = [ - ":testing", - "../../../third_party/abseil", - "../../../third_party/googletest:gtest", - ] -} diff --git a/osp/impl/testing/fake_mdns_platform_service.cc b/osp/impl/testing/fake_mdns_platform_service.cc deleted file mode 100644 index 8ce6faf7..00000000 --- a/osp/impl/testing/fake_mdns_platform_service.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/testing/fake_mdns_platform_service.h" - -#include <algorithm> - -#include "util/osp_logging.h" - -namespace openscreen { -namespace osp { - -FakeMdnsPlatformService::FakeMdnsPlatformService() = default; -FakeMdnsPlatformService::~FakeMdnsPlatformService() = default; - -std::vector<MdnsPlatformService::BoundInterface> -FakeMdnsPlatformService::RegisterInterfaces( - const std::vector<NetworkInterfaceIndex>& allowlist) { - OSP_CHECK(registered_interfaces_.empty()); - if (allowlist.empty()) { - registered_interfaces_ = interfaces_; - } else { - for (const auto& interface : interfaces_) { - if (std::find(allowlist.begin(), allowlist.end(), - interface.interface_info.index) != allowlist.end()) { - registered_interfaces_.push_back(interface); - } - } - } - return registered_interfaces_; -} - -void FakeMdnsPlatformService::DeregisterInterfaces( - const std::vector<BoundInterface>& interfaces) { - for (const auto& interface : interfaces) { - auto index = interface.interface_info.index; - auto it = std::find_if(registered_interfaces_.begin(), - registered_interfaces_.end(), - [index](const BoundInterface& interface) { - return interface.interface_info.index == index; - }); - OSP_CHECK(it != registered_interfaces_.end()) - << "Must deregister a previously returned interface: " << index; - registered_interfaces_.erase(it); - } -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/testing/fake_mdns_platform_service.h b/osp/impl/testing/fake_mdns_platform_service.h deleted file mode 100644 index a21c48cd..00000000 --- a/osp/impl/testing/fake_mdns_platform_service.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_IMPL_TESTING_FAKE_MDNS_PLATFORM_SERVICE_H_ -#define OSP_IMPL_TESTING_FAKE_MDNS_PLATFORM_SERVICE_H_ - -#include <vector> - -#include "osp/impl/mdns_platform_service.h" - -namespace openscreen { -namespace osp { - -class FakeMdnsPlatformService final : public MdnsPlatformService { - public: - FakeMdnsPlatformService(); - ~FakeMdnsPlatformService() override; - - void set_interfaces(const std::vector<BoundInterface>& interfaces) { - interfaces_ = interfaces; - } - - // PlatformService overrides. - std::vector<BoundInterface> RegisterInterfaces( - const std::vector<NetworkInterfaceIndex>& interface_index_allowlist) - override; - void DeregisterInterfaces( - const std::vector<BoundInterface>& registered_interfaces) override; - - private: - std::vector<BoundInterface> registered_interfaces_; - std::vector<BoundInterface> interfaces_; -}; - -} // namespace osp -} // namespace openscreen - -#endif // OSP_IMPL_TESTING_FAKE_MDNS_PLATFORM_SERVICE_H_ diff --git a/osp/impl/testing/fake_mdns_platform_service_unittest.cc b/osp/impl/testing/fake_mdns_platform_service_unittest.cc deleted file mode 100644 index 5a1d5f4c..00000000 --- a/osp/impl/testing/fake_mdns_platform_service_unittest.cc +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/testing/fake_mdns_platform_service.h" - -#include <cstdint> - -#include "gtest/gtest.h" - -namespace openscreen { -namespace osp { -namespace { - -UdpSocket* const kDefaultSocket = - reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(16)); -UdpSocket* const kSecondSocket = - reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(24)); - -class FakeMdnsPlatformServiceTest : public ::testing::Test { - protected: - const uint8_t mac1_[6] = {11, 22, 33, 44, 55, 66}; - const uint8_t mac2_[6] = {12, 23, 34, 45, 56, 67}; - const IPSubnet subnet1_{IPAddress{192, 168, 3, 2}, 24}; - const IPSubnet subnet2_{ - IPAddress{0x0102, 0x0304, 0x0504, 0x0302, 0x0102, 0x0304, 0x0506, 0x0708}, - 24}; - std::vector<MdnsPlatformService::BoundInterface> bound_interfaces_{ - MdnsPlatformService::BoundInterface{ - InterfaceInfo{1, - mac1_, - "eth0", - InterfaceInfo::Type::kEthernet, - {subnet1_}}, - subnet1_, kDefaultSocket}, - MdnsPlatformService::BoundInterface{ - InterfaceInfo{2, - mac2_, - "eth1", - InterfaceInfo::Type::kEthernet, - {subnet2_}}, - subnet2_, kSecondSocket}}; -}; - -} // namespace - -TEST_F(FakeMdnsPlatformServiceTest, SimpleRegistration) { - FakeMdnsPlatformService platform_service; - std::vector<MdnsPlatformService::BoundInterface> bound_interfaces{ - bound_interfaces_[0]}; - - platform_service.set_interfaces(bound_interfaces); - - auto registered_interfaces = platform_service.RegisterInterfaces({}); - EXPECT_EQ(bound_interfaces, registered_interfaces); - platform_service.DeregisterInterfaces(registered_interfaces); - - registered_interfaces = platform_service.RegisterInterfaces({}); - EXPECT_EQ(bound_interfaces, registered_interfaces); - platform_service.DeregisterInterfaces(registered_interfaces); - platform_service.set_interfaces({}); - - registered_interfaces = platform_service.RegisterInterfaces({}); - EXPECT_TRUE(registered_interfaces.empty()); - platform_service.DeregisterInterfaces(registered_interfaces); - - std::vector<MdnsPlatformService::BoundInterface> new_interfaces{ - bound_interfaces_[1]}; - - platform_service.set_interfaces(new_interfaces); - - registered_interfaces = platform_service.RegisterInterfaces({}); - EXPECT_EQ(new_interfaces, registered_interfaces); - platform_service.DeregisterInterfaces(registered_interfaces); -} - -TEST_F(FakeMdnsPlatformServiceTest, ObeyIndexAllowlist) { - FakeMdnsPlatformService platform_service; - platform_service.set_interfaces(bound_interfaces_); - - auto eth0_only = platform_service.RegisterInterfaces({1}); - EXPECT_EQ( - (std::vector<MdnsPlatformService::BoundInterface>{bound_interfaces_[0]}), - eth0_only); - platform_service.DeregisterInterfaces(eth0_only); - - auto eth1_only = platform_service.RegisterInterfaces({2}); - EXPECT_EQ( - (std::vector<MdnsPlatformService::BoundInterface>{bound_interfaces_[1]}), - eth1_only); - platform_service.DeregisterInterfaces(eth1_only); - - auto both = platform_service.RegisterInterfaces({1, 2}); - EXPECT_EQ(bound_interfaces_, both); - platform_service.DeregisterInterfaces(both); -} - -TEST_F(FakeMdnsPlatformServiceTest, PartialDeregister) { - FakeMdnsPlatformService platform_service; - platform_service.set_interfaces(bound_interfaces_); - - auto both = platform_service.RegisterInterfaces({}); - std::vector<MdnsPlatformService::BoundInterface> eth0_only{ - bound_interfaces_[0]}; - std::vector<MdnsPlatformService::BoundInterface> eth1_only{ - bound_interfaces_[1]}; - platform_service.DeregisterInterfaces(eth0_only); - platform_service.DeregisterInterfaces(eth1_only); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/testing/fake_mdns_responder_adapter.cc b/osp/impl/testing/fake_mdns_responder_adapter.cc deleted file mode 100644 index 7b5a3b5e..00000000 --- a/osp/impl/testing/fake_mdns_responder_adapter.cc +++ /dev/null @@ -1,600 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/testing/fake_mdns_responder_adapter.h" - -#include <algorithm> -#include <map> -#include <string> -#include <utility> - -#include "platform/base/error.h" -#include "util/osp_logging.h" - -namespace openscreen { -namespace osp { - -constexpr char kLocalDomain[] = "local"; - -PtrEvent MakePtrEvent(const std::string& service_instance, - const std::string& service_type, - const std::string& service_protocol, - UdpSocket* socket) { - const auto labels = std::vector<std::string>{service_instance, service_type, - service_protocol, kLocalDomain}; - ErrorOr<DomainName> full_instance_name = - DomainName::FromLabels(labels.begin(), labels.end()); - OSP_CHECK(full_instance_name); - PtrEvent result{QueryEventHeader{QueryEventHeader::Type::kAdded, socket}, - full_instance_name.value()}; - return result; -} - -SrvEvent MakeSrvEvent(const std::string& service_instance, - const std::string& service_type, - const std::string& service_protocol, - const std::string& hostname, - uint16_t port, - UdpSocket* socket) { - const auto instance_labels = std::vector<std::string>{ - service_instance, service_type, service_protocol, kLocalDomain}; - ErrorOr<DomainName> full_instance_name = - DomainName::FromLabels(instance_labels.begin(), instance_labels.end()); - OSP_CHECK(full_instance_name); - - const auto host_labels = std::vector<std::string>{hostname, kLocalDomain}; - ErrorOr<DomainName> domain_name = - DomainName::FromLabels(host_labels.begin(), host_labels.end()); - OSP_CHECK(domain_name); - - SrvEvent result{QueryEventHeader{QueryEventHeader::Type::kAdded, socket}, - full_instance_name.value(), domain_name.value(), port}; - return result; -} - -TxtEvent MakeTxtEvent(const std::string& service_instance, - const std::string& service_type, - const std::string& service_protocol, - const std::vector<std::string>& txt_lines, - UdpSocket* socket) { - const auto labels = std::vector<std::string>{service_instance, service_type, - service_protocol, kLocalDomain}; - ErrorOr<DomainName> domain_name = - DomainName::FromLabels(labels.begin(), labels.end()); - OSP_CHECK(domain_name); - TxtEvent result{QueryEventHeader{QueryEventHeader::Type::kAdded, socket}, - domain_name.value(), txt_lines}; - return result; -} - -AEvent MakeAEvent(const std::string& hostname, - IPAddress address, - UdpSocket* socket) { - const auto labels = std::vector<std::string>{hostname, kLocalDomain}; - ErrorOr<DomainName> domain_name = - DomainName::FromLabels(labels.begin(), labels.end()); - OSP_CHECK(domain_name); - AEvent result{QueryEventHeader{QueryEventHeader::Type::kAdded, socket}, - domain_name.value(), address}; - return result; -} - -AaaaEvent MakeAaaaEvent(const std::string& hostname, - IPAddress address, - UdpSocket* socket) { - const auto labels = std::vector<std::string>{hostname, kLocalDomain}; - ErrorOr<DomainName> domain_name = - DomainName::FromLabels(labels.begin(), labels.end()); - OSP_CHECK(domain_name); - AaaaEvent result{QueryEventHeader{QueryEventHeader::Type::kAdded, socket}, - domain_name.value(), address}; - return result; -} - -void AddEventsForNewService(FakeMdnsResponderAdapter* mdns_responder, - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const std::string& hostname, - uint16_t port, - const std::vector<std::string>& txt_lines, - const IPAddress& address, - UdpSocket* socket) { - mdns_responder->AddPtrEvent( - MakePtrEvent(service_instance, service_name, service_protocol, socket)); - mdns_responder->AddSrvEvent(MakeSrvEvent(service_instance, service_name, - service_protocol, hostname, port, - socket)); - mdns_responder->AddTxtEvent(MakeTxtEvent( - service_instance, service_name, service_protocol, txt_lines, socket)); - mdns_responder->AddAEvent(MakeAEvent(hostname, address, socket)); -} - -FakeMdnsResponderAdapter::~FakeMdnsResponderAdapter() { - if (observer_) { - observer_->OnDestroyed(); - } -} - -void FakeMdnsResponderAdapter::AddPtrEvent(PtrEvent&& ptr_event) { - if (running_) - ptr_events_.push_back(std::move(ptr_event)); -} - -void FakeMdnsResponderAdapter::AddSrvEvent(SrvEvent&& srv_event) { - if (running_) - srv_events_.push_back(std::move(srv_event)); -} - -void FakeMdnsResponderAdapter::AddTxtEvent(TxtEvent&& txt_event) { - if (running_) - txt_events_.push_back(std::move(txt_event)); -} - -void FakeMdnsResponderAdapter::AddAEvent(AEvent&& a_event) { - if (running_) - a_events_.push_back(std::move(a_event)); -} - -void FakeMdnsResponderAdapter::AddAaaaEvent(AaaaEvent&& aaaa_event) { - if (running_) - aaaa_events_.push_back(std::move(aaaa_event)); -} - -bool FakeMdnsResponderAdapter::ptr_queries_empty() const { - for (const auto& queries : queries_) { - if (!queries.second.ptr_queries.empty()) - return false; - } - return true; -} - -bool FakeMdnsResponderAdapter::srv_queries_empty() const { - for (const auto& queries : queries_) { - if (!queries.second.srv_queries.empty()) - return false; - } - return true; -} - -bool FakeMdnsResponderAdapter::txt_queries_empty() const { - for (const auto& queries : queries_) { - if (!queries.second.txt_queries.empty()) - return false; - } - return true; -} - -bool FakeMdnsResponderAdapter::a_queries_empty() const { - for (const auto& queries : queries_) { - if (!queries.second.a_queries.empty()) - return false; - } - return true; -} - -bool FakeMdnsResponderAdapter::aaaa_queries_empty() const { - for (const auto& queries : queries_) { - if (!queries.second.aaaa_queries.empty()) - return false; - } - return true; -} - -Error FakeMdnsResponderAdapter::Init() { - OSP_CHECK(!running_); - running_ = true; - return Error::None(); -} - -void FakeMdnsResponderAdapter::Close() { - queries_.clear(); - ptr_events_.clear(); - srv_events_.clear(); - txt_events_.clear(); - a_events_.clear(); - aaaa_events_.clear(); - registered_interfaces_.clear(); - registered_services_.clear(); - running_ = false; -} - -Error FakeMdnsResponderAdapter::SetHostLabel(const std::string& host_label) { - return Error::Code::kNotImplemented; -} - -Error FakeMdnsResponderAdapter::RegisterInterface( - const InterfaceInfo& interface_info, - const IPSubnet& interface_address, - UdpSocket* socket) { - if (!running_) - return Error::Code::kOperationInvalid; - - if (std::find_if(registered_interfaces_.begin(), registered_interfaces_.end(), - [&socket](const RegisteredInterface& interface) { - return interface.socket == socket; - }) != registered_interfaces_.end()) { - return Error::Code::kItemNotFound; - } - registered_interfaces_.push_back({interface_info, interface_address, socket}); - return Error::None(); -} - -Error FakeMdnsResponderAdapter::DeregisterInterface(UdpSocket* socket) { - auto it = - std::find_if(registered_interfaces_.begin(), registered_interfaces_.end(), - [&socket](const RegisteredInterface& interface) { - return interface.socket == socket; - }); - if (it == registered_interfaces_.end()) - return Error::Code::kItemNotFound; - - registered_interfaces_.erase(it); - return Error::None(); -} - -void FakeMdnsResponderAdapter::OnRead(UdpSocket* socket, - ErrorOr<UdpPacket> packet) { - OSP_NOTREACHED(); -} - -void FakeMdnsResponderAdapter::OnSendError(UdpSocket* socket, Error error) { - OSP_NOTREACHED(); -} - -void FakeMdnsResponderAdapter::OnError(UdpSocket* socket, Error error) { - OSP_NOTREACHED(); -} - -void FakeMdnsResponderAdapter::OnBound(UdpSocket* socket) { - OSP_NOTREACHED(); -} - -Clock::duration FakeMdnsResponderAdapter::RunTasks() { - return std::chrono::seconds(1); -} - -std::vector<PtrEvent> FakeMdnsResponderAdapter::TakePtrResponses() { - std::vector<PtrEvent> result; - for (auto& queries : queries_) { - const auto query_it = std::stable_partition( - ptr_events_.begin(), ptr_events_.end(), - [&queries](const PtrEvent& ptr_event) { - const auto instance_labels = ptr_event.service_instance.GetLabels(); - for (const auto& query : queries.second.ptr_queries) { - const auto query_labels = query.GetLabels(); - // TODO(btolsch): Just use qname if it's added to PtrEvent. - if (ptr_event.header.socket == queries.first && - std::equal(instance_labels.begin() + 1, instance_labels.end(), - query_labels.begin())) { - return false; - } - } - return true; - }); - for (auto it = query_it; it != ptr_events_.end(); ++it) { - result.push_back(std::move(*it)); - } - ptr_events_.erase(query_it, ptr_events_.end()); - } - OSP_LOG_INFO << "taking " << result.size() << " ptr response(s)"; - return result; -} - -std::vector<SrvEvent> FakeMdnsResponderAdapter::TakeSrvResponses() { - std::vector<SrvEvent> result; - for (auto& queries : queries_) { - const auto query_it = std::stable_partition( - srv_events_.begin(), srv_events_.end(), - [&queries](const SrvEvent& srv_event) { - for (const auto& query : queries.second.srv_queries) { - if (srv_event.header.socket == queries.first && - srv_event.service_instance == query) - return false; - } - return true; - }); - for (auto it = query_it; it != srv_events_.end(); ++it) { - result.push_back(std::move(*it)); - } - srv_events_.erase(query_it, srv_events_.end()); - } - OSP_LOG_INFO << "taking " << result.size() << " srv response(s)"; - return result; -} - -std::vector<TxtEvent> FakeMdnsResponderAdapter::TakeTxtResponses() { - std::vector<TxtEvent> result; - for (auto& queries : queries_) { - const auto query_it = std::stable_partition( - txt_events_.begin(), txt_events_.end(), - [&queries](const TxtEvent& txt_event) { - for (const auto& query : queries.second.txt_queries) { - if (txt_event.header.socket == queries.first && - txt_event.service_instance == query) { - return false; - } - } - return true; - }); - for (auto it = query_it; it != txt_events_.end(); ++it) { - result.push_back(std::move(*it)); - } - txt_events_.erase(query_it, txt_events_.end()); - } - OSP_LOG_INFO << "taking " << result.size() << " txt response(s)"; - return result; -} - -std::vector<AEvent> FakeMdnsResponderAdapter::TakeAResponses() { - std::vector<AEvent> result; - for (auto& queries : queries_) { - const auto query_it = std::stable_partition( - a_events_.begin(), a_events_.end(), [&queries](const AEvent& a_event) { - for (const auto& query : queries.second.a_queries) { - if (a_event.header.socket == queries.first && - a_event.domain_name == query) { - return false; - } - } - return true; - }); - for (auto it = query_it; it != a_events_.end(); ++it) { - result.push_back(std::move(*it)); - } - a_events_.erase(query_it, a_events_.end()); - } - OSP_LOG_INFO << "taking " << result.size() << " a response(s)"; - return result; -} - -std::vector<AaaaEvent> FakeMdnsResponderAdapter::TakeAaaaResponses() { - std::vector<AaaaEvent> result; - for (auto& queries : queries_) { - const auto query_it = std::stable_partition( - aaaa_events_.begin(), aaaa_events_.end(), - [&queries](const AaaaEvent& aaaa_event) { - for (const auto& query : queries.second.aaaa_queries) { - if (aaaa_event.header.socket == queries.first && - aaaa_event.domain_name == query) { - return false; - } - } - return true; - }); - for (auto it = query_it; it != aaaa_events_.end(); ++it) { - result.push_back(std::move(*it)); - } - aaaa_events_.erase(query_it, aaaa_events_.end()); - } - OSP_LOG_INFO << "taking " << result.size() << " a response(s)"; - return result; -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StartPtrQuery( - UdpSocket* socket, - const DomainName& service_type) { - if (!running_) - return MdnsResponderErrorCode::kUnknownError; - - auto canonical_service_type = service_type; - if (!canonical_service_type.EndsWithLocalDomain()) - OSP_CHECK(canonical_service_type.Append(DomainName::GetLocalDomain()).ok()); - - auto maybe_inserted = - queries_[socket].ptr_queries.insert(canonical_service_type); - if (maybe_inserted.second) { - return MdnsResponderErrorCode::kNoError; - } else { - return MdnsResponderErrorCode::kUnknownError; - } -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StartSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) { - if (!running_) - return MdnsResponderErrorCode::kUnknownError; - - auto maybe_inserted = queries_[socket].srv_queries.insert(service_instance); - if (maybe_inserted.second) { - return MdnsResponderErrorCode::kNoError; - } else { - return MdnsResponderErrorCode::kUnknownError; - } -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StartTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) { - if (!running_) - return MdnsResponderErrorCode::kUnknownError; - - auto maybe_inserted = queries_[socket].txt_queries.insert(service_instance); - if (maybe_inserted.second) { - return MdnsResponderErrorCode::kNoError; - } else { - return MdnsResponderErrorCode::kUnknownError; - } -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAQuery( - UdpSocket* socket, - const DomainName& domain_name) { - if (!running_) - return MdnsResponderErrorCode::kUnknownError; - - auto maybe_inserted = queries_[socket].a_queries.insert(domain_name); - if (maybe_inserted.second) { - return MdnsResponderErrorCode::kNoError; - } else { - return MdnsResponderErrorCode::kUnknownError; - } -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StartAaaaQuery( - UdpSocket* socket, - const DomainName& domain_name) { - if (!running_) - return MdnsResponderErrorCode::kUnknownError; - - auto maybe_inserted = queries_[socket].aaaa_queries.insert(domain_name); - if (maybe_inserted.second) { - return MdnsResponderErrorCode::kNoError; - } else { - return MdnsResponderErrorCode::kUnknownError; - } -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StopPtrQuery( - UdpSocket* socket, - const DomainName& service_type) { - auto interface_entry = queries_.find(socket); - if (interface_entry == queries_.end()) - return MdnsResponderErrorCode::kUnknownError; - auto& ptr_queries = interface_entry->second.ptr_queries; - auto canonical_service_type = service_type; - if (!canonical_service_type.EndsWithLocalDomain()) - OSP_CHECK(canonical_service_type.Append(DomainName::GetLocalDomain()).ok()); - - auto it = ptr_queries.find(canonical_service_type); - if (it == ptr_queries.end()) - return MdnsResponderErrorCode::kUnknownError; - - ptr_queries.erase(it); - return MdnsResponderErrorCode::kNoError; -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StopSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) { - auto interface_entry = queries_.find(socket); - if (interface_entry == queries_.end()) - return MdnsResponderErrorCode::kUnknownError; - auto& srv_queries = interface_entry->second.srv_queries; - auto it = srv_queries.find(service_instance); - if (it == srv_queries.end()) - return MdnsResponderErrorCode::kUnknownError; - - srv_queries.erase(it); - return MdnsResponderErrorCode::kNoError; -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StopTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) { - auto interface_entry = queries_.find(socket); - if (interface_entry == queries_.end()) - return MdnsResponderErrorCode::kUnknownError; - auto& txt_queries = interface_entry->second.txt_queries; - auto it = txt_queries.find(service_instance); - if (it == txt_queries.end()) - return MdnsResponderErrorCode::kUnknownError; - - txt_queries.erase(it); - return MdnsResponderErrorCode::kNoError; -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StopAQuery( - UdpSocket* socket, - const DomainName& domain_name) { - auto interface_entry = queries_.find(socket); - if (interface_entry == queries_.end()) - return MdnsResponderErrorCode::kUnknownError; - auto& a_queries = interface_entry->second.a_queries; - auto it = a_queries.find(domain_name); - if (it == a_queries.end()) - return MdnsResponderErrorCode::kUnknownError; - - a_queries.erase(it); - return MdnsResponderErrorCode::kNoError; -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::StopAaaaQuery( - UdpSocket* socket, - const DomainName& domain_name) { - auto interface_entry = queries_.find(socket); - if (interface_entry == queries_.end()) - return MdnsResponderErrorCode::kUnknownError; - auto& aaaa_queries = interface_entry->second.aaaa_queries; - auto it = aaaa_queries.find(domain_name); - if (it == aaaa_queries.end()) - return MdnsResponderErrorCode::kUnknownError; - - aaaa_queries.erase(it); - return MdnsResponderErrorCode::kNoError; -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::RegisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const DomainName& target_host, - uint16_t target_port, - const std::map<std::string, std::string>& txt_data) { - if (!running_) - return MdnsResponderErrorCode::kUnknownError; - - if (std::find_if(registered_services_.begin(), registered_services_.end(), - [&service_instance, &service_name, - &service_protocol](const RegisteredService& service) { - return service.service_instance == service_instance && - service.service_name == service_name && - service.service_protocol == service_protocol; - }) != registered_services_.end()) { - return MdnsResponderErrorCode::kUnknownError; - } - registered_services_.push_back({service_instance, service_name, - service_protocol, target_host, target_port, - txt_data}); - return MdnsResponderErrorCode::kNoError; -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::DeregisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol) { - if (!running_) - return MdnsResponderErrorCode::kUnknownError; - - auto it = - std::find_if(registered_services_.begin(), registered_services_.end(), - [&service_instance, &service_name, - &service_protocol](const RegisteredService& service) { - return service.service_instance == service_instance && - service.service_name == service_name && - service.service_protocol == service_protocol; - }); - if (it == registered_services_.end()) - return MdnsResponderErrorCode::kUnknownError; - - registered_services_.erase(it); - return MdnsResponderErrorCode::kNoError; -} - -MdnsResponderErrorCode FakeMdnsResponderAdapter::UpdateTxtData( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const std::map<std::string, std::string>& txt_data) { - if (!running_) - return MdnsResponderErrorCode::kUnknownError; - - auto it = - std::find_if(registered_services_.begin(), registered_services_.end(), - [&service_instance, &service_name, - &service_protocol](const RegisteredService& service) { - return service.service_instance == service_instance && - service.service_name == service_name && - service.service_protocol == service_protocol; - }); - if (it == registered_services_.end()) - return MdnsResponderErrorCode::kUnknownError; - - it->txt_data = txt_data; - return MdnsResponderErrorCode::kNoError; -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/impl/testing/fake_mdns_responder_adapter.h b/osp/impl/testing/fake_mdns_responder_adapter.h deleted file mode 100644 index ecdb21cc..00000000 --- a/osp/impl/testing/fake_mdns_responder_adapter.h +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef OSP_IMPL_TESTING_FAKE_MDNS_RESPONDER_ADAPTER_H_ -#define OSP_IMPL_TESTING_FAKE_MDNS_RESPONDER_ADAPTER_H_ - -#include <map> -#include <set> -#include <string> -#include <vector> - -#include "osp/impl/discovery/mdns/mdns_responder_adapter.h" - -namespace openscreen { -namespace osp { - -class FakeMdnsResponderAdapter; - -PtrEvent MakePtrEvent(const std::string& service_instance, - const std::string& service_type, - const std::string& service_protocol, - UdpSocket* socket); - -SrvEvent MakeSrvEvent(const std::string& service_instance, - const std::string& service_type, - const std::string& service_protocol, - const std::string& hostname, - uint16_t port, - UdpSocket* socket); - -TxtEvent MakeTxtEvent(const std::string& service_instance, - const std::string& service_type, - const std::string& service_protocol, - const std::vector<std::string>& txt_lines, - UdpSocket* socket); - -AEvent MakeAEvent(const std::string& hostname, - IPAddress address, - UdpSocket* socket); - -AaaaEvent MakeAaaaEvent(const std::string& hostname, - IPAddress address, - UdpSocket* socket); - -void AddEventsForNewService(FakeMdnsResponderAdapter* mdns_responder, - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const std::string& hostname, - uint16_t port, - const std::vector<std::string>& txt_lines, - const IPAddress& address, - UdpSocket* socket); - -class FakeMdnsResponderAdapter final : public MdnsResponderAdapter { - public: - struct RegisteredInterface { - InterfaceInfo interface_info; - IPSubnet interface_address; - UdpSocket* socket; - }; - - struct RegisteredService { - std::string service_instance; - std::string service_name; - std::string service_protocol; - DomainName target_host; - uint16_t target_port; - std::map<std::string, std::string> txt_data; - }; - - class LifetimeObserver { - public: - virtual ~LifetimeObserver() = default; - - virtual void OnDestroyed() = 0; - }; - - ~FakeMdnsResponderAdapter() override; - - void SetLifetimeObserver(LifetimeObserver* observer) { observer_ = observer; } - - void AddPtrEvent(PtrEvent&& ptr_event); - void AddSrvEvent(SrvEvent&& srv_event); - void AddTxtEvent(TxtEvent&& txt_event); - void AddAEvent(AEvent&& a_event); - void AddAaaaEvent(AaaaEvent&& aaaa_event); - - const std::vector<RegisteredInterface>& registered_interfaces() { - return registered_interfaces_; - } - const std::vector<RegisteredService>& registered_services() { - return registered_services_; - } - bool ptr_queries_empty() const; - bool srv_queries_empty() const; - bool txt_queries_empty() const; - bool a_queries_empty() const; - bool aaaa_queries_empty() const; - bool running() const { return running_; } - - // UdpSocket::Client overrides. - void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) override; - void OnSendError(UdpSocket* socket, Error error) override; - void OnError(UdpSocket* socket, Error error) override; - void OnBound(UdpSocket* socket) override; - - // MdnsResponderAdapter overrides. - Error Init() override; - void Close() override; - - Error SetHostLabel(const std::string& host_label) override; - - // TODO(btolsch): Reject/OSP_CHECK events that don't match any registered - // interface? - Error RegisterInterface(const InterfaceInfo& interface_info, - const IPSubnet& interface_address, - UdpSocket* socket) override; - Error DeregisterInterface(UdpSocket* socket) override; - - Clock::duration RunTasks() override; - - std::vector<PtrEvent> TakePtrResponses() override; - std::vector<SrvEvent> TakeSrvResponses() override; - std::vector<TxtEvent> TakeTxtResponses() override; - std::vector<AEvent> TakeAResponses() override; - std::vector<AaaaEvent> TakeAaaaResponses() override; - - MdnsResponderErrorCode StartPtrQuery(UdpSocket* socket, - const DomainName& service_type) override; - MdnsResponderErrorCode StartSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) override; - MdnsResponderErrorCode StartTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) override; - MdnsResponderErrorCode StartAQuery(UdpSocket* socket, - const DomainName& domain_name) override; - MdnsResponderErrorCode StartAaaaQuery(UdpSocket* socket, - const DomainName& domain_name) override; - - MdnsResponderErrorCode StopPtrQuery(UdpSocket* socket, - const DomainName& service_type) override; - MdnsResponderErrorCode StopSrvQuery( - UdpSocket* socket, - const DomainName& service_instance) override; - MdnsResponderErrorCode StopTxtQuery( - UdpSocket* socket, - const DomainName& service_instance) override; - MdnsResponderErrorCode StopAQuery(UdpSocket* socket, - const DomainName& domain_name) override; - MdnsResponderErrorCode StopAaaaQuery(UdpSocket* socket, - const DomainName& domain_name) override; - - MdnsResponderErrorCode RegisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const DomainName& target_host, - uint16_t target_port, - const std::map<std::string, std::string>& txt_data) override; - MdnsResponderErrorCode DeregisterService( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol) override; - MdnsResponderErrorCode UpdateTxtData( - const std::string& service_instance, - const std::string& service_name, - const std::string& service_protocol, - const std::map<std::string, std::string>& txt_data) override; - - private: - struct InterfaceQueries { - std::set<DomainName, DomainNameComparator> a_queries; - std::set<DomainName, DomainNameComparator> aaaa_queries; - std::set<DomainName, DomainNameComparator> ptr_queries; - std::set<DomainName, DomainNameComparator> srv_queries; - std::set<DomainName, DomainNameComparator> txt_queries; - }; - - bool running_ = false; - LifetimeObserver* observer_ = nullptr; - - std::map<UdpSocket*, InterfaceQueries> queries_; - // NOTE: One of many simplifications here is that there is no cache. This - // means that calling StartQuery, StopQuery, StartQuery will only return an - // event the first time, unless the test also adds the event a second time. - std::vector<PtrEvent> ptr_events_; - std::vector<SrvEvent> srv_events_; - std::vector<TxtEvent> txt_events_; - std::vector<AEvent> a_events_; - std::vector<AaaaEvent> aaaa_events_; - - std::vector<RegisteredInterface> registered_interfaces_; - std::vector<RegisteredService> registered_services_; -}; - -} // namespace osp -} // namespace openscreen - -#endif // OSP_IMPL_TESTING_FAKE_MDNS_RESPONDER_ADAPTER_H_ diff --git a/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc b/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc deleted file mode 100644 index 06cec89a..00000000 --- a/osp/impl/testing/fake_mdns_responder_adapter_unittest.cc +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright 2018 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "osp/impl/testing/fake_mdns_responder_adapter.h" - -#include "gtest/gtest.h" - -namespace openscreen { -namespace osp { - -namespace { - -constexpr char kTestServiceInstance[] = "turtle"; -constexpr char kTestServiceName[] = "_foo"; -constexpr char kTestServiceProtocol[] = "_udp"; - -UdpSocket* const kDefaultSocket = - reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(8)); -UdpSocket* const kSecondSocket = - reinterpret_cast<UdpSocket*>(static_cast<uintptr_t>(32)); - -} // namespace - -TEST(FakeMdnsResponderAdapterTest, AQueries) { - FakeMdnsResponderAdapter mdns_responder; - - mdns_responder.Init(); - ASSERT_TRUE(mdns_responder.running()); - auto event = MakeAEvent("alpha", IPAddress{1, 2, 3, 4}, kDefaultSocket); - auto domain_name = event.domain_name; - mdns_responder.AddAEvent(std::move(event)); - - auto a_events = mdns_responder.TakeAResponses(); - EXPECT_TRUE(a_events.empty()); - - auto result = mdns_responder.StartAQuery(kDefaultSocket, domain_name); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - a_events = mdns_responder.TakeAResponses(); - ASSERT_EQ(1u, a_events.size()); - EXPECT_EQ(domain_name, a_events[0].domain_name); - EXPECT_EQ((IPAddress{1, 2, 3, 4}), a_events[0].address); - - result = mdns_responder.StopAQuery(kDefaultSocket, domain_name); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - mdns_responder.Close(); - ASSERT_FALSE(mdns_responder.running()); - - mdns_responder.AddAEvent( - MakeAEvent("alpha", IPAddress{1, 2, 3, 4}, kDefaultSocket)); - result = mdns_responder.StartAQuery(kDefaultSocket, domain_name); - EXPECT_NE(MdnsResponderErrorCode::kNoError, result); - a_events = mdns_responder.TakeAResponses(); - EXPECT_TRUE(a_events.empty()); -} - -TEST(FakeMdnsResponderAdapterTest, AaaaQueries) { - FakeMdnsResponderAdapter mdns_responder; - - mdns_responder.Init(); - ASSERT_TRUE(mdns_responder.running()); - auto event = MakeAaaaEvent("alpha", IPAddress{1, 2, 3, 4}, kDefaultSocket); - auto domain_name = event.domain_name; - mdns_responder.AddAaaaEvent(std::move(event)); - - auto aaaa_events = mdns_responder.TakeAaaaResponses(); - EXPECT_TRUE(aaaa_events.empty()); - - auto result = mdns_responder.StartAaaaQuery(kDefaultSocket, domain_name); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - aaaa_events = mdns_responder.TakeAaaaResponses(); - ASSERT_EQ(1u, aaaa_events.size()); - EXPECT_EQ(domain_name, aaaa_events[0].domain_name); - EXPECT_EQ((IPAddress{1, 2, 3, 4}), aaaa_events[0].address); - - result = mdns_responder.StopAaaaQuery(kDefaultSocket, domain_name); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - mdns_responder.Close(); - ASSERT_FALSE(mdns_responder.running()); - - mdns_responder.AddAaaaEvent( - MakeAaaaEvent("alpha", IPAddress{1, 2, 3, 4}, kDefaultSocket)); - result = mdns_responder.StartAaaaQuery(kDefaultSocket, domain_name); - EXPECT_NE(MdnsResponderErrorCode::kNoError, result); - aaaa_events = mdns_responder.TakeAaaaResponses(); - EXPECT_TRUE(aaaa_events.empty()); -} - -TEST(FakeMdnsResponderAdapterTest, PtrQueries) { - const DomainName kTestServiceType{ - {4, '_', 'f', 'o', 'o', 4, '_', 'u', 'd', 'p', 0}}; - const DomainName kTestServiceTypeCanon{{4, '_', 'f', 'o', 'o', 4, '_', 'u', - 'd', 'p', 5, 'l', 'o', 'c', 'a', 'l', - 0}}; - - FakeMdnsResponderAdapter mdns_responder; - - mdns_responder.Init(); - ASSERT_TRUE(mdns_responder.running()); - mdns_responder.AddPtrEvent( - MakePtrEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - kDefaultSocket)); - - auto ptr_events = mdns_responder.TakePtrResponses(); - EXPECT_TRUE(ptr_events.empty()); - - auto result = mdns_responder.StartPtrQuery(kDefaultSocket, kTestServiceType); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - ptr_events = mdns_responder.TakePtrResponses(); - ASSERT_EQ(1u, ptr_events.size()); - auto labels = ptr_events[0].service_instance.GetLabels(); - EXPECT_EQ(kTestServiceInstance, labels[0]); - - // TODO(btolsch): qname if PtrEvent gets it. - ErrorOr<DomainName> st = - DomainName::FromLabels(labels.begin() + 1, labels.end()); - ASSERT_TRUE(st); - EXPECT_EQ(kTestServiceTypeCanon, st.value()); - - result = mdns_responder.StopPtrQuery(kDefaultSocket, kTestServiceType); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - mdns_responder.Close(); - ASSERT_FALSE(mdns_responder.running()); - - mdns_responder.AddPtrEvent( - MakePtrEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - kDefaultSocket)); - result = mdns_responder.StartPtrQuery(kDefaultSocket, kTestServiceType); - EXPECT_NE(MdnsResponderErrorCode::kNoError, result); - ptr_events = mdns_responder.TakePtrResponses(); - EXPECT_TRUE(ptr_events.empty()); -} - -TEST(FakeMdnsResponderAdapterTest, SrvQueries) { - FakeMdnsResponderAdapter mdns_responder; - - mdns_responder.Init(); - ASSERT_TRUE(mdns_responder.running()); - - auto event = - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "alpha", 12345, kDefaultSocket); - auto service_instance = event.service_instance; - auto domain_name = event.domain_name; - mdns_responder.AddSrvEvent(std::move(event)); - - auto srv_events = mdns_responder.TakeSrvResponses(); - EXPECT_TRUE(srv_events.empty()); - - auto result = mdns_responder.StartSrvQuery(kDefaultSocket, service_instance); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - srv_events = mdns_responder.TakeSrvResponses(); - ASSERT_EQ(1u, srv_events.size()); - EXPECT_EQ(service_instance, srv_events[0].service_instance); - EXPECT_EQ(domain_name, srv_events[0].domain_name); - EXPECT_EQ(12345, srv_events[0].port); - - result = mdns_responder.StopSrvQuery(kDefaultSocket, service_instance); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - mdns_responder.Close(); - ASSERT_FALSE(mdns_responder.running()); - - mdns_responder.AddSrvEvent( - MakeSrvEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - "alpha", 12345, kDefaultSocket)); - result = mdns_responder.StartSrvQuery(kDefaultSocket, service_instance); - EXPECT_NE(MdnsResponderErrorCode::kNoError, result); - srv_events = mdns_responder.TakeSrvResponses(); - EXPECT_TRUE(srv_events.empty()); -} - -TEST(FakeMdnsResponderAdapterTest, TxtQueries) { - FakeMdnsResponderAdapter mdns_responder; - - mdns_responder.Init(); - ASSERT_TRUE(mdns_responder.running()); - - const auto txt_lines = std::vector<std::string>{"asdf", "jkl;", "j"}; - auto event = MakeTxtEvent(kTestServiceInstance, kTestServiceName, - kTestServiceProtocol, txt_lines, kDefaultSocket); - auto service_instance = event.service_instance; - mdns_responder.AddTxtEvent(std::move(event)); - - auto txt_events = mdns_responder.TakeTxtResponses(); - EXPECT_TRUE(txt_events.empty()); - - auto result = mdns_responder.StartTxtQuery(kDefaultSocket, service_instance); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - txt_events = mdns_responder.TakeTxtResponses(); - ASSERT_EQ(1u, txt_events.size()); - EXPECT_EQ(service_instance, txt_events[0].service_instance); - EXPECT_EQ(txt_lines, txt_events[0].txt_info); - - result = mdns_responder.StopTxtQuery(kDefaultSocket, service_instance); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - - mdns_responder.Close(); - ASSERT_FALSE(mdns_responder.running()); - - mdns_responder.AddTxtEvent( - MakeTxtEvent(kTestServiceInstance, kTestServiceName, kTestServiceProtocol, - txt_lines, kDefaultSocket)); - result = mdns_responder.StartTxtQuery(kDefaultSocket, service_instance); - EXPECT_NE(MdnsResponderErrorCode::kNoError, result); - txt_events = mdns_responder.TakeTxtResponses(); - EXPECT_TRUE(txt_events.empty()); -} - -TEST(FakeMdnsResponderAdapterTest, RegisterServices) { - FakeMdnsResponderAdapter mdns_responder; - - mdns_responder.Init(); - ASSERT_TRUE(mdns_responder.running()); - - auto result = mdns_responder.RegisterService( - "instance", "name", "proto", - DomainName{{1, 'a', 5, 'l', 'o', 'c', 'a', 'l', 0}}, 12345, - {{"k1", "asdf"}, {"k2", "jkl"}}); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - EXPECT_EQ(1u, mdns_responder.registered_services().size()); - - result = mdns_responder.RegisterService( - "instance2", "name", "proto", - DomainName{{1, 'b', 5, 'l', 'o', 'c', 'a', 'l', 0}}, 12346, - {{"k1", "asdf"}, {"k2", "jkl"}}); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - EXPECT_EQ(2u, mdns_responder.registered_services().size()); - - result = mdns_responder.DeregisterService("instance", "name", "proto"); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - result = mdns_responder.DeregisterService("instance", "name", "proto"); - EXPECT_NE(MdnsResponderErrorCode::kNoError, result); - EXPECT_EQ(1u, mdns_responder.registered_services().size()); - - mdns_responder.Close(); - ASSERT_FALSE(mdns_responder.running()); - EXPECT_EQ(0u, mdns_responder.registered_services().size()); - - result = mdns_responder.RegisterService( - "instance2", "name", "proto", - DomainName{{1, 'b', 5, 'l', 'o', 'c', 'a', 'l', 0}}, 12346, - {{"k1", "asdf"}, {"k2", "jkl"}}); - EXPECT_NE(MdnsResponderErrorCode::kNoError, result); - EXPECT_EQ(0u, mdns_responder.registered_services().size()); -} - -TEST(FakeMdnsResponderAdapterTest, RegisterInterfaces) { - FakeMdnsResponderAdapter mdns_responder; - - mdns_responder.Init(); - ASSERT_TRUE(mdns_responder.running()); - EXPECT_EQ(0u, mdns_responder.registered_interfaces().size()); - - Error result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{}, - kDefaultSocket); - EXPECT_TRUE(result.ok()); - EXPECT_EQ(1u, mdns_responder.registered_interfaces().size()); - - result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{}, - kDefaultSocket); - EXPECT_FALSE(result.ok()); - EXPECT_EQ(1u, mdns_responder.registered_interfaces().size()); - - result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{}, - kSecondSocket); - EXPECT_TRUE(result.ok()); - EXPECT_EQ(2u, mdns_responder.registered_interfaces().size()); - - result = mdns_responder.DeregisterInterface(kSecondSocket); - EXPECT_TRUE(result.ok()); - EXPECT_EQ(1u, mdns_responder.registered_interfaces().size()); - result = mdns_responder.DeregisterInterface(kSecondSocket); - EXPECT_FALSE(result.ok()); - EXPECT_EQ(1u, mdns_responder.registered_interfaces().size()); - - mdns_responder.Close(); - ASSERT_FALSE(mdns_responder.running()); - EXPECT_EQ(0u, mdns_responder.registered_interfaces().size()); - - result = mdns_responder.RegisterInterface(InterfaceInfo{}, IPSubnet{}, - kDefaultSocket); - EXPECT_FALSE(result.ok()); - EXPECT_EQ(0u, mdns_responder.registered_interfaces().size()); -} - -TEST(FakeMdnsResponderAdapterTest, UpdateTxtData) { - FakeMdnsResponderAdapter mdns_responder; - - mdns_responder.Init(); - ASSERT_TRUE(mdns_responder.running()); - - const std::map<std::string, std::string> txt_data1{{"k1", "asdf"}, - {"k2", "jkl"}}; - auto result = mdns_responder.RegisterService( - "instance", "name", "proto", - DomainName{{1, 'a', 5, 'l', 'o', 'c', 'a', 'l', 0}}, 12345, txt_data1); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - ASSERT_EQ(1u, mdns_responder.registered_services().size()); - EXPECT_EQ(txt_data1, mdns_responder.registered_services()[0].txt_data); - - const std::map<std::string, std::string> txt_data2{ - {"k1", "monkey"}, {"k2", "panda"}, {"k3", "turtle"}, {"k4", "rhino"}}; - result = mdns_responder.UpdateTxtData("instance", "name", "proto", txt_data2); - EXPECT_EQ(MdnsResponderErrorCode::kNoError, result); - ASSERT_EQ(1u, mdns_responder.registered_services().size()); - EXPECT_EQ(txt_data2, mdns_responder.registered_services()[0].txt_data); -} - -} // namespace osp -} // namespace openscreen diff --git a/osp/public/BUILD.gn b/osp/public/BUILD.gn index cc915c60..3606dcd5 100644 --- a/osp/public/BUILD.gn +++ b/osp/public/BUILD.gn @@ -11,7 +11,6 @@ source_set("public") { "endpoint_request_ids.cc", "endpoint_request_ids.h", "mdns_service_listener_factory.h", - "mdns_service_publisher_factory.h", "message_demuxer.h", "network_metrics.h", "network_service_manager.h", @@ -35,6 +34,7 @@ source_set("public") { "service_listener.h", "service_publisher.cc", "service_publisher.h", + "service_publisher_factory.h", "timestamp.h", ] diff --git a/osp/public/mdns_service_listener_factory.h b/osp/public/mdns_service_listener_factory.h index 663d060f..fa33f54f 100644 --- a/osp/public/mdns_service_listener_factory.h +++ b/osp/public/mdns_service_listener_factory.h @@ -8,6 +8,7 @@ #include <memory> #include "osp/public/service_listener.h" +#include "util/osp_logging.h" namespace openscreen { @@ -25,7 +26,9 @@ class MdnsServiceListenerFactory { static std::unique_ptr<ServiceListener> Create( const MdnsServiceListenerConfig& config, ServiceListener::Observer* observer, - TaskRunner* task_runner); + TaskRunner* task_runner) { + OSP_NOTREACHED(); + } }; } // namespace osp diff --git a/osp/public/network_service_manager.h b/osp/public/network_service_manager.h index 8289e178..b1a225c4 100644 --- a/osp/public/network_service_manager.h +++ b/osp/public/network_service_manager.h @@ -28,7 +28,7 @@ class NetworkServiceManager final { // be passed for services not provided by the embedder. static NetworkServiceManager* Create( std::unique_ptr<ServiceListener> mdns_listener, - std::unique_ptr<ServicePublisher> mdns_publisher, + std::unique_ptr<ServicePublisher> service_publisher, std::unique_ptr<ProtocolConnectionClient> connection_client, std::unique_ptr<ProtocolConnectionServer> connection_server); @@ -47,7 +47,7 @@ class NetworkServiceManager final { // Returns an instance of the mDNS receiver publisher, or nullptr if not // provided. - ServicePublisher* GetMdnsServicePublisher(); + ServicePublisher* GetServicePublisher(); // Returns an instance of the protocol connection client, or nullptr if not // provided. @@ -60,14 +60,14 @@ class NetworkServiceManager final { private: NetworkServiceManager( std::unique_ptr<ServiceListener> mdns_listener, - std::unique_ptr<ServicePublisher> mdns_publisher, + std::unique_ptr<ServicePublisher> service_publisher, std::unique_ptr<ProtocolConnectionClient> connection_client, std::unique_ptr<ProtocolConnectionServer> connection_server); ~NetworkServiceManager(); std::unique_ptr<ServiceListener> mdns_listener_; - std::unique_ptr<ServicePublisher> mdns_publisher_; + std::unique_ptr<ServicePublisher> service_publisher_; std::unique_ptr<ProtocolConnectionClient> connection_client_; std::unique_ptr<ProtocolConnectionServer> connection_server_; }; diff --git a/osp/public/presentation/presentation_connection.h b/osp/public/presentation/presentation_connection.h index ee0cff1b..4ea37ce7 100644 --- a/osp/public/presentation/presentation_connection.h +++ b/osp/public/presentation/presentation_connection.h @@ -62,7 +62,6 @@ class Connection { class Delegate { public: Delegate() = default; - virtual ~Delegate() = default; // State changes. virtual void OnConnected() = 0; @@ -85,6 +84,9 @@ class Connection { // A binary message was received. virtual void OnBinaryMessage(const std::vector<uint8_t>& data) = 0; + protected: + virtual ~Delegate() = default; + private: OSP_DISALLOW_COPY_AND_ASSIGN(Delegate); }; diff --git a/osp/public/request_response_handler.h b/osp/public/request_response_handler.h index de783efc..0ae97f88 100644 --- a/osp/public/request_response_handler.h +++ b/osp/public/request_response_handler.h @@ -59,12 +59,14 @@ class RequestResponseHandler : public MessageDemuxer::MessageCallback { public: class Delegate { public: - virtual ~Delegate() = default; virtual void OnMatchedResponse(RequestT* request, typename RequestT::ResponseMsgType* response, uint64_t endpoint_id) = 0; virtual void OnError(RequestT* request, Error error) = 0; + + protected: + virtual ~Delegate() = default; }; explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {} diff --git a/osp/public/service_info.h b/osp/public/service_info.h index e486052c..7c95ff2e 100644 --- a/osp/public/service_info.h +++ b/osp/public/service_info.h @@ -14,6 +14,8 @@ namespace openscreen { namespace osp { +constexpr char kOpenScreenServiceName[] = "_openscreen._udp"; + // This contains canonical information about a specific Open Screen service // found on the network via our discovery mechanism (mDNS). struct ServiceInfo { diff --git a/osp/public/service_publisher.cc b/osp/public/service_publisher.cc index 3a8e70b7..2268f2d1 100644 --- a/osp/public/service_publisher.cc +++ b/osp/public/service_publisher.cc @@ -7,26 +7,25 @@ namespace openscreen { namespace osp { -ServicePublisherError::ServicePublisherError() = default; -ServicePublisherError::ServicePublisherError(Code error, - const std::string& message) - : error(error), message(message) {} -ServicePublisherError::ServicePublisherError( - const ServicePublisherError& other) = default; -ServicePublisherError::~ServicePublisherError() = default; - -ServicePublisherError& ServicePublisherError::operator=( - const ServicePublisherError& other) = default; - ServicePublisher::Metrics::Metrics() = default; ServicePublisher::Metrics::~Metrics() = default; ServicePublisher::Config::Config() = default; ServicePublisher::Config::~Config() = default; +bool ServicePublisher::Config::IsValid() const { + return !friendly_name.empty() && !service_instance_name.empty() && + connection_server_port > 0 && !network_interfaces.empty(); +} + +ServicePublisher::~ServicePublisher() = default; + +void ServicePublisher::SetConfig(const Config& config) { + config_ = config; +} + ServicePublisher::ServicePublisher(Observer* observer) : state_(State::kStopped), observer_(observer) {} -ServicePublisher::~ServicePublisher() = default; } // namespace osp } // namespace openscreen diff --git a/osp/public/service_publisher.h b/osp/public/service_publisher.h index b31f59fc..190d22df 100644 --- a/osp/public/service_publisher.h +++ b/osp/public/service_publisher.h @@ -10,30 +10,13 @@ #include <vector> #include "osp/public/timestamp.h" -#include "platform/api/network_interface.h" +#include "platform/base/error.h" +#include "platform/base/interface_info.h" #include "platform/base/macros.h" namespace openscreen { namespace osp { -// Used to report an error from a ServiceListener implementation. -struct ServicePublisherError { - // TODO(mfoltz): Add additional error types, as implementations progress. - enum class Code { - kNone = 0, - }; - - ServicePublisherError(); - ServicePublisherError(Code error, const std::string& message); - ServicePublisherError(const ServicePublisherError& other); - ~ServicePublisherError(); - - ServicePublisherError& operator=(const ServicePublisherError& other); - - Code error; - std::string message; -}; - class ServicePublisher { public: enum class State { @@ -74,7 +57,7 @@ class ServicePublisher { virtual void OnSuspended() = 0; // Reports an error. - virtual void OnError(ServicePublisherError) = 0; + virtual void OnError(Error) = 0; // Reports metrics. virtual void OnMetrics(Metrics) = 0; @@ -103,15 +86,21 @@ class ServicePublisher { // configured in the ProtocolConnectionServer. uint16_t connection_server_port = 0; - // A list of network interface names that the publisher should use. + // A list of network interfaces that the publisher should use. // By default, all enabled Ethernet and WiFi interfaces are used. // This configuration must be identical to the interfaces configured // in the ScreenConnectionServer. - std::vector<NetworkInterfaceIndex> network_interface_indices; + std::vector<InterfaceInfo> network_interfaces; + + // Returns true if the config object is valid. + bool IsValid() const; }; virtual ~ServicePublisher(); + // Sets the service configuration for this publisher. + virtual void SetConfig(const Config& config); + // Starts publishing this service using the config object. // Returns true if state() == kStopped and the service will be started, false // otherwise. @@ -139,14 +128,15 @@ class ServicePublisher { State state() const { return state_; } // Returns the last error reported by this publisher. - ServicePublisherError last_error() const { return last_error_; } + Error last_error() const { return last_error_; } protected: explicit ServicePublisher(Observer* observer); State state_; - ServicePublisherError last_error_; + Error last_error_; Observer* observer_; + Config config_; OSP_DISALLOW_COPY_AND_ASSIGN(ServicePublisher); }; diff --git a/osp/public/mdns_service_publisher_factory.h b/osp/public/service_publisher_factory.h index 075137a7..93193d03 100644 --- a/osp/public/mdns_service_publisher_factory.h +++ b/osp/public/service_publisher_factory.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef OSP_PUBLIC_MDNS_SERVICE_PUBLISHER_FACTORY_H_ -#define OSP_PUBLIC_MDNS_SERVICE_PUBLISHER_FACTORY_H_ +#ifndef OSP_PUBLIC_SERVICE_PUBLISHER_FACTORY_H_ +#define OSP_PUBLIC_SERVICE_PUBLISHER_FACTORY_H_ #include <memory> @@ -15,7 +15,7 @@ class TaskRunner; namespace osp { -class MdnsServicePublisherFactory { +class ServicePublisherFactory { public: static std::unique_ptr<ServicePublisher> Create( const ServicePublisher::Config& config, @@ -26,4 +26,4 @@ class MdnsServicePublisherFactory { } // namespace osp } // namespace openscreen -#endif // OSP_PUBLIC_MDNS_SERVICE_PUBLISHER_FACTORY_H_ +#endif // OSP_PUBLIC_SERVICE_PUBLISHER_FACTORY_H_ diff --git a/platform/BUILD.gn b/platform/BUILD.gn index e98067f1..304c0853 100644 --- a/platform/BUILD.gn +++ b/platform/BUILD.gn @@ -33,6 +33,15 @@ source_set("base") { public_configs = [ "../build:openscreen_include_dirs" ] } +# Public API source files. May depend on nothing except :base. +source_set("logging") { + defines = [] + + sources = [ "api/logging.h" ] + + public_deps = [ ":base" ] +} + # Public API source files. These may depend on nothing except :base. source_set("api") { defines = [] @@ -56,7 +65,10 @@ source_set("api") { "api/udp_socket.h", ] - public_deps = [ ":base" ] + public_deps = [ + ":base", + ":logging", + ] } # The following target is only activated in standalone builds (see :platform). @@ -211,10 +223,22 @@ source_set("unittests") { "base/udp_packet_unittest.cc", ] + deps = [ + ":platform", + ":test", + "../third_party/abseil", + "../third_party/boringssl", + "../third_party/googletest:gmock", + "../third_party/googletest:gtest", + "../util", + ] + # The socket integration tests assume that you can Bind with UDP sockets, # which is simply not true when we are built inside of Chromium. if (!build_with_chromium) { sources += [ "api/socket_integration_unittest.cc" ] + + deps += [ ":standalone_impl" ] } # The unit tests in impl/ assume the standalone implementation is being used. @@ -238,14 +262,4 @@ source_set("unittests") { ] } } - - deps = [ - ":platform", - ":test", - "../third_party/abseil", - "../third_party/boringssl", - "../third_party/googletest:gmock", - "../third_party/googletest:gtest", - "../util", - ] } diff --git a/platform/api/tls_connection.cc b/platform/api/tls_connection.cc index 9668c114..12fdc5cd 100644 --- a/platform/api/tls_connection.cc +++ b/platform/api/tls_connection.cc @@ -9,4 +9,6 @@ namespace openscreen { TlsConnection::TlsConnection() = default; TlsConnection::~TlsConnection() = default; +TlsConnection::Client::~Client() = default; + } // namespace openscreen diff --git a/platform/api/tls_connection.h b/platform/api/tls_connection.h index 4d409cb5..1f47bce6 100644 --- a/platform/api/tls_connection.h +++ b/platform/api/tls_connection.h @@ -26,7 +26,7 @@ class TlsConnection { std::vector<uint8_t> block) = 0; protected: - virtual ~Client() = default; + virtual ~Client(); }; virtual ~TlsConnection(); @@ -40,9 +40,6 @@ class TlsConnection { // Sends a message. Returns true iff the message will be sent. [[nodiscard]] virtual bool Send(const void* data, size_t len) = 0; - // Get the local address. - virtual IPEndpoint GetLocalEndpoint() const = 0; - // Get the connected remote address. virtual IPEndpoint GetRemoteEndpoint() const = 0; diff --git a/platform/api/tls_connection_factory.cc b/platform/api/tls_connection_factory.cc index e64078f1..c23c9e7e 100644 --- a/platform/api/tls_connection_factory.cc +++ b/platform/api/tls_connection_factory.cc @@ -9,4 +9,6 @@ namespace openscreen { TlsConnectionFactory::TlsConnectionFactory() = default; TlsConnectionFactory::~TlsConnectionFactory() = default; +TlsConnectionFactory::Client::~Client() = default; + } // namespace openscreen diff --git a/platform/api/tls_connection_factory.h b/platform/api/tls_connection_factory.h index 80dc8ac6..b9d1e2f5 100644 --- a/platform/api/tls_connection_factory.h +++ b/platform/api/tls_connection_factory.h @@ -46,6 +46,9 @@ class TlsConnectionFactory { // Called when a non-recoverable error occurs. virtual void OnError(TlsConnectionFactory* factory, Error error) = 0; + + protected: + virtual ~Client(); }; // The connection factory requires a client for yielding creation results diff --git a/platform/api/udp_socket.cc b/platform/api/udp_socket.cc index 47eba8bd..d895cf04 100644 --- a/platform/api/udp_socket.cc +++ b/platform/api/udp_socket.cc @@ -9,4 +9,6 @@ namespace openscreen { UdpSocket::UdpSocket() = default; UdpSocket::~UdpSocket() = default; +UdpSocket::Client::~Client() = default; + } // namespace openscreen diff --git a/platform/api/udp_socket.h b/platform/api/udp_socket.h index 3baf4119..d77d95f9 100644 --- a/platform/api/udp_socket.h +++ b/platform/api/udp_socket.h @@ -30,7 +30,6 @@ class UdpSocket { // Client for the UdpSocket class. class Client { public: - virtual ~Client() = default; // Method called when the UDP socket is bound. Default implementation // does nothing, as clients may not care about the socket bind state. @@ -49,6 +48,9 @@ class UdpSocket { // Method called when a packet is read. virtual void OnRead(UdpSocket* socket, ErrorOr<UdpPacket> packet) = 0; + + protected: + virtual ~Client(); }; // Constants used to specify how we want packets sent from this socket. diff --git a/platform/base/error.cc b/platform/base/error.cc index bad8a84f..a9a146e3 100644 --- a/platform/base/error.cc +++ b/platform/base/error.cc @@ -254,10 +254,16 @@ std::ostream& operator<<(std::ostream& os, const Error::Code& code) { return os << "ProcessReceivedRecordFailure"; case Error::Code::kUnknownCodec: return os << "UnknownCodec"; + case Error::Code::kInvalidCodecParameter: + return os << "InvalidCodecParameter"; case Error::Code::kSocketFailure: return os << "SocketFailure"; case Error::Code::kUnencryptedOffer: return os << "UnencryptedOffer"; + case Error::Code::kRemotingNotSupported: + return os << "RemotingNotSupported"; + case Error::Code::kNegotiationFailure: + return os << "NegotiationFailure"; case Error::Code::kNone: break; } diff --git a/platform/base/error.h b/platform/base/error.h index 9deacd2f..2f9216f2 100644 --- a/platform/base/error.h +++ b/platform/base/error.h @@ -186,8 +186,14 @@ class Error { // Cast streaming errors kTypeError, kUnknownCodec, + kInvalidCodecParameter, kSocketFailure, - kUnencryptedOffer + kUnencryptedOffer, + kRemotingNotSupported, + + // A negotiation failure means that the current negotiation must be + // restarted by the sender. + kNegotiationFailure, }; Error(); diff --git a/platform/base/interface_info.cc b/platform/base/interface_info.cc index 2ada91be..5fb8c62f 100644 --- a/platform/base/interface_info.cc +++ b/platform/base/interface_info.cc @@ -5,6 +5,7 @@ #include "platform/base/interface_info.h" #include <algorithm> +#include <utility> namespace openscreen { @@ -46,6 +47,11 @@ IPAddress InterfaceInfo::GetIpAddressV6() const { return IPAddress{}; } +bool InterfaceInfo::HasHardwareAddress() const { + return std::any_of(hardware_address.begin(), hardware_address.end(), + [](uint8_t e) { return e != 0; }); +} + std::ostream& operator<<(std::ostream& out, const IPSubnet& subnet) { if (subnet.address.IsV6()) { out << '['; diff --git a/platform/base/interface_info.h b/platform/base/interface_info.h index 81686063..01944870 100644 --- a/platform/base/interface_info.h +++ b/platform/base/interface_info.h @@ -63,6 +63,9 @@ struct InterfaceInfo { IPAddress GetIpAddressV4() const; IPAddress GetIpAddressV6() const; + // Returns true if |hardware_address| is non-zero. + bool HasHardwareAddress() const; + InterfaceInfo(); InterfaceInfo(NetworkInterfaceIndex index, const uint8_t hardware_address[6], diff --git a/platform/base/trace_logging_types.h b/platform/base/trace_logging_types.h index 257feaff..73592557 100644 --- a/platform/base/trace_logging_types.h +++ b/platform/base/trace_logging_types.h @@ -62,6 +62,8 @@ struct TraceCategory { kStandaloneReceiver = 0x01 << 4, kDiscovery = 0x01 << 5, kStandaloneSender = 0x01 << 6, + kReceiver = 0x01 << 7, + kSender = 0x01 << 8 }; }; diff --git a/platform/impl/network_interface_linux.cc b/platform/impl/network_interface_linux.cc index 48351ae8..1678acd7 100644 --- a/platform/impl/network_interface_linux.cc +++ b/platform/impl/network_interface_linux.cc @@ -168,13 +168,14 @@ std::vector<InterfaceInfo> GetLinkInfo() { request.header.nlmsg_pid = 0; request.msg.ifi_family = AF_UNSPEC; struct iovec iov = {&request, request.header.nlmsg_len}; - struct msghdr msg = {&peer, - sizeof(peer), - &iov, - /* msg_iovlen */ 1, - /* msg_control */ nullptr, - /* msg_controllen */ 0, - /* msg_flags */ 0}; + struct msghdr msg = {}; + msg.msg_name = &peer; + msg.msg_namelen = sizeof(peer); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = nullptr; + msg.msg_controllen = 0; + msg.msg_flags = 0; if (sendmsg(fd.get(), &msg, 0) < 0) { OSP_LOG_ERROR << "netlink sendmsg() failed: " << errno << " - " << strerror(errno); @@ -187,14 +188,16 @@ std::vector<InterfaceInfo> GetLinkInfo() { char buf[kNetlinkRecvmsgBufSize]; struct iovec iov = {buf, sizeof(buf)}; struct sockaddr_nl source_address; - struct msghdr msg; + struct msghdr msg = {}; struct nlmsghdr* netlink_header; - msg = {&source_address, sizeof(source_address), &iov, - /* msg_iovlen */ 1, - /* msg_control */ nullptr, - /* msg_controllen */ 0, - /* msg_flags */ 0}; + msg.msg_name = &source_address; + msg.msg_namelen = sizeof(source_address); + msg.msg_iov = &iov; + msg.msg_iovlen = 1, + msg.msg_control = nullptr, + msg.msg_controllen = 0, + msg.msg_flags = 0; bool done = false; while (!done) { @@ -269,13 +272,14 @@ void PopulateSubnetsOrClearList(std::vector<InterfaceInfo>* info_list) { request.header.nlmsg_pid = 0; request.msg.ifa_family = AF_UNSPEC; struct iovec iov = {&request, request.header.nlmsg_len}; - struct msghdr msg = {&peer, - sizeof(peer), - &iov, - /* msg_iovlen */ 1, - /* msg_control */ nullptr, - /* msg_controllen */ 0, - /* msg_flags */ 0}; + struct msghdr msg = {}; + msg.msg_name = &peer; + msg.msg_namelen = sizeof(peer); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = nullptr; + msg.msg_controllen = 0; + msg.msg_flags = 0; if (sendmsg(fd.get(), &msg, 0) < 0) { OSP_LOG_ERROR << "sendmsg failed: " << errno << " - " << strerror(errno); info_list->clear(); @@ -287,14 +291,16 @@ void PopulateSubnetsOrClearList(std::vector<InterfaceInfo>* info_list) { char buf[kNetlinkRecvmsgBufSize]; struct iovec iov = {buf, sizeof(buf)}; struct sockaddr_nl source_address; - struct msghdr msg; + struct msghdr msg = {}; struct nlmsghdr* netlink_header; - msg = {&source_address, sizeof(source_address), &iov, - /* msg_iovlen */ 1, - /* msg_control */ nullptr, - /* msg_controllen */ 0, - /* msg_flags */ 0}; + msg.msg_name = &source_address; + msg.msg_namelen = sizeof(source_address); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = nullptr; + msg.msg_controllen = 0; + msg.msg_flags = 0; bool done = false; while (!done) { size_t len = recvmsg(fd.get(), &msg, 0); diff --git a/platform/impl/network_interface_mac.cc b/platform/impl/network_interface_mac.cc index bb7bd583..e101beba 100644 --- a/platform/impl/network_interface_mac.cc +++ b/platform/impl/network_interface_mac.cc @@ -6,6 +6,7 @@ #include <net/if_dl.h> #include <net/if_media.h> #include <netinet/in.h> +#include <netinet/in_var.h> #include <sys/ioctl.h> #include <sys/socket.h> #include <sys/types.h> @@ -133,6 +134,15 @@ std::vector<InterfaceInfo> ProcessInterfacesList(ifaddrs* interfaces) { memcpy(&interface->hardware_address[0], &lladdr[0], sizeof(interface->hardware_address)); } else if (cur->ifa_addr->sa_family == AF_INET6) { // Ipv6 address. + struct in6_ifreq ifr = {}; + // Reject network interfaces that have a deprecated flag set. + strncpy(ifr.ifr_name, cur->ifa_name, sizeof(ifr.ifr_name) - 1); + memcpy(&ifr.ifr_ifru.ifru_addr, cur->ifa_addr, cur->ifa_addr->sa_len); + if (ioctl(ioctl_socket.get(), SIOCGIFAFLAG_IN6, &ifr) != 0 || + ifr.ifr_ifru.ifru_flags & IN6_IFF_DEPRECATED) { + continue; + } + auto* const addr_in6 = reinterpret_cast<const sockaddr_in6*>(cur->ifa_addr); uint8_t tmp[sizeof(addr_in6->sin6_addr.s6_addr)]; diff --git a/platform/impl/task_runner.cc b/platform/impl/task_runner.cc index f306e792..35b9a5e2 100644 --- a/platform/impl/task_runner.cc +++ b/platform/impl/task_runner.cc @@ -129,7 +129,6 @@ void TaskRunnerImpl::RequestStopSoon() { } void TaskRunnerImpl::RunRunnableTasks() { - OSP_DVLOG << "Running " << running_tasks_.size() << " tasks..."; for (TaskWithMetadata& running_task : running_tasks_) { // Move the task to the stack so that its bound state is freed immediately // after being run. diff --git a/platform/impl/tls_connection_posix.cc b/platform/impl/tls_connection_posix.cc index ad77fdec..779541bc 100644 --- a/platform/impl/tls_connection_posix.cc +++ b/platform/impl/tls_connection_posix.cc @@ -102,14 +102,6 @@ bool TlsConnectionPosix::Send(const void* data, size_t len) { return buffer_.Push(data, len); } -IPEndpoint TlsConnectionPosix::GetLocalEndpoint() const { - OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); - - absl::optional<IPEndpoint> endpoint = socket_->local_address(); - OSP_DCHECK(endpoint.has_value()); - return endpoint.value(); -} - IPEndpoint TlsConnectionPosix::GetRemoteEndpoint() const { OSP_DCHECK(task_runner_->IsRunningOnTaskRunner()); diff --git a/platform/impl/tls_connection_posix.h b/platform/impl/tls_connection_posix.h index c78bf5f1..5655ffe0 100644 --- a/platform/impl/tls_connection_posix.h +++ b/platform/impl/tls_connection_posix.h @@ -34,7 +34,6 @@ class TlsConnectionPosix : public TlsConnection { // TlsConnection overrides. void SetClient(Client* client) override; bool Send(const void* data, size_t len) override; - IPEndpoint GetLocalEndpoint() const override; IPEndpoint GetRemoteEndpoint() const override; // Registers |this| with the platform TlsDataRouterPosix. This is called diff --git a/platform/test/mock_tls_connection.h b/platform/test/mock_tls_connection.h index 1865c9e9..66d7adce 100644 --- a/platform/test/mock_tls_connection.h +++ b/platform/test/mock_tls_connection.h @@ -5,6 +5,9 @@ #ifndef PLATFORM_TEST_MOCK_TLS_CONNECTION_H_ #define PLATFORM_TEST_MOCK_TLS_CONNECTION_H_ +#include <utility> +#include <vector> + #include "gmock/gmock.h" #include "platform/api/tls_connection.h" @@ -24,7 +27,6 @@ class MockTlsConnection : public TlsConnection { MOCK_METHOD(bool, Send, (const void* data, size_t len), (override)); - IPEndpoint GetLocalEndpoint() const override { return local_address_; } IPEndpoint GetRemoteEndpoint() const override { return remote_address_; } void OnError(Error error) { diff --git a/test/BUILD.gn b/test/BUILD.gn index 663b26d4..be8a1534 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -7,9 +7,7 @@ import("//build_overrides/build.gni") source_set("test_main") { testonly = true - sources = [ - "test_main.cc", - ] + sources = [ "test_main.cc" ] if (!build_with_chromium) { defines = [ "ENABLE_PLATFORM_IMPL" ] @@ -17,6 +15,7 @@ source_set("test_main") { deps = [ "../platform", + "../platform:standalone_impl", "../third_party/googletest:gtest", ] } diff --git a/test/test_main.cc b/test/test_main.cc index 443278c7..b678ce9d 100644 --- a/test/test_main.cc +++ b/test/test_main.cc @@ -87,8 +87,9 @@ GlobalTestState InitFromArgs(int argc, char** argv) { // Googletest strongly recommends that we roll our own main // function if we want to do global test environment setup. // See the below link for more info; -// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md#sharing-resources-between-tests-in-the-same-test-case -// +// https://github.com/google/googletest/blob/master/docs/advanced.md#sharing-resources-between-tests-in-the-same-test-suite +// TODO(issuetracker.google.com/172242670): rename reference to "main" +// once googletest has a "main" branch. // This main method is a drop-in replacement for anywhere that currently // depends on gtest_main, meaning it can be linked into any test-only binary // to provide a main implementation that supports setting flags and other diff --git a/testing/libfuzzer/archive_corpus.py b/testing/libfuzzer/archive_corpus.py index e80848dd..e1960854 100755 --- a/testing/libfuzzer/archive_corpus.py +++ b/testing/libfuzzer/archive_corpus.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python2 +#!/usr/bin/env python3 # # Copyright 2019 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be diff --git a/testing/libfuzzer/fuzzer_test.gni b/testing/libfuzzer/fuzzer_test.gni index e88231c3..8ffd5afc 100644 --- a/testing/libfuzzer/fuzzer_test.gni +++ b/testing/libfuzzer/fuzzer_test.gni @@ -154,7 +154,7 @@ template("openscreen_fuzzer_test") { } if (is_mac) { - sources += [ "//testing/libfuzzer/libfuzzer_exports.h" ] + sources += [ "//testing/libfuzzer/libfuzzer_exports_mac.h" ] } } } else { diff --git a/testing/libfuzzer/gen_fuzzer_config.py b/testing/libfuzzer/gen_fuzzer_config.py index bde9e146..d195b5a1 100755 --- a/testing/libfuzzer/gen_fuzzer_config.py +++ b/testing/libfuzzer/gen_fuzzer_config.py @@ -1,4 +1,4 @@ -#!/usr/bin/python2 +#!/usr/bin/env python3 # # Copyright (c) 2015 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be @@ -8,8 +8,8 @@ Invoked by GN from fuzzer_test.gni. """ -import ConfigParser import argparse +import configparser import os import sys @@ -52,7 +52,7 @@ def main(): args.asan_options or args.msan_options or args.ubsan_options): return - config = ConfigParser.ConfigParser() + config = configparser.ConfigParser() libfuzzer_options = [] if args.dict: libfuzzer_options.append(('dict', os.path.basename(args.dict))) diff --git a/testing/libfuzzer/libfuzzer_exports_mac.h b/testing/libfuzzer/libfuzzer_exports_mac.h new file mode 100644 index 00000000..ce34e6a0 --- /dev/null +++ b/testing/libfuzzer/libfuzzer_exports_mac.h @@ -0,0 +1,41 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef TESTING_LIBFUZZER_LIBFUZZER_EXPORTS_MAC_H_ +#define TESTING_LIBFUZZER_LIBFUZZER_EXPORTS_MAC_H_ + +// On macOS, the linker may strip symbols for functions that are not reachable +// by the program entrypoint. Several libFuzzer functions are resolved via +// dlsym at runtime and therefore may be dead-stripped as a result. Including +// this header in the fuzzer's implementation file will ensure that all the +// symbols are kept and exported. + +#define EXPORT_FUZZER_FUNCTION \ + __attribute__((used)) __attribute__((visibility("default"))) + +extern "C" { + +EXPORT_FUZZER_FUNCTION int LLVMFuzzerInitialize(int* argc, char*** argv); +EXPORT_FUZZER_FUNCTION int LLVMFuzzerTestOneInput(const uint8_t* data, + size_t size); +EXPORT_FUZZER_FUNCTION size_t LLVMFuzzerCustomMutator(uint8_t* data, + size_t size, + size_t max_size, + unsigned int seed); +EXPORT_FUZZER_FUNCTION size_t LLVMFuzzerCustomCrossOver(const uint8_t* data1, + size_t size1, + const uint8_t* data2, + size_t size2, + uint8_t* out, + size_t max_out_size, + unsigned int seed); +EXPORT_FUZZER_FUNCTION size_t LLVMFuzzerMutate(uint8_t* data, + size_t size, + size_t max_size); + +} // extern "C" + +#undef EXPORT_FUZZER_FUNCTION + +#endif // TESTING_LIBFUZZER_LIBFUZZER_EXPORTS_MAC_H_ diff --git a/third_party/aomedia/BUILD.gn b/third_party/aomedia/BUILD.gn new file mode 100644 index 00000000..1dd90696 --- /dev/null +++ b/third_party/aomedia/BUILD.gn @@ -0,0 +1,14 @@ +# Copyright (c) 2021 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +config("aomedia_config") { + include_dirs = [ "src" ] +} + +source_set("aomedia") { + sources = [ + "aom/aom_encoder.h", + "aom/aomcx.h", + ] +} diff --git a/third_party/aomedia/README.chromium b/third_party/aomedia/README.chromium new file mode 100644 index 00000000..c735109a --- /dev/null +++ b/third_party/aomedia/README.chromium @@ -0,0 +1,9 @@ +Name: AOM +URL: https://aomedia.googlesource.com/aom +Version: git +License: BSD +License File: src/LICENSE +Security Critical: no + +Description: +AOM is an AV1 codec library for encoding and decoding. diff --git a/third_party/googletest/BUILD.gn b/third_party/googletest/BUILD.gn index 27006582..93705bbf 100644 --- a/third_party/googletest/BUILD.gn +++ b/third_party/googletest/BUILD.gn @@ -11,9 +11,7 @@ if (build_with_chromium) { "//build/config/compiler:default_include_dirs", "../../build:openscreen_include_dirs", ] - public_deps = [ - "//third_party/googletest:gmock", - ] + public_deps = [ "//third_party/googletest:gmock" ] } source_set("gtest") { @@ -22,9 +20,7 @@ if (build_with_chromium) { "//build/config/compiler:default_include_dirs", "../../build:openscreen_include_dirs", ] - public_deps = [ - "//third_party/googletest:gtest", - ] + public_deps = [ "//third_party/googletest:gtest" ] } source_set("gtest_main") { @@ -33,9 +29,7 @@ if (build_with_chromium) { "//build/config/compiler:default_include_dirs", "../../build:openscreen_include_dirs", ] - public_deps = [ - "//third_party/googletest:gtest_main", - ] + public_deps = [ "//third_party/googletest:gtest_main" ] } } else { config("gmock_config") { @@ -68,7 +62,7 @@ if (build_with_chromium) { source_set("gmock") { testonly = true sources = [ - "src/googlemock/include/gmock.h", + "src/googlemock/include/gmock/gmock.h", "src/googlemock/src/gmock-all.cc", ] @@ -77,9 +71,7 @@ if (build_with_chromium) { ":gtest_config", ] - public_deps = [ - ":gtest", - ] + public_deps = [ ":gtest" ] include_dirs = [ "src/googlemock" ] } @@ -87,7 +79,7 @@ if (build_with_chromium) { source_set("gtest") { testonly = true sources = [ - "src/googletest/include/gtest.h", + "src/googletest/include/gtest/gtest.h", "src/googletest/src/gtest-all.cc", ] @@ -98,11 +90,7 @@ if (build_with_chromium) { source_set("gtest_main") { testonly = true - sources = [ - "src/googletest/src/gtest_main.cc", - ] - deps = [ - ":gtest", - ] + sources = [ "src/googletest/src/gtest_main.cc" ] + deps = [ ":gtest" ] } } diff --git a/third_party/jsoncpp/BUILD.gn b/third_party/jsoncpp/BUILD.gn index bc96a235..2350a72e 100644 --- a/third_party/jsoncpp/BUILD.gn +++ b/third_party/jsoncpp/BUILD.gn @@ -22,7 +22,6 @@ if (build_with_chromium) { source_set("jsoncpp") { sources = [ "src/include/json/allocator.h", - "src/include/json/autolink.h", "src/include/json/config.h", "src/include/json/forwards.h", "src/include/json/json.h", diff --git a/third_party/libprotobuf-mutator/BUILD.gn b/third_party/libprotobuf-mutator/BUILD.gn index cc3eeaeb..cc8c6a01 100644 --- a/third_party/libprotobuf-mutator/BUILD.gn +++ b/third_party/libprotobuf-mutator/BUILD.gn @@ -1,4 +1,4 @@ -# Copyright 2020 The Chromium Authors. All rights reserved. +# Copyright 2017 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. @@ -7,14 +7,17 @@ import("//testing/libfuzzer/fuzzer_test.gni") import("//third_party/libprotobuf-mutator/fuzzable_proto_library.gni") config("include_config") { - include_dirs = [ "src/" ] + include_dirs = [ + "src/", + "//", + ] + cflags_cc = [ "-Wno-exit-time-destructors" ] } source_set("libprotobuf-mutator") { testonly = true configs += [ ":include_config" ] - public_configs = [ ":include_config" ] sources = [ "src/src/binary_format.cc", @@ -30,49 +33,28 @@ source_set("libprotobuf-mutator") { public_deps = [ "//third_party/protobuf:protobuf_full" ] } -# This protoc plugin, like the compiler, should only be built for the host -# architecture. -if (current_toolchain == host_toolchain) { - # This plugin will be needed to fuzz most protobuf code in Chromium. That's - # because production protobuf code must contain the line: - # "option optimize_for = LITE_RUNTIME", which instructs the proto compiler not - # to compile the proto using the full protobuf runtime. This allows Chromium - # not to depend on the full protobuf library, but prevents - # libprotobuf-mutator from fuzzing because the lite runtime lacks needed - # features (such as reflection). The plugin simply compiles a proto library - # as normal but ensures that is compiled with the full protobuf runtime. - executable("override_lite_runtime_plugin") { - sources = [ "protoc_plugin/protoc_plugin.cc" ] - deps = [ "//third_party/protobuf:protoc_lib" ] - public_configs = [ "//third_party/protobuf:protobuf_config" ] - } - # To use the plugin in a proto_library you want to fuzz, change the build - # target to fuzzable_proto_library (defined in - # //third_party/libprotobuf-mutator/fuzzable_proto_library.gni) -} - # The CQ will try building this target without "use_libfuzzer" if it is defined. # That will cause the build to fail, so don't define it when "use_libfuzzer" is # is false. if (use_libfuzzer) { - # Test that override_lite_runtime_plugin is working when built. This target - # contains files that are optimized for LITE_RUNTIME and which import other - # files that are also optimized for LITE_RUNTIME. - openscreen_fuzzer_test("override_lite_runtime_plugin_test_fuzzer") { - sources = [ "protoc_plugin/test_fuzzer.cc" ] + # Test that fuzzable_proto_library works. This target contains files that are + # optimized for LITE_RUNTIME and which import other files that are also + # optimized for LITE_RUNTIME. + openscreen_fuzzer_test("lpm_test_fuzzer") { + sources = [ "test_fuzzer/test_fuzzer.cc" ] deps = [ ":libprotobuf-mutator", - ":override_lite_runtime_plugin_test_fuzzer_proto", + ":lpm_test_fuzzer_proto", ] } } -# Proto library for override_lite_runtime_plugin_test_fuzzer -fuzzable_proto_library("override_lite_runtime_plugin_test_fuzzer_proto") { +# Proto library for lpm_test_fuzzer +fuzzable_proto_library("lpm_test_fuzzer_proto") { sources = [ - "protoc_plugin/imported.proto", - "protoc_plugin/imported_publicly.proto", - "protoc_plugin/test_fuzzer_input.proto", + "test_fuzzer/imported.proto", + "test_fuzzer/imported_publicly.proto", + "test_fuzzer/test_fuzzer_input.proto", ] } @@ -83,5 +65,6 @@ if (use_libfuzzer) { # Component that can provide protobuf_full to non-testonly targets static_library("protobuf_full") { public_deps = [ "//third_party/protobuf:protobuf_full" ] + sources = [ "dummy.cc" ] } } diff --git a/third_party/libprotobuf-mutator/dummy.cc b/third_party/libprotobuf-mutator/dummy.cc new file mode 100644 index 00000000..8df1899d --- /dev/null +++ b/third_party/libprotobuf-mutator/dummy.cc @@ -0,0 +1,6 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Dummy file used to ensure that wrapper libraries get built on non-Linux +// platforms.
\ No newline at end of file diff --git a/third_party/libprotobuf-mutator/fuzzable_proto_library.gni b/third_party/libprotobuf-mutator/fuzzable_proto_library.gni index fee136c6..2d357db7 100644 --- a/third_party/libprotobuf-mutator/fuzzable_proto_library.gni +++ b/third_party/libprotobuf-mutator/fuzzable_proto_library.gni @@ -6,7 +6,7 @@ # non-fuzzer builds (ie: use_libfuzzer=false). However, in fuzzer builds, the # proto_library is built with the full protobuf runtime and any "optimize_for = # LITE_RUNTIME" options are ignored. This is done because libprotobuf-mutator -# needs the full protobuf runtime, but proto_libraries shipped in chrome must +# needs the full protobuf runtime, but proto_libraries shipped in Chrome must # use the optimize for LITE_RUNTIME option which is incompatible with the full # protobuf runtime. tl;dr: A fuzzable_proto_library is a proto_library that can # be fuzzed with libprotobuf-mutator and shipped in Chrome. @@ -16,18 +16,12 @@ import("//testing/libfuzzer/fuzzer_test.gni") import("//third_party/protobuf/proto_library.gni") template("fuzzable_proto_library") { - # Only make the proto library fuzzable if we are doing a build that we can - # use LPM on (i.e. libFuzzer not on Chrome OS). - if (use_libfuzzer && current_toolchain != "//build/toolchain/cros:target") { + if (use_libfuzzer) { proto_library("proto_library_" + target_name) { forward_variables_from(invoker, "*") assert(current_toolchain == host_toolchain) - if (!defined(proto_deps)) { - proto_deps = [] - } - proto_deps += - [ "//third_party/libprotobuf-mutator:override_lite_runtime_plugin" ] + cc_generator_options = "speed" extra_configs = [ "//third_party/protobuf:protobuf_config" ] } diff --git a/third_party/libprotobuf-mutator/test_fuzzer/imported.proto b/third_party/libprotobuf-mutator/test_fuzzer/imported.proto new file mode 100644 index 00000000..f347c366 --- /dev/null +++ b/third_party/libprotobuf-mutator/test_fuzzer/imported.proto @@ -0,0 +1,17 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Ensure imported files are handled properly. This file is imported by +// test_fuzzer_input.proto and imports imported_publicly publicly. + +syntax = "proto2"; +option optimize_for = LITE_RUNTIME; +package lpm_test_fuzzer; + +// Test public imported files are handled properly. +import public "imported_publicly.proto"; + +message Imported { + required ImportedPublicly imported_publicly = 1; +} diff --git a/third_party/libprotobuf-mutator/test_fuzzer/imported_publicly.proto b/third_party/libprotobuf-mutator/test_fuzzer/imported_publicly.proto new file mode 100644 index 00000000..10768495 --- /dev/null +++ b/third_party/libprotobuf-mutator/test_fuzzer/imported_publicly.proto @@ -0,0 +1,14 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Ensure publicly imported files are handled properly. This file is imported +// publicly by test_fuzzer_input2.proto + +syntax = "proto2"; +option optimize_for = LITE_RUNTIME; +package lpm_test_fuzzer; + +message ImportedPublicly { + required int32 input = 1; +} diff --git a/third_party/libprotobuf-mutator/test_fuzzer/test_fuzzer.cc b/third_party/libprotobuf-mutator/test_fuzzer/test_fuzzer.cc new file mode 100644 index 00000000..e7af5346 --- /dev/null +++ b/third_party/libprotobuf-mutator/test_fuzzer/test_fuzzer.cc @@ -0,0 +1,16 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Test fuzzer that when built successfully proves that fuzzable_proto_library +// is working. Building this fuzzer without using fuzzable_proto_library will +// fail because of test_fuzzer_input.proto + +#include <iostream> + +#include "third_party/libprotobuf-mutator/src/src/libfuzzer/libfuzzer_macro.h" +#include "third_party/libprotobuf-mutator/test_fuzzer/test_fuzzer_input.pb.h" + +DEFINE_PROTO_FUZZER(const lpm_test_fuzzer::TestFuzzerInput& input) { + std::cout << input.imported().imported_publicly().input() << std::endl; +} diff --git a/third_party/libprotobuf-mutator/test_fuzzer/test_fuzzer_input.proto b/third_party/libprotobuf-mutator/test_fuzzer/test_fuzzer_input.proto new file mode 100644 index 00000000..45645fc9 --- /dev/null +++ b/third_party/libprotobuf-mutator/test_fuzzer/test_fuzzer_input.proto @@ -0,0 +1,22 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Depended on by lpm_test_fuzzer. Tests whether fuzzable_proto_library is +// working since without it builds will fail because of the optimize_for +// LITE_RUNTIME option this file has set. Also imports a file that does the same +// thing. + +syntax = "proto2"; + +// This line is essentially the purpose of this test fuzzer. The build rule, if +// working, ignores this line. If it is not working or isn't used, then this +// build will fail. +option optimize_for = LITE_RUNTIME; + +package lpm_test_fuzzer; +import "imported.proto"; + +message TestFuzzerInput { + required Imported imported = 1; +}
\ No newline at end of file diff --git a/third_party/mDNSResponder/BUILD.gn b/third_party/mDNSResponder/BUILD.gn deleted file mode 100644 index bb0047ee..00000000 --- a/third_party/mDNSResponder/BUILD.gn +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2018 The Chromium Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. - -config("mdnsresponder_config") { - cflags = [ "-w" ] # Disable all warnings. - - cflags_c = [ - # We need to rename some linked symbols in order to avoid multiple - # definitions. - "-DMD5_Update=MD5_Update_mDNS", - "-DMD5_Init=MD5_Init_mDNS", - "-DMD5_Final=MD5_Final_mDNS", - "-DMD5_Transform=MD5_Transform_mDNS", - ] -} - -source_set("core") { - sources = [ - "src/mDNSCore/DNSCommon.c", - "src/mDNSCore/DNSCommon.h", - "src/mDNSCore/DNSDigest.c", - "src/mDNSCore/mDNS.c", - "src/mDNSCore/mDNSDebug.h", - "src/mDNSCore/mDNSEmbeddedAPI.h", - "src/mDNSCore/uDNS.c", - "src/mDNSCore/uDNS.h", - "src/mDNSShared/mDNSDebug.c", - ] - - configs += [ ":mdnsresponder_config" ] - - if (is_debug) { - defines = [ "MDNS_DEBUGMSGS=2" ] - } - - include_dirs = [ "src/mDNSCore" ] -} diff --git a/third_party/mDNSResponder/README.chromium b/third_party/mDNSResponder/README.chromium deleted file mode 100644 index 77469188..00000000 --- a/third_party/mDNSResponder/README.chromium +++ /dev/null @@ -1,10 +0,0 @@ -Name: mDNSResponder -URL: https://github.com/jevinskie/mDNSResponder -License: Apache License, Version 2.0 -License File: src/LICENSE -Security Critical: no - -Description: - -Pull from Apple Bonjour's MDNS/DNS-SD implementation. Will eventually be -replaced with our custom implementation, currently only used in osp. diff --git a/third_party/protobuf/BUILD.gn b/third_party/protobuf/BUILD.gn index 13793df7..1616b80e 100644 --- a/third_party/protobuf/BUILD.gn +++ b/third_party/protobuf/BUILD.gn @@ -29,6 +29,7 @@ config("protobuf_warnings") { "-Wno-extra-semi", "-Wno-unneeded-internal-declaration", "-Wno-unused-private-field", + "-Wno-inconsistent-missing-override", ] } @@ -73,7 +74,6 @@ lite_sources = [ "src/src/google/protobuf/has_bits.h", "src/src/google/protobuf/implicit_weak_message.cc", "src/src/google/protobuf/implicit_weak_message.h", - "src/src/google/protobuf/inlined_string_field.h", "src/src/google/protobuf/io/coded_stream.cc", "src/src/google/protobuf/io/coded_stream.h", "src/src/google/protobuf/io/io_win32.cc", @@ -103,7 +103,6 @@ lite_sources = [ "src/src/google/protobuf/stubs/casts.h", "src/src/google/protobuf/stubs/common.cc", "src/src/google/protobuf/stubs/common.h", - "src/src/google/protobuf/stubs/fastmem.h", "src/src/google/protobuf/stubs/hash.h", "src/src/google/protobuf/stubs/int128.cc", "src/src/google/protobuf/stubs/int128.h", @@ -313,6 +312,8 @@ if (current_toolchain == host_toolchain) { "src/src/google/protobuf/compiler/cpp/cpp_options.h", "src/src/google/protobuf/compiler/cpp/cpp_padding_optimizer.cc", "src/src/google/protobuf/compiler/cpp/cpp_padding_optimizer.h", + "src/src/google/protobuf/compiler/cpp/cpp_parse_function_generator.cc", + "src/src/google/protobuf/compiler/cpp/cpp_parse_function_generator.h", "src/src/google/protobuf/compiler/cpp/cpp_primitive_field.cc", "src/src/google/protobuf/compiler/cpp/cpp_primitive_field.h", "src/src/google/protobuf/compiler/cpp/cpp_service.cc", @@ -378,6 +379,8 @@ if (current_toolchain == host_toolchain) { "src/src/google/protobuf/compiler/java/java_generator_factory.h", "src/src/google/protobuf/compiler/java/java_helpers.cc", "src/src/google/protobuf/compiler/java/java_helpers.h", + "src/src/google/protobuf/compiler/java/java_kotlin_generator.cc", + "src/src/google/protobuf/compiler/java/java_kotlin_generator.h", "src/src/google/protobuf/compiler/java/java_map_field.cc", "src/src/google/protobuf/compiler/java/java_map_field.h", "src/src/google/protobuf/compiler/java/java_map_field_lite.cc", diff --git a/third_party/protobuf/proto_library.gni b/third_party/protobuf/proto_library.gni index 7bbc73b7..9c550457 100644 --- a/third_party/protobuf/proto_library.gni +++ b/third_party/protobuf/proto_library.gni @@ -58,6 +58,12 @@ template("proto_library") { rel_cc_out_dir, ] + if (defined(invoker.cc_generator_options)) { + args += [ + "--cc-options", + invoker.cc_generator_options, + ] + } inputs = [ protoc_path ] deps = [ protoc_label ] } diff --git a/third_party/quiche/BUILD.gn b/third_party/quiche/BUILD.gn new file mode 100644 index 00000000..b6c7a5f6 --- /dev/null +++ b/third_party/quiche/BUILD.gn @@ -0,0 +1,626 @@ +# Copyright (c) 2021 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +config("quiche_config") { + include_dirs = [ + # The ordering here is important, since headers in overrides/ replace + # headers in src/common/platform/default. + "overrides", + "src/common/platform/default", + "src", + ] +} + +# TODO(https://issuetracker.google.com/issues/169447969): This is not expected +# to compile because the QUICHE platform depends on Chromium //net and //base, +# which are not available in Open Screen. +source_set("quiche") { + sources = [ + "overrides/quiche_platform_impl/quic_mutex_impl.cc", + "overrides/quiche_platform_impl/quic_mutex_impl.h", + "overrides/quiche_platform_impl/quiche_bug_tracker_impl.h", + "overrides/quiche_platform_impl/quiche_export_impl.h", + "overrides/quiche_platform_impl/quiche_logging_impl.h", + "overrides/quiche_platform_impl/quiche_thread_local_impl.h", + "overrides/quiche_platform_impl/quiche_time_utils_impl.cc", + "overrides/quiche_platform_impl/quiche_time_utils_impl.h", + "src/common/platform/api/quiche_export.h", + "src/common/platform/api/quiche_flag_utils.h", + "src/common/platform/api/quiche_flags.h", + "src/common/platform/api/quiche_logging.h", + "src/common/platform/api/quiche_prefetch.h", + "src/common/platform/api/quiche_thread_local.h", + "src/common/platform/api/quiche_time_utils.h", + "src/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h", + "src/common/print_elements.h", + "src/common/quiche_circular_deque.h", + "src/common/quiche_data_reader.cc", + "src/common/quiche_data_reader.h", + "src/common/quiche_data_writer.cc", + "src/common/quiche_data_writer.h", + "src/common/quiche_endian.h", + "src/common/quiche_linked_hash_map.h", + "src/common/quiche_text_utils.cc", + "src/common/quiche_text_utils.h", + "src/http2/core/http2_priority_write_scheduler.h", + "src/http2/core/priority_write_scheduler.h", + "src/http2/core/write_scheduler.h", + "src/http2/decoder/decode_buffer.cc", + "src/http2/decoder/decode_buffer.h", + "src/http2/decoder/decode_http2_structures.cc", + "src/http2/decoder/decode_http2_structures.h", + "src/http2/decoder/decode_status.cc", + "src/http2/decoder/decode_status.h", + "src/http2/decoder/frame_decoder_state.cc", + "src/http2/decoder/frame_decoder_state.h", + "src/http2/decoder/http2_frame_decoder.cc", + "src/http2/decoder/http2_frame_decoder.h", + "src/http2/decoder/http2_frame_decoder_listener.cc", + "src/http2/decoder/http2_frame_decoder_listener.h", + "src/http2/decoder/http2_structure_decoder.cc", + "src/http2/decoder/http2_structure_decoder.h", + "src/http2/decoder/payload_decoders/altsvc_payload_decoder.cc", + "src/http2/decoder/payload_decoders/altsvc_payload_decoder.h", + "src/http2/decoder/payload_decoders/continuation_payload_decoder.cc", + "src/http2/decoder/payload_decoders/continuation_payload_decoder.h", + "src/http2/decoder/payload_decoders/data_payload_decoder.cc", + "src/http2/decoder/payload_decoders/data_payload_decoder.h", + "src/http2/decoder/payload_decoders/goaway_payload_decoder.cc", + "src/http2/decoder/payload_decoders/goaway_payload_decoder.h", + "src/http2/decoder/payload_decoders/headers_payload_decoder.cc", + "src/http2/decoder/payload_decoders/headers_payload_decoder.h", + "src/http2/decoder/payload_decoders/ping_payload_decoder.cc", + "src/http2/decoder/payload_decoders/ping_payload_decoder.h", + "src/http2/decoder/payload_decoders/priority_payload_decoder.cc", + "src/http2/decoder/payload_decoders/priority_payload_decoder.h", + "src/http2/decoder/payload_decoders/priority_update_payload_decoder.cc", + "src/http2/decoder/payload_decoders/priority_update_payload_decoder.h", + "src/http2/decoder/payload_decoders/push_promise_payload_decoder.cc", + "src/http2/decoder/payload_decoders/push_promise_payload_decoder.h", + "src/http2/decoder/payload_decoders/rst_stream_payload_decoder.cc", + "src/http2/decoder/payload_decoders/rst_stream_payload_decoder.h", + "src/http2/decoder/payload_decoders/settings_payload_decoder.cc", + "src/http2/decoder/payload_decoders/settings_payload_decoder.h", + "src/http2/decoder/payload_decoders/unknown_payload_decoder.cc", + "src/http2/decoder/payload_decoders/unknown_payload_decoder.h", + "src/http2/decoder/payload_decoders/window_update_payload_decoder.cc", + "src/http2/decoder/payload_decoders/window_update_payload_decoder.h", + "src/http2/hpack/decoder/hpack_block_decoder.cc", + "src/http2/hpack/decoder/hpack_block_decoder.h", + "src/http2/hpack/decoder/hpack_decoder.cc", + "src/http2/hpack/decoder/hpack_decoder.h", + "src/http2/hpack/decoder/hpack_decoder_listener.cc", + "src/http2/hpack/decoder/hpack_decoder_listener.h", + "src/http2/hpack/decoder/hpack_decoder_state.cc", + "src/http2/hpack/decoder/hpack_decoder_state.h", + "src/http2/hpack/decoder/hpack_decoder_string_buffer.cc", + "src/http2/hpack/decoder/hpack_decoder_string_buffer.h", + "src/http2/hpack/decoder/hpack_decoder_tables.cc", + "src/http2/hpack/decoder/hpack_decoder_tables.h", + "src/http2/hpack/decoder/hpack_decoding_error.cc", + "src/http2/hpack/decoder/hpack_decoding_error.h", + "src/http2/hpack/decoder/hpack_entry_decoder.cc", + "src/http2/hpack/decoder/hpack_entry_decoder.h", + "src/http2/hpack/decoder/hpack_entry_decoder_listener.cc", + "src/http2/hpack/decoder/hpack_entry_decoder_listener.h", + "src/http2/hpack/decoder/hpack_entry_type_decoder.cc", + "src/http2/hpack/decoder/hpack_entry_type_decoder.h", + "src/http2/hpack/decoder/hpack_string_decoder.cc", + "src/http2/hpack/decoder/hpack_string_decoder.h", + "src/http2/hpack/decoder/hpack_string_decoder_listener.cc", + "src/http2/hpack/decoder/hpack_string_decoder_listener.h", + "src/http2/hpack/decoder/hpack_whole_entry_buffer.cc", + "src/http2/hpack/decoder/hpack_whole_entry_buffer.h", + "src/http2/hpack/decoder/hpack_whole_entry_listener.cc", + "src/http2/hpack/decoder/hpack_whole_entry_listener.h", + "src/http2/hpack/hpack_static_table_entries.inc", + "src/http2/hpack/http2_hpack_constants.cc", + "src/http2/hpack/http2_hpack_constants.h", + "src/http2/hpack/huffman/hpack_huffman_decoder.cc", + "src/http2/hpack/huffman/hpack_huffman_decoder.h", + "src/http2/hpack/huffman/hpack_huffman_encoder.cc", + "src/http2/hpack/huffman/hpack_huffman_encoder.h", + "src/http2/hpack/huffman/huffman_spec_tables.cc", + "src/http2/hpack/huffman/huffman_spec_tables.h", + "src/http2/hpack/varint/hpack_varint_decoder.cc", + "src/http2/hpack/varint/hpack_varint_decoder.h", + "src/http2/hpack/varint/hpack_varint_encoder.cc", + "src/http2/hpack/varint/hpack_varint_encoder.h", + "src/http2/http2_constants.cc", + "src/http2/http2_constants.h", + "src/http2/http2_structures.cc", + "src/http2/http2_structures.h", + "src/http2/platform/api/http2_bug_tracker.h", + "src/http2/platform/api/http2_flag_utils.h", + "src/http2/platform/api/http2_flags.h", + "src/http2/platform/api/http2_logging.h", + "src/http2/platform/api/http2_macros.h", + "src/quic/core/congestion_control/bandwidth_sampler.cc", + "src/quic/core/congestion_control/bandwidth_sampler.h", + "src/quic/core/congestion_control/bbr2_drain.cc", + "src/quic/core/congestion_control/bbr2_drain.h", + "src/quic/core/congestion_control/bbr2_misc.cc", + "src/quic/core/congestion_control/bbr2_misc.h", + "src/quic/core/congestion_control/bbr2_probe_bw.cc", + "src/quic/core/congestion_control/bbr2_probe_bw.h", + "src/quic/core/congestion_control/bbr2_probe_rtt.cc", + "src/quic/core/congestion_control/bbr2_probe_rtt.h", + "src/quic/core/congestion_control/bbr2_sender.cc", + "src/quic/core/congestion_control/bbr2_sender.h", + "src/quic/core/congestion_control/bbr2_startup.cc", + "src/quic/core/congestion_control/bbr2_startup.h", + "src/quic/core/congestion_control/bbr_sender.cc", + "src/quic/core/congestion_control/bbr_sender.h", + "src/quic/core/congestion_control/cubic_bytes.cc", + "src/quic/core/congestion_control/cubic_bytes.h", + "src/quic/core/congestion_control/general_loss_algorithm.cc", + "src/quic/core/congestion_control/general_loss_algorithm.h", + "src/quic/core/congestion_control/hybrid_slow_start.cc", + "src/quic/core/congestion_control/hybrid_slow_start.h", + "src/quic/core/congestion_control/loss_detection_interface.h", + "src/quic/core/congestion_control/pacing_sender.cc", + "src/quic/core/congestion_control/pacing_sender.h", + "src/quic/core/congestion_control/prr_sender.cc", + "src/quic/core/congestion_control/prr_sender.h", + "src/quic/core/congestion_control/rtt_stats.cc", + "src/quic/core/congestion_control/rtt_stats.h", + "src/quic/core/congestion_control/send_algorithm_interface.cc", + "src/quic/core/congestion_control/send_algorithm_interface.h", + "src/quic/core/congestion_control/tcp_cubic_sender_bytes.cc", + "src/quic/core/congestion_control/tcp_cubic_sender_bytes.h", + "src/quic/core/congestion_control/uber_loss_algorithm.cc", + "src/quic/core/congestion_control/uber_loss_algorithm.h", + "src/quic/core/congestion_control/windowed_filter.h", + "src/quic/core/crypto/aead_base_decrypter.cc", + "src/quic/core/crypto/aead_base_decrypter.h", + "src/quic/core/crypto/aead_base_encrypter.cc", + "src/quic/core/crypto/aead_base_encrypter.h", + "src/quic/core/crypto/aes_128_gcm_12_decrypter.cc", + "src/quic/core/crypto/aes_128_gcm_12_decrypter.h", + "src/quic/core/crypto/aes_128_gcm_12_encrypter.cc", + "src/quic/core/crypto/aes_128_gcm_12_encrypter.h", + "src/quic/core/crypto/aes_128_gcm_decrypter.cc", + "src/quic/core/crypto/aes_128_gcm_decrypter.h", + "src/quic/core/crypto/aes_128_gcm_encrypter.cc", + "src/quic/core/crypto/aes_128_gcm_encrypter.h", + "src/quic/core/crypto/aes_256_gcm_decrypter.cc", + "src/quic/core/crypto/aes_256_gcm_decrypter.h", + "src/quic/core/crypto/aes_256_gcm_encrypter.cc", + "src/quic/core/crypto/aes_256_gcm_encrypter.h", + "src/quic/core/crypto/aes_base_decrypter.cc", + "src/quic/core/crypto/aes_base_decrypter.h", + "src/quic/core/crypto/aes_base_encrypter.cc", + "src/quic/core/crypto/aes_base_encrypter.h", + "src/quic/core/crypto/boring_utils.h", + "src/quic/core/crypto/cert_compressor.cc", + "src/quic/core/crypto/cert_compressor.h", + "src/quic/core/crypto/certificate_view.cc", + "src/quic/core/crypto/certificate_view.h", + "src/quic/core/crypto/chacha20_poly1305_decrypter.cc", + "src/quic/core/crypto/chacha20_poly1305_decrypter.h", + "src/quic/core/crypto/chacha20_poly1305_encrypter.cc", + "src/quic/core/crypto/chacha20_poly1305_encrypter.h", + "src/quic/core/crypto/chacha20_poly1305_tls_decrypter.cc", + "src/quic/core/crypto/chacha20_poly1305_tls_decrypter.h", + "src/quic/core/crypto/chacha20_poly1305_tls_encrypter.cc", + "src/quic/core/crypto/chacha20_poly1305_tls_encrypter.h", + "src/quic/core/crypto/chacha_base_decrypter.cc", + "src/quic/core/crypto/chacha_base_decrypter.h", + "src/quic/core/crypto/chacha_base_encrypter.cc", + "src/quic/core/crypto/chacha_base_encrypter.h", + "src/quic/core/crypto/channel_id.cc", + "src/quic/core/crypto/channel_id.h", + "src/quic/core/crypto/common_cert_set.cc", + "src/quic/core/crypto/common_cert_set.h", + "src/quic/core/crypto/crypto_framer.cc", + "src/quic/core/crypto/crypto_framer.h", + "src/quic/core/crypto/crypto_handshake.cc", + "src/quic/core/crypto/crypto_handshake.h", + "src/quic/core/crypto/crypto_handshake_message.cc", + "src/quic/core/crypto/crypto_handshake_message.h", + "src/quic/core/crypto/crypto_message_parser.h", + "src/quic/core/crypto/crypto_protocol.h", + "src/quic/core/crypto/crypto_secret_boxer.cc", + "src/quic/core/crypto/crypto_secret_boxer.h", + "src/quic/core/crypto/crypto_utils.cc", + "src/quic/core/crypto/crypto_utils.h", + "src/quic/core/crypto/curve25519_key_exchange.cc", + "src/quic/core/crypto/curve25519_key_exchange.h", + "src/quic/core/crypto/key_exchange.cc", + "src/quic/core/crypto/key_exchange.h", + "src/quic/core/crypto/null_decrypter.cc", + "src/quic/core/crypto/null_decrypter.h", + "src/quic/core/crypto/null_encrypter.cc", + "src/quic/core/crypto/null_encrypter.h", + "src/quic/core/crypto/p256_key_exchange.cc", + "src/quic/core/crypto/p256_key_exchange.h", + "src/quic/core/crypto/proof_source.cc", + "src/quic/core/crypto/proof_source.h", + "src/quic/core/crypto/proof_verifier.h", + "src/quic/core/crypto/quic_compressed_certs_cache.cc", + "src/quic/core/crypto/quic_compressed_certs_cache.h", + "src/quic/core/crypto/quic_crypter.cc", + "src/quic/core/crypto/quic_crypter.h", + "src/quic/core/crypto/quic_crypto_client_config.cc", + "src/quic/core/crypto/quic_crypto_client_config.h", + "src/quic/core/crypto/quic_crypto_proof.cc", + "src/quic/core/crypto/quic_crypto_proof.h", + "src/quic/core/crypto/quic_crypto_server_config.cc", + "src/quic/core/crypto/quic_crypto_server_config.h", + "src/quic/core/crypto/quic_decrypter.cc", + "src/quic/core/crypto/quic_decrypter.h", + "src/quic/core/crypto/quic_encrypter.cc", + "src/quic/core/crypto/quic_encrypter.h", + "src/quic/core/crypto/quic_hkdf.cc", + "src/quic/core/crypto/quic_hkdf.h", + "src/quic/core/crypto/quic_random.cc", + "src/quic/core/crypto/quic_random.h", + "src/quic/core/crypto/server_proof_verifier.h", + "src/quic/core/crypto/tls_client_connection.cc", + "src/quic/core/crypto/tls_client_connection.h", + "src/quic/core/crypto/tls_connection.cc", + "src/quic/core/crypto/tls_connection.h", + "src/quic/core/crypto/tls_server_connection.cc", + "src/quic/core/crypto/tls_server_connection.h", + "src/quic/core/crypto/transport_parameters.cc", + "src/quic/core/crypto/transport_parameters.h", + "src/quic/core/frames/quic_ack_frame.cc", + "src/quic/core/frames/quic_ack_frame.h", + "src/quic/core/frames/quic_ack_frequency_frame.cc", + "src/quic/core/frames/quic_ack_frequency_frame.h", + "src/quic/core/frames/quic_blocked_frame.cc", + "src/quic/core/frames/quic_blocked_frame.h", + "src/quic/core/frames/quic_connection_close_frame.cc", + "src/quic/core/frames/quic_connection_close_frame.h", + "src/quic/core/frames/quic_crypto_frame.cc", + "src/quic/core/frames/quic_crypto_frame.h", + "src/quic/core/frames/quic_frame.cc", + "src/quic/core/frames/quic_frame.h", + "src/quic/core/frames/quic_goaway_frame.cc", + "src/quic/core/frames/quic_goaway_frame.h", + "src/quic/core/frames/quic_handshake_done_frame.cc", + "src/quic/core/frames/quic_handshake_done_frame.h", + "src/quic/core/frames/quic_inlined_frame.h", + "src/quic/core/frames/quic_max_streams_frame.cc", + "src/quic/core/frames/quic_max_streams_frame.h", + "src/quic/core/frames/quic_message_frame.cc", + "src/quic/core/frames/quic_message_frame.h", + "src/quic/core/frames/quic_mtu_discovery_frame.h", + "src/quic/core/frames/quic_new_connection_id_frame.cc", + "src/quic/core/frames/quic_new_connection_id_frame.h", + "src/quic/core/frames/quic_new_token_frame.cc", + "src/quic/core/frames/quic_new_token_frame.h", + "src/quic/core/frames/quic_padding_frame.cc", + "src/quic/core/frames/quic_padding_frame.h", + "src/quic/core/frames/quic_path_challenge_frame.cc", + "src/quic/core/frames/quic_path_challenge_frame.h", + "src/quic/core/frames/quic_path_response_frame.cc", + "src/quic/core/frames/quic_path_response_frame.h", + "src/quic/core/frames/quic_ping_frame.cc", + "src/quic/core/frames/quic_ping_frame.h", + "src/quic/core/frames/quic_retire_connection_id_frame.cc", + "src/quic/core/frames/quic_retire_connection_id_frame.h", + "src/quic/core/frames/quic_rst_stream_frame.cc", + "src/quic/core/frames/quic_rst_stream_frame.h", + "src/quic/core/frames/quic_stop_sending_frame.cc", + "src/quic/core/frames/quic_stop_sending_frame.h", + "src/quic/core/frames/quic_stop_waiting_frame.cc", + "src/quic/core/frames/quic_stop_waiting_frame.h", + "src/quic/core/frames/quic_stream_frame.cc", + "src/quic/core/frames/quic_stream_frame.h", + "src/quic/core/frames/quic_streams_blocked_frame.cc", + "src/quic/core/frames/quic_streams_blocked_frame.h", + "src/quic/core/frames/quic_window_update_frame.cc", + "src/quic/core/frames/quic_window_update_frame.h", + "src/quic/core/handshaker_delegate_interface.h", + "src/quic/core/http/http_constants.cc", + "src/quic/core/http/http_constants.h", + "src/quic/core/http/http_decoder.cc", + "src/quic/core/http/http_decoder.h", + "src/quic/core/http/http_encoder.cc", + "src/quic/core/http/http_encoder.h", + "src/quic/core/http/http_frames.h", + "src/quic/core/http/quic_client_promised_info.cc", + "src/quic/core/http/quic_client_promised_info.h", + "src/quic/core/http/quic_client_push_promise_index.cc", + "src/quic/core/http/quic_client_push_promise_index.h", + "src/quic/core/http/quic_header_list.cc", + "src/quic/core/http/quic_header_list.h", + "src/quic/core/http/quic_headers_stream.cc", + "src/quic/core/http/quic_headers_stream.h", + "src/quic/core/http/quic_receive_control_stream.cc", + "src/quic/core/http/quic_receive_control_stream.h", + "src/quic/core/http/quic_send_control_stream.cc", + "src/quic/core/http/quic_send_control_stream.h", + "src/quic/core/http/quic_server_initiated_spdy_stream.cc", + "src/quic/core/http/quic_server_initiated_spdy_stream.h", + "src/quic/core/http/quic_server_session_base.cc", + "src/quic/core/http/quic_server_session_base.h", + "src/quic/core/http/quic_spdy_client_session.cc", + "src/quic/core/http/quic_spdy_client_session.h", + "src/quic/core/http/quic_spdy_client_session_base.cc", + "src/quic/core/http/quic_spdy_client_session_base.h", + "src/quic/core/http/quic_spdy_client_stream.cc", + "src/quic/core/http/quic_spdy_client_stream.h", + "src/quic/core/http/quic_spdy_session.cc", + "src/quic/core/http/quic_spdy_session.h", + "src/quic/core/http/quic_spdy_stream.cc", + "src/quic/core/http/quic_spdy_stream.h", + "src/quic/core/http/quic_spdy_stream_body_manager.cc", + "src/quic/core/http/quic_spdy_stream_body_manager.h", + "src/quic/core/http/spdy_server_push_utils.cc", + "src/quic/core/http/spdy_server_push_utils.h", + "src/quic/core/http/spdy_utils.cc", + "src/quic/core/http/spdy_utils.h", + "src/quic/core/http/web_transport_http3.cc", + "src/quic/core/http/web_transport_http3.h", + "src/quic/core/legacy_quic_stream_id_manager.cc", + "src/quic/core/legacy_quic_stream_id_manager.h", + "src/quic/core/packet_number_indexed_queue.h", + "src/quic/core/proto/cached_network_parameters_proto.h", + "src/quic/core/proto/crypto_server_config_proto.h", + "src/quic/core/proto/source_address_token_proto.h", + "src/quic/core/qpack/qpack_blocking_manager.cc", + "src/quic/core/qpack/qpack_blocking_manager.h", + "src/quic/core/qpack/qpack_decoded_headers_accumulator.cc", + "src/quic/core/qpack/qpack_decoded_headers_accumulator.h", + "src/quic/core/qpack/qpack_decoder.cc", + "src/quic/core/qpack/qpack_decoder.h", + "src/quic/core/qpack/qpack_decoder_stream_receiver.cc", + "src/quic/core/qpack/qpack_decoder_stream_receiver.h", + "src/quic/core/qpack/qpack_decoder_stream_sender.cc", + "src/quic/core/qpack/qpack_decoder_stream_sender.h", + "src/quic/core/qpack/qpack_encoder.cc", + "src/quic/core/qpack/qpack_encoder.h", + "src/quic/core/qpack/qpack_encoder_stream_receiver.cc", + "src/quic/core/qpack/qpack_encoder_stream_receiver.h", + "src/quic/core/qpack/qpack_encoder_stream_sender.cc", + "src/quic/core/qpack/qpack_encoder_stream_sender.h", + "src/quic/core/qpack/qpack_header_table.cc", + "src/quic/core/qpack/qpack_header_table.h", + "src/quic/core/qpack/qpack_index_conversions.cc", + "src/quic/core/qpack/qpack_index_conversions.h", + "src/quic/core/qpack/qpack_instruction_decoder.cc", + "src/quic/core/qpack/qpack_instruction_decoder.h", + "src/quic/core/qpack/qpack_instruction_encoder.cc", + "src/quic/core/qpack/qpack_instruction_encoder.h", + "src/quic/core/qpack/qpack_instructions.cc", + "src/quic/core/qpack/qpack_instructions.h", + "src/quic/core/qpack/qpack_progressive_decoder.cc", + "src/quic/core/qpack/qpack_progressive_decoder.h", + "src/quic/core/qpack/qpack_receive_stream.cc", + "src/quic/core/qpack/qpack_receive_stream.h", + "src/quic/core/qpack/qpack_required_insert_count.cc", + "src/quic/core/qpack/qpack_required_insert_count.h", + "src/quic/core/qpack/qpack_send_stream.cc", + "src/quic/core/qpack/qpack_send_stream.h", + "src/quic/core/qpack/qpack_static_table.cc", + "src/quic/core/qpack/qpack_static_table.h", + "src/quic/core/qpack/qpack_stream_receiver.h", + "src/quic/core/qpack/qpack_stream_sender_delegate.h", + "src/quic/core/qpack/value_splitting_header_list.cc", + "src/quic/core/qpack/value_splitting_header_list.h", + "src/quic/core/quic_ack_listener_interface.cc", + "src/quic/core/quic_ack_listener_interface.h", + "src/quic/core/quic_alarm.cc", + "src/quic/core/quic_alarm.h", + "src/quic/core/quic_alarm_factory.h", + "src/quic/core/quic_arena_scoped_ptr.h", + "src/quic/core/quic_bandwidth.cc", + "src/quic/core/quic_bandwidth.h", + "src/quic/core/quic_blocked_writer_interface.h", + "src/quic/core/quic_buffer_allocator.cc", + "src/quic/core/quic_buffer_allocator.h", + "src/quic/core/quic_chaos_protector.cc", + "src/quic/core/quic_chaos_protector.h", + "src/quic/core/quic_clock.cc", + "src/quic/core/quic_clock.h", + "src/quic/core/quic_coalesced_packet.cc", + "src/quic/core/quic_coalesced_packet.h", + "src/quic/core/quic_config.cc", + "src/quic/core/quic_config.h", + "src/quic/core/quic_connection.cc", + "src/quic/core/quic_connection.h", + "src/quic/core/quic_connection_context.cc", + "src/quic/core/quic_connection_context.h", + "src/quic/core/quic_connection_id.cc", + "src/quic/core/quic_connection_id.h", + "src/quic/core/quic_connection_id_manager.cc", + "src/quic/core/quic_connection_id_manager.h", + "src/quic/core/quic_connection_stats.cc", + "src/quic/core/quic_connection_stats.h", + "src/quic/core/quic_constants.cc", + "src/quic/core/quic_constants.h", + "src/quic/core/quic_control_frame_manager.cc", + "src/quic/core/quic_control_frame_manager.h", + "src/quic/core/quic_crypto_client_handshaker.cc", + "src/quic/core/quic_crypto_client_handshaker.h", + "src/quic/core/quic_crypto_client_stream.cc", + "src/quic/core/quic_crypto_client_stream.h", + "src/quic/core/quic_crypto_handshaker.cc", + "src/quic/core/quic_crypto_handshaker.h", + "src/quic/core/quic_crypto_server_stream.cc", + "src/quic/core/quic_crypto_server_stream.h", + "src/quic/core/quic_crypto_server_stream_base.cc", + "src/quic/core/quic_crypto_server_stream_base.h", + "src/quic/core/quic_crypto_stream.cc", + "src/quic/core/quic_crypto_stream.h", + "src/quic/core/quic_data_reader.cc", + "src/quic/core/quic_data_reader.h", + "src/quic/core/quic_data_writer.cc", + "src/quic/core/quic_data_writer.h", + "src/quic/core/quic_datagram_queue.cc", + "src/quic/core/quic_datagram_queue.h", + "src/quic/core/quic_error_codes.cc", + "src/quic/core/quic_error_codes.h", + "src/quic/core/quic_flow_controller.cc", + "src/quic/core/quic_flow_controller.h", + "src/quic/core/quic_framer.cc", + "src/quic/core/quic_framer.h", + "src/quic/core/quic_idle_network_detector.cc", + "src/quic/core/quic_idle_network_detector.h", + "src/quic/core/quic_interval.h", + "src/quic/core/quic_interval_deque.h", + "src/quic/core/quic_interval_set.h", + "src/quic/core/quic_legacy_version_encapsulator.cc", + "src/quic/core/quic_legacy_version_encapsulator.h", + "src/quic/core/quic_lru_cache.h", + "src/quic/core/quic_mtu_discovery.cc", + "src/quic/core/quic_mtu_discovery.h", + "src/quic/core/quic_network_blackhole_detector.cc", + "src/quic/core/quic_network_blackhole_detector.h", + "src/quic/core/quic_one_block_arena.h", + "src/quic/core/quic_packet_creator.cc", + "src/quic/core/quic_packet_creator.h", + "src/quic/core/quic_packet_number.cc", + "src/quic/core/quic_packet_number.h", + "src/quic/core/quic_packet_writer.h", + "src/quic/core/quic_packets.cc", + "src/quic/core/quic_packets.h", + "src/quic/core/quic_path_validator.cc", + "src/quic/core/quic_path_validator.h", + "src/quic/core/quic_protocol_flags_list.h", + "src/quic/core/quic_received_packet_manager.cc", + "src/quic/core/quic_received_packet_manager.h", + "src/quic/core/quic_sent_packet_manager.cc", + "src/quic/core/quic_sent_packet_manager.h", + "src/quic/core/quic_server_id.cc", + "src/quic/core/quic_server_id.h", + "src/quic/core/quic_session.cc", + "src/quic/core/quic_session.h", + "src/quic/core/quic_simple_buffer_allocator.cc", + "src/quic/core/quic_simple_buffer_allocator.h", + "src/quic/core/quic_socket_address_coder.cc", + "src/quic/core/quic_socket_address_coder.h", + "src/quic/core/quic_stream.cc", + "src/quic/core/quic_stream.h", + "src/quic/core/quic_stream_frame_data_producer.h", + "src/quic/core/quic_stream_id_manager.cc", + "src/quic/core/quic_stream_id_manager.h", + "src/quic/core/quic_stream_send_buffer.cc", + "src/quic/core/quic_stream_send_buffer.h", + "src/quic/core/quic_stream_sequencer.cc", + "src/quic/core/quic_stream_sequencer.h", + "src/quic/core/quic_stream_sequencer_buffer.cc", + "src/quic/core/quic_stream_sequencer_buffer.h", + "src/quic/core/quic_sustained_bandwidth_recorder.cc", + "src/quic/core/quic_sustained_bandwidth_recorder.h", + "src/quic/core/quic_tag.cc", + "src/quic/core/quic_tag.h", + "src/quic/core/quic_time.cc", + "src/quic/core/quic_time.h", + "src/quic/core/quic_time_accumulator.h", + "src/quic/core/quic_transmission_info.cc", + "src/quic/core/quic_transmission_info.h", + "src/quic/core/quic_types.cc", + "src/quic/core/quic_types.h", + "src/quic/core/quic_unacked_packet_map.cc", + "src/quic/core/quic_unacked_packet_map.h", + "src/quic/core/quic_utils.cc", + "src/quic/core/quic_utils.h", + "src/quic/core/quic_version_manager.cc", + "src/quic/core/quic_version_manager.h", + "src/quic/core/quic_versions.cc", + "src/quic/core/quic_versions.h", + "src/quic/core/quic_write_blocked_list.cc", + "src/quic/core/quic_write_blocked_list.h", + "src/quic/core/session_notifier_interface.h", + "src/quic/core/stream_delegate_interface.h", + "src/quic/core/tls_client_handshaker.cc", + "src/quic/core/tls_client_handshaker.h", + "src/quic/core/tls_handshaker.cc", + "src/quic/core/tls_handshaker.h", + "src/quic/core/tls_server_handshaker.cc", + "src/quic/core/tls_server_handshaker.h", + "src/quic/core/uber_quic_stream_id_manager.cc", + "src/quic/core/uber_quic_stream_id_manager.h", + "src/quic/core/uber_received_packet_manager.cc", + "src/quic/core/uber_received_packet_manager.h", + "src/quic/core/web_transport_stream_adapter.cc", + "src/quic/core/web_transport_stream_adapter.h", + "src/quic/platform/api/quic_bug_tracker.h", + "src/quic/platform/api/quic_client_stats.h", + "src/quic/platform/api/quic_containers.h", + "src/quic/platform/api/quic_error_code_wrappers.h", + "src/quic/platform/api/quic_export.h", + "src/quic/platform/api/quic_exported_stats.h", + "src/quic/platform/api/quic_flag_utils.h", + "src/quic/platform/api/quic_flags.h", + "src/quic/platform/api/quic_hostname_utils.cc", + "src/quic/platform/api/quic_hostname_utils.h", + "src/quic/platform/api/quic_iovec.h", + "src/quic/platform/api/quic_ip_address.cc", + "src/quic/platform/api/quic_ip_address.h", + "src/quic/platform/api/quic_ip_address_family.h", + "src/quic/platform/api/quic_logging.h", + "src/quic/platform/api/quic_mem_slice.h", + "src/quic/platform/api/quic_mem_slice_span.h", + "src/quic/platform/api/quic_mem_slice_storage.h", + "src/quic/platform/api/quic_mutex.cc", + "src/quic/platform/api/quic_mutex.h", + "src/quic/platform/api/quic_reference_counted.h", + "src/quic/platform/api/quic_server_stats.h", + "src/quic/platform/api/quic_sleep.h", + "src/quic/platform/api/quic_socket_address.cc", + "src/quic/platform/api/quic_socket_address.h", + "src/quic/platform/api/quic_stack_trace.h", + "src/quic/platform/api/quic_thread.h", + "src/quic/quic_transport/quic_transport_client_session.cc", + "src/quic/quic_transport/quic_transport_client_session.h", + "src/quic/quic_transport/quic_transport_protocol.h", + "src/quic/quic_transport/quic_transport_server_session.cc", + "src/quic/quic_transport/quic_transport_server_session.h", + "src/quic/quic_transport/quic_transport_session_interface.h", + "src/quic/quic_transport/quic_transport_stream.cc", + "src/quic/quic_transport/quic_transport_stream.h", + "src/quic/quic_transport/web_transport_fingerprint_proof_verifier.cc", + "src/quic/quic_transport/web_transport_fingerprint_proof_verifier.h", + "src/spdy/core/hpack/hpack_constants.cc", + "src/spdy/core/hpack/hpack_constants.h", + "src/spdy/core/hpack/hpack_decoder_adapter.cc", + "src/spdy/core/hpack/hpack_decoder_adapter.h", + "src/spdy/core/hpack/hpack_encoder.cc", + "src/spdy/core/hpack/hpack_encoder.h", + "src/spdy/core/hpack/hpack_entry.cc", + "src/spdy/core/hpack/hpack_entry.h", + "src/spdy/core/hpack/hpack_header_table.cc", + "src/spdy/core/hpack/hpack_header_table.h", + "src/spdy/core/hpack/hpack_output_stream.cc", + "src/spdy/core/hpack/hpack_output_stream.h", + "src/spdy/core/hpack/hpack_static_table.cc", + "src/spdy/core/hpack/hpack_static_table.h", + "src/spdy/core/http2_frame_decoder_adapter.cc", + "src/spdy/core/http2_frame_decoder_adapter.h", + "src/spdy/core/recording_headers_handler.cc", + "src/spdy/core/recording_headers_handler.h", + "src/spdy/core/spdy_alt_svc_wire_format.cc", + "src/spdy/core/spdy_alt_svc_wire_format.h", + "src/spdy/core/spdy_bitmasks.h", + "src/spdy/core/spdy_frame_builder.cc", + "src/spdy/core/spdy_frame_builder.h", + "src/spdy/core/spdy_frame_reader.cc", + "src/spdy/core/spdy_frame_reader.h", + "src/spdy/core/spdy_framer.cc", + "src/spdy/core/spdy_framer.h", + "src/spdy/core/spdy_header_block.cc", + "src/spdy/core/spdy_header_block.h", + "src/spdy/core/spdy_header_storage.cc", + "src/spdy/core/spdy_header_storage.h", + "src/spdy/core/spdy_headers_handler_interface.h", + "src/spdy/core/spdy_intrusive_list.h", + "src/spdy/core/spdy_no_op_visitor.cc", + "src/spdy/core/spdy_no_op_visitor.h", + "src/spdy/core/spdy_pinnable_buffer_piece.cc", + "src/spdy/core/spdy_pinnable_buffer_piece.h", + "src/spdy/core/spdy_prefixed_buffer_reader.cc", + "src/spdy/core/spdy_prefixed_buffer_reader.h", + "src/spdy/core/spdy_protocol.cc", + "src/spdy/core/spdy_protocol.h", + "src/spdy/core/spdy_simple_arena.cc", + "src/spdy/core/spdy_simple_arena.h", + "src/spdy/core/zero_copy_output_buffer.h", + ] + deps = [ "//net:net_deps" ] + public_deps = [ "//net:net_public_deps" ] +} diff --git a/third_party/quiche/README.chromium b/third_party/quiche/README.chromium new file mode 100644 index 00000000..065860ff --- /dev/null +++ b/third_party/quiche/README.chromium @@ -0,0 +1,9 @@ +Name: QUICHE +URL: https://quiche.googlesource.com/quiche +Version: git +License: BSD +License File: src/LICENSE +Security Critical: yes + +Description: +This is QUICHE, Google's implementation of QUIC, HTTP/2 and SPDY protocols. diff --git a/third_party/valijson/BUILD.gn b/third_party/valijson/BUILD.gn index 8df84598..b7286470 100644 --- a/third_party/valijson/BUILD.gn +++ b/third_party/valijson/BUILD.gn @@ -22,7 +22,7 @@ if (!build_with_chromium) { # We only need the adapter for JsonCpp. "src/include/valijson/adapters/jsoncpp_adapter.hpp", - "src/include/valijson/constraints_builder.hpp", + "src/include/valijson/constraint_builder.hpp", "src/include/valijson/internal/custom_allocator.hpp", "src/include/valijson/internal/debug.hpp", "src/include/valijson/internal/json_pointer.hpp", diff --git a/tools/cddl/cddl.py b/tools/cddl/cddl.py index 28436f27..8625a9ae 100644 --- a/tools/cddl/cddl.py +++ b/tools/cddl/cddl.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + # Copyright 2018 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/tools/cddl/codegen.cc b/tools/cddl/codegen.cc index ac58eed9..a763b721 100644 --- a/tools/cddl/codegen.cc +++ b/tools/cddl/codegen.cc @@ -7,6 +7,7 @@ #include <cinttypes> #include <iostream> #include <limits> +#include <memory> #include <set> #include <sstream> #include <string> @@ -381,7 +382,7 @@ bool EnsureDependentTypeDefinitionsWritten(int fd, case CppType::Which::kVector: { return EnsureDependentTypeDefinitionsWritten( fd, *cpp_type.vector_type.element_type, defs); - } break; + } case CppType::Which::kEnum: { if (defs->find(cpp_type.name) != defs->end()) return true; @@ -406,7 +407,7 @@ bool EnsureDependentTypeDefinitionsWritten(int fd, case CppType::Which::kOptional: { return EnsureDependentTypeDefinitionsWritten(fd, *cpp_type.optional_type, defs); - } break; + } case CppType::Which::kDiscriminatedUnion: { for (const auto* x : cpp_type.discriminated_union.members) if (!EnsureDependentTypeDefinitionsWritten(fd, *x, defs)) @@ -559,12 +560,10 @@ bool WriteEncoder(int fd, } return true; } - break; case CppType::Which::kUint64: dprintf(fd, " CBOR_RETURN_ON_ERROR(cbor_encode_uint(&encoder%d, %s));\n", encoder_depth, ToUnderscoreId(name).c_str()); return true; - break; case CppType::Which::kString: { std::string cid = ToUnderscoreId(name); dprintf(fd, " if (!IsValidUtf8(%s)) {\n", cid.c_str()); @@ -575,7 +574,7 @@ bool WriteEncoder(int fd, "%s.c_str(), %s.size()));\n", encoder_depth, cid.c_str(), cid.c_str()); return true; - } break; + } case CppType::Which::kBytes: { std::string cid = ToUnderscoreId(name); dprintf(fd, @@ -584,7 +583,7 @@ bool WriteEncoder(int fd, "%s.size()));\n", encoder_depth, cid.c_str(), cid.c_str()); return true; - } break; + } case CppType::Which::kVector: { std::string cid = ToUnderscoreId(name); dprintf(fd, " {\n"); @@ -619,14 +618,14 @@ bool WriteEncoder(int fd, encoder_depth, encoder_depth + 1); dprintf(fd, " }\n"); return true; - } break; + } case CppType::Which::kEnum: { dprintf(fd, " CBOR_RETURN_ON_ERROR(cbor_encode_uint(&encoder%d, " "static_cast<uint64_t>(%s)));\n", encoder_depth, ToUnderscoreId(name).c_str()); return true; - } break; + } case CppType::Which::kDiscriminatedUnion: { for (const auto* union_member : cpp_type.discriminated_union.members) { switch (union_member->which) { @@ -670,7 +669,7 @@ bool WriteEncoder(int fd, ToCamelCase(cpp_type.name).c_str()); dprintf(fd, " return -CborUnknownError;\n"); return true; - } break; + } case CppType::Which::kTaggedType: { dprintf(fd, " CBOR_RETURN_ON_ERROR(cbor_encode_tag(&encoder%d, %" PRIu64 @@ -681,7 +680,7 @@ bool WriteEncoder(int fd, return false; } return true; - } break; + } default: break; } @@ -1042,7 +1041,7 @@ bool WriteDecoder(int fd, dprintf(fd, " CBOR_RETURN_ON_ERROR(cbor_value_advance_fixed(&it%d));\n", decoder_depth); return true; - } break; + } case CppType::Which::kString: { int temp_length = (*temporary_count)++; dprintf(fd, " size_t length%d = 0;", temp_length); @@ -1073,7 +1072,7 @@ bool WriteDecoder(int fd, dprintf(fd, " CBOR_RETURN_ON_ERROR(cbor_value_advance(&it%d));\n", decoder_depth); return true; - } break; + } case CppType::Which::kBytes: { int temp_length = (*temporary_count)++; dprintf(fd, " size_t length%d = 0;", temp_length); @@ -1110,7 +1109,7 @@ bool WriteDecoder(int fd, dprintf(fd, " CBOR_RETURN_ON_ERROR(cbor_value_advance(&it%d));\n", decoder_depth); return true; - } break; + } case CppType::Which::kVector: { dprintf(fd, " if (cbor_value_get_type(&it%d) != CborArrayType) {\n", decoder_depth); @@ -1157,7 +1156,7 @@ bool WriteDecoder(int fd, decoder_depth, decoder_depth + 1); dprintf(fd, " }\n"); return true; - } break; + } case CppType::Which::kEnum: { dprintf(fd, " CBOR_RETURN_ON_ERROR(cbor_value_get_uint64(&it%d, " @@ -1167,7 +1166,7 @@ bool WriteDecoder(int fd, decoder_depth); // TODO(btolsch): Validate against enum members. return true; - } break; + } case CppType::Which::kStruct: { if (cpp_type.struct_type.key_type == CppType::Struct::KeyType::kMap) { return WriteMapDecoder(fd, name, member_accessor, @@ -1235,7 +1234,7 @@ bool WriteDecoder(int fd, } dprintf(fd, " else { return -1; }\n"); return true; - } break; + } case CppType::Which::kTaggedType: { int temp_tag = (*temporary_count)++; dprintf(fd, " uint64_t tag%d = 0;\n", temp_tag); @@ -1253,7 +1252,7 @@ bool WriteDecoder(int fd, return false; } return true; - } break; + } default: break; } @@ -1586,15 +1585,16 @@ namespace openscreen { namespace msgs { namespace { +/* + * Encoder-specific errors, so it's fine to check these even in the + * parser. + */ #define CBOR_RETURN_WHAT_ON_ERROR(stmt, what) \ { \ CborError error = stmt; \ - /* Encoder-specific errors, so it's fine to check these even in the \ - * parser. \ - */ \ - OSP_DCHECK_NE(error, CborErrorTooFewItems); \ - OSP_DCHECK_NE(error, CborErrorTooManyItems); \ - OSP_DCHECK_NE(error, CborErrorDataTooLarge); \ + OSP_DCHECK_NE(error, CborErrorTooFewItems); \ + OSP_DCHECK_NE(error, CborErrorTooManyItems); \ + OSP_DCHECK_NE(error, CborErrorDataTooLarge); \ if (error != CborNoError && error != CborErrorOutOfMemory) \ return what; \ } diff --git a/tools/cddl/parse.cc b/tools/cddl/parse.cc index eef62976..439df8ff 100644 --- a/tools/cddl/parse.cc +++ b/tools/cddl/parse.cc @@ -10,6 +10,7 @@ #include <iostream> #include <memory> #include <sstream> +#include <utility> #include <vector> #include "absl/strings/ascii.h" @@ -427,7 +428,6 @@ AstNode* ParseGroupChoice(Parser* p) { return nullptr; } } - return nullptr; } AstNode* ParseGroup(Parser* p) { @@ -974,7 +974,7 @@ ParseResult ParseCddl(absl::string_view data) { if (data[0] == 0) { return {nullptr, {}}; } - Parser p{(char*)data.data()}; + Parser p{data.data()}; SkipWhitespace(&p); AstNode* root = nullptr; diff --git a/tools/cddl/sema.cc b/tools/cddl/sema.cc index 2ecb162a..584fe56f 100644 --- a/tools/cddl/sema.cc +++ b/tools/cddl/sema.cc @@ -722,7 +722,7 @@ CppType* MakeCppType(CppSymbolTable* table, cpp_type->discriminated_union.members.push_back(member); } return cpp_type; - } break; + } case CddlType::Which::kTaggedType: { cpp_type = GetCppType(table, name); cpp_type->which = CppType::Which::kTaggedType; diff --git a/tools/convert_to_data_file.py b/tools/convert_to_data_file.py index 05456f08..c6af0da8 100755 --- a/tools/convert_to_data_file.py +++ b/tools/convert_to_data_file.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2020 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/tools/curlish.py b/tools/curlish.py index c0324b02..956c5c43 100755 --- a/tools/curlish.py +++ b/tools/curlish.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2020 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/tools/download-clang-update-script.py b/tools/download-clang-update-script.py index f3534a10..203862ec 100755 --- a/tools/download-clang-update-script.py +++ b/tools/download-clang-update-script.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2020 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. @@ -14,8 +14,8 @@ import argparse import curlish import sys -SCRIPT_DOWNLOAD_URL = ('https://raw.githubusercontent.com/chromium/' + - 'chromium/master/tools/clang/scripts/update.py') +SCRIPT_DOWNLOAD_URL = ('https://raw.githubusercontent.com/chromium/' + 'chromium/main/tools/clang/scripts/update.py') def main(): diff --git a/tools/download-yajsv.py b/tools/download-yajsv.py index d42d3f3a..b127b1ff 100755 --- a/tools/download-yajsv.py +++ b/tools/download-yajsv.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # Copyright 2020 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/util/BUILD.gn b/util/BUILD.gn index 3f97e09e..90a7fe39 100644 --- a/util/BUILD.gn +++ b/util/BUILD.gn @@ -17,31 +17,14 @@ config("trace_logging_config") { } } -source_set("util") { +# The set of util classes which have no dependency on platform:api. +source_set("base") { sources = [ - "alarm.cc", - "alarm.h", "base64.cc", "base64.h", "big_endian.cc", "big_endian.h", "chrono_helpers.h", - "crypto/certificate_utils.cc", - "crypto/certificate_utils.h", - "crypto/digest_sign.cc", - "crypto/digest_sign.h", - "crypto/openssl_util.cc", - "crypto/openssl_util.h", - "crypto/pem_helpers.cc", - "crypto/pem_helpers.h", - "crypto/random_bytes.cc", - "crypto/random_bytes.h", - "crypto/rsa_private_key.cc", - "crypto/rsa_private_key.h", - "crypto/secure_hash.cc", - "crypto/secure_hash.h", - "crypto/sha2.cc", - "crypto/sha2.h", "enum_name_table.h", "flat_map.h", "hashing.h", @@ -59,10 +42,6 @@ source_set("util") { "std_util.h", "stringprintf.cc", "stringprintf.h", - "trace_logging.h", - "trace_logging/macro_support.h", - "trace_logging/scoped_trace_operations.cc", - "trace_logging/scoped_trace_operations.h", "url.cc", "url.h", "weak_ptr.h", @@ -71,14 +50,13 @@ source_set("util") { ] public_deps = [ - "../platform:api", "../platform:base", + "../platform:logging", "../third_party/abseil", "../third_party/jsoncpp", ] deps = [ - "../third_party/boringssl", "../third_party/mozilla", # We do a clone of Chrome's modp_b64 in order to share their BUILD.gn @@ -86,6 +64,45 @@ source_set("util") { "//third_party/modp_b64", ] + public_configs = [ "../build:openscreen_include_dirs" ] +} + +source_set("util") { + sources = [ + "alarm.cc", + "alarm.h", + "crypto/certificate_utils.cc", + "crypto/certificate_utils.h", + "crypto/digest_sign.cc", + "crypto/digest_sign.h", + "crypto/openssl_util.cc", + "crypto/openssl_util.h", + "crypto/pem_helpers.cc", + "crypto/pem_helpers.h", + "crypto/random_bytes.cc", + "crypto/random_bytes.h", + "crypto/rsa_private_key.cc", + "crypto/rsa_private_key.h", + "crypto/secure_hash.cc", + "crypto/secure_hash.h", + "crypto/sha2.cc", + "crypto/sha2.h", + "trace_logging.h", + "trace_logging/macro_support.h", + "trace_logging/scoped_trace_operations.cc", + "trace_logging/scoped_trace_operations.h", + ] + + public_deps = [ + ":base", + "../platform:api", + "../platform:base", + "../third_party/abseil", + "../third_party/jsoncpp", + ] + + deps = [ "../third_party/boringssl" ] + public_configs = [ "../build:openscreen_include_dirs", ":trace_logging_config", diff --git a/util/base64.cc b/util/base64.cc index 06e120e0..64e34175 100644 --- a/util/base64.cc +++ b/util/base64.cc @@ -6,6 +6,10 @@ #include <stddef.h> +#include <string> +#include <utility> +#include <vector> + #include "third_party/modp_b64/modp_b64.h" #include "util/osp_logging.h" #include "util/std_util.h" @@ -33,20 +37,18 @@ std::string Encode(absl::string_view input) { return out; } -bool Decode(absl::string_view input, std::string* output) { - std::string out; - out.resize(modp_b64_decode_len(input.size())); +bool Decode(absl::string_view input, std::vector<uint8_t>* output) { + std::vector<uint8_t> out(modp_b64_decode_len(input.size())); - // We don't null terminate the result since it is binary data. - const size_t output_size = - modp_b64_decode(data(out), input.data(), input.size()); + const size_t output_size = modp_b64_decode( + reinterpret_cast<char*>(out.data()), input.data(), input.size()); if (output_size == MODP_B64_ERROR) { return false; } // The output size from decode_len is generally larger than needed. out.resize(output_size); - output->swap(out); + *output = std::move(out); return true; } diff --git a/util/base64.h b/util/base64.h index b24c3b3f..a7af9eca 100644 --- a/util/base64.h +++ b/util/base64.h @@ -6,6 +6,7 @@ #define UTIL_BASE64_H_ #include <string> +#include <vector> #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -23,7 +24,7 @@ std::string Encode(absl::string_view input); // Decodes the base64 input string. Returns true if successful and false // otherwise. The output string is only modified if successful. The decoding can // be done in-place. -bool Decode(absl::string_view input, std::string* output); +bool Decode(absl::string_view input, std::vector<uint8_t>* output); } // namespace base64 } // namespace openscreen diff --git a/util/base64_unittest.cc b/util/base64_unittest.cc index 28d4fb1d..873b5656 100644 --- a/util/base64_unittest.cc +++ b/util/base64_unittest.cc @@ -4,6 +4,9 @@ #include "util/base64.h" +#include <string> +#include <vector> + #include "gtest/gtest.h" namespace openscreen { @@ -14,13 +17,21 @@ namespace { constexpr char kText[] = "hello world"; constexpr char kBase64Text[] = "aGVsbG8gd29ybGQ="; +// More sophisticated comparisons here, such as EXPECT_STREQ, may +// cause memory failures on some platforms (e.g. ASAN) due to mismatched +// lengths. +void CheckEquals(const char* expected, const std::vector<uint8_t>& actual) { + EXPECT_EQ(0, std::memcmp(actual.data(), expected, actual.size())); +} + void CheckEncodeDecode(const char* to_encode, const char* encode_expected) { std::string encoded = Encode(to_encode); EXPECT_EQ(encode_expected, encoded); - std::string decoded; + std::vector<uint8_t> decoded; EXPECT_TRUE(Decode(encoded, &decoded)); - EXPECT_EQ(to_encode, decoded); + + CheckEquals(to_encode, decoded); } } // namespace @@ -52,8 +63,9 @@ TEST(Base64Test, InPlace) { text = Encode(text); EXPECT_EQ(kBase64Text, text); - EXPECT_TRUE(Decode(text, &text)); - EXPECT_EQ(text, kText); + std::vector<uint8_t> out; + EXPECT_TRUE(Decode(text, &out)); + CheckEquals(kText, out); } } // namespace base64 diff --git a/util/json/json_helpers.h b/util/json/json_helpers.h index 1943973d..ebd25add 100644 --- a/util/json/json_helpers.h +++ b/util/json/json_helpers.h @@ -16,6 +16,7 @@ #include "json/value.h" #include "platform/base/error.h" #include "util/chrono_helpers.h" +#include "util/json/json_serialization.h" #include "util/simple_fraction.h" // This file contains helper methods for parsing JSON, in an attempt to @@ -23,53 +24,7 @@ namespace openscreen { namespace json { -// TODO(jophba): remove these methods after refactoring offer messaging. -inline Error CreateParseError(const std::string& type) { - return Error(Error::Code::kJsonParseError, "Failed to parse " + type); -} - -inline Error CreateParameterError(const std::string& type) { - return Error(Error::Code::kParameterInvalid, "Invalid parameter: " + type); -} - -inline ErrorOr<bool> ParseBool(const Json::Value& parent, - const std::string& field) { - const Json::Value& value = parent[field]; - if (!value.isBool()) { - return CreateParseError("bool field " + field); - } - return value.asBool(); -} - -inline ErrorOr<int> ParseInt(const Json::Value& parent, - const std::string& field) { - const Json::Value& value = parent[field]; - if (!value.isInt()) { - return CreateParseError("integer field: " + field); - } - return value.asInt(); -} - -inline ErrorOr<uint32_t> ParseUint(const Json::Value& parent, - const std::string& field) { - const Json::Value& value = parent[field]; - if (!value.isUInt()) { - return CreateParseError("unsigned integer field: " + field); - } - return value.asUInt(); -} - -inline ErrorOr<std::string> ParseString(const Json::Value& parent, - const std::string& field) { - const Json::Value& value = parent[field]; - if (!value.isString()) { - return CreateParseError("string field: " + field); - } - return value.asString(); -} - -// TODO(jophba): offer messaging should use these methods instead. -inline bool ParseBool(const Json::Value& value, bool* out) { +inline bool TryParseBool(const Json::Value& value, bool* out) { if (!value.isBool()) { return false; } @@ -80,9 +35,9 @@ inline bool ParseBool(const Json::Value& value, bool* out) { // A general note about parsing primitives. "Validation" in this context // generally means ensuring that the values are non-negative, excepting doubles // which may be negative in some cases. -inline bool ParseAndValidateDouble(const Json::Value& value, - double* out, - bool allow_negative = false) { +inline bool TryParseDouble(const Json::Value& value, + double* out, + bool allow_negative = false) { if (!value.isDouble()) { return false; } @@ -97,7 +52,7 @@ inline bool ParseAndValidateDouble(const Json::Value& value, return true; } -inline bool ParseAndValidateInt(const Json::Value& value, int* out) { +inline bool TryParseInt(const Json::Value& value, int* out) { if (!value.isInt()) { return false; } @@ -109,7 +64,7 @@ inline bool ParseAndValidateInt(const Json::Value& value, int* out) { return true; } -inline bool ParseAndValidateUint(const Json::Value& value, uint32_t* out) { +inline bool TryParseUint(const Json::Value& value, uint32_t* out) { if (!value.isUInt()) { return false; } @@ -117,7 +72,7 @@ inline bool ParseAndValidateUint(const Json::Value& value, uint32_t* out) { return true; } -inline bool ParseAndValidateString(const Json::Value& value, std::string* out) { +inline bool TryParseString(const Json::Value& value, std::string* out) { if (!value.isString()) { return false; } @@ -128,8 +83,8 @@ inline bool ParseAndValidateString(const Json::Value& value, std::string* out) { // We want to be more robust when we parse fractions then just // allowing strings, this will parse numeral values such as // value: 50 as well as value: "50" and value: "100/2". -inline bool ParseAndValidateSimpleFraction(const Json::Value& value, - SimpleFraction* out) { +inline bool TryParseSimpleFraction(const Json::Value& value, + SimpleFraction* out) { if (value.isInt()) { int parsed = value.asInt(); if (parsed < 0) { @@ -155,10 +110,9 @@ inline bool ParseAndValidateSimpleFraction(const Json::Value& value, return false; } -inline bool ParseAndValidateMilliseconds(const Json::Value& value, - milliseconds* out) { +inline bool TryParseMilliseconds(const Json::Value& value, milliseconds* out) { int out_ms; - if (!ParseAndValidateInt(value, &out_ms) || out_ms < 0) { + if (!TryParseInt(value, &out_ms) || out_ms < 0) { return false; } *out = milliseconds(out_ms); @@ -171,9 +125,9 @@ using Parser = std::function<bool(const Json::Value&, T*)>; // NOTE: array parsing methods reset the output vector to an empty vector in // any error case. This is especially useful for optional arrays. template <typename T> -bool ParseAndValidateArray(const Json::Value& value, - Parser<T> parser, - std::vector<T>* out) { +bool TryParseArray(const Json::Value& value, + Parser<T> parser, + std::vector<T>* out) { out->clear(); if (!value.isArray() || value.empty()) { return false; @@ -192,19 +146,18 @@ bool ParseAndValidateArray(const Json::Value& value, return true; } -inline bool ParseAndValidateIntArray(const Json::Value& value, - std::vector<int>* out) { - return ParseAndValidateArray<int>(value, ParseAndValidateInt, out); +inline bool TryParseIntArray(const Json::Value& value, std::vector<int>* out) { + return TryParseArray<int>(value, TryParseInt, out); } -inline bool ParseAndValidateUintArray(const Json::Value& value, - std::vector<uint32_t>* out) { - return ParseAndValidateArray<uint32_t>(value, ParseAndValidateUint, out); +inline bool TryParseUintArray(const Json::Value& value, + std::vector<uint32_t>* out) { + return TryParseArray<uint32_t>(value, TryParseUint, out); } -inline bool ParseAndValidateStringArray(const Json::Value& value, - std::vector<std::string>* out) { - return ParseAndValidateArray<std::string>(value, ParseAndValidateString, out); +inline bool TryParseStringArray(const Json::Value& value, + std::vector<std::string>* out) { + return TryParseArray<std::string>(value, TryParseString, out); } } // namespace json diff --git a/util/json/json_helpers_unittest.cc b/util/json/json_helpers_unittest.cc index c461cf93..eb05d3f6 100644 --- a/util/json/json_helpers_unittest.cc +++ b/util/json/json_helpers_unittest.cc @@ -26,9 +26,9 @@ struct Dummy { } }; -bool ParseAndValidateDummy(const Json::Value& value, Dummy* out) { +bool TryParseDummy(const Json::Value& value, Dummy* out) { int value_out; - if (!ParseAndValidateInt(value, &value_out)) { + if (!TryParseInt(value, &value_out)) { return false; } *out = Dummy{value_out}; @@ -37,7 +37,7 @@ bool ParseAndValidateDummy(const Json::Value& value, Dummy* out) { } // namespace -TEST(ParsingHelpersTest, ParseAndValidateDouble) { +TEST(ParsingHelpersTest, TryParseDouble) { const Json::Value kValid = 13.37; const Json::Value kNotDouble = "coffee beans"; const Json::Value kNegativeDouble = -4.2; @@ -45,62 +45,62 @@ TEST(ParsingHelpersTest, ParseAndValidateDouble) { const Json::Value kNanDouble = std::nan(""); double out; - EXPECT_TRUE(ParseAndValidateDouble(kValid, &out)); + EXPECT_TRUE(TryParseDouble(kValid, &out)); EXPECT_DOUBLE_EQ(13.37, out); - EXPECT_TRUE(ParseAndValidateDouble(kZeroDouble, &out)); + EXPECT_TRUE(TryParseDouble(kZeroDouble, &out)); EXPECT_DOUBLE_EQ(0.0, out); - EXPECT_FALSE(ParseAndValidateDouble(kNotDouble, &out)); - EXPECT_FALSE(ParseAndValidateDouble(kNegativeDouble, &out)); - EXPECT_FALSE(ParseAndValidateDouble(kNone, &out)); - EXPECT_FALSE(ParseAndValidateDouble(kNanDouble, &out)); + EXPECT_FALSE(TryParseDouble(kNotDouble, &out)); + EXPECT_FALSE(TryParseDouble(kNegativeDouble, &out)); + EXPECT_FALSE(TryParseDouble(kNone, &out)); + EXPECT_FALSE(TryParseDouble(kNanDouble, &out)); } -TEST(ParsingHelpersTest, ParseAndValidateInt) { +TEST(ParsingHelpersTest, TryParseInt) { const Json::Value kValid = 1337; const Json::Value kNotInt = "cold brew"; const Json::Value kNegativeInt = -42; const Json::Value kZeroInt = 0; int out; - EXPECT_TRUE(ParseAndValidateInt(kValid, &out)); + EXPECT_TRUE(TryParseInt(kValid, &out)); EXPECT_EQ(1337, out); - EXPECT_TRUE(ParseAndValidateInt(kZeroInt, &out)); + EXPECT_TRUE(TryParseInt(kZeroInt, &out)); EXPECT_EQ(0, out); - EXPECT_FALSE(ParseAndValidateInt(kNone, &out)); - EXPECT_FALSE(ParseAndValidateInt(kNotInt, &out)); - EXPECT_FALSE(ParseAndValidateInt(kNegativeInt, &out)); + EXPECT_FALSE(TryParseInt(kNone, &out)); + EXPECT_FALSE(TryParseInt(kNotInt, &out)); + EXPECT_FALSE(TryParseInt(kNegativeInt, &out)); } -TEST(ParsingHelpersTest, ParseAndValidateUint) { +TEST(ParsingHelpersTest, TryParseUint) { const Json::Value kValid = 1337u; const Json::Value kNotUint = "espresso"; const Json::Value kZeroUint = 0u; uint32_t out; - EXPECT_TRUE(ParseAndValidateUint(kValid, &out)); + EXPECT_TRUE(TryParseUint(kValid, &out)); EXPECT_EQ(1337u, out); - EXPECT_TRUE(ParseAndValidateUint(kZeroUint, &out)); + EXPECT_TRUE(TryParseUint(kZeroUint, &out)); EXPECT_EQ(0u, out); - EXPECT_FALSE(ParseAndValidateUint(kNone, &out)); - EXPECT_FALSE(ParseAndValidateUint(kNotUint, &out)); + EXPECT_FALSE(TryParseUint(kNone, &out)); + EXPECT_FALSE(TryParseUint(kNotUint, &out)); } -TEST(ParsingHelpersTest, ParseAndValidateString) { +TEST(ParsingHelpersTest, TryParseString) { const Json::Value kValid = "macchiato"; const Json::Value kNotString = 42; std::string out; - EXPECT_TRUE(ParseAndValidateString(kValid, &out)); + EXPECT_TRUE(TryParseString(kValid, &out)); EXPECT_EQ("macchiato", out); - EXPECT_TRUE(ParseAndValidateString(kEmptyString, &out)); + EXPECT_TRUE(TryParseString(kEmptyString, &out)); EXPECT_EQ("", out); - EXPECT_FALSE(ParseAndValidateString(kNone, &out)); - EXPECT_FALSE(ParseAndValidateString(kNotString, &out)); + EXPECT_FALSE(TryParseString(kNone, &out)); + EXPECT_FALSE(TryParseString(kNotString, &out)); } // Simple fraction validity is tested extensively in its unit tests, so we // just check the major cases here. -TEST(ParsingHelpersTest, ParseAndValidateSimpleFraction) { +TEST(ParsingHelpersTest, TryParseSimpleFraction) { const Json::Value kValid = "42/30"; const Json::Value kValidNumber = "42"; const Json::Value kUndefined = "5/0"; @@ -111,22 +111,22 @@ TEST(ParsingHelpersTest, ParseAndValidateSimpleFraction) { const Json::Value kNegativeInteger = -5000; SimpleFraction out; - EXPECT_TRUE(ParseAndValidateSimpleFraction(kValid, &out)); + EXPECT_TRUE(TryParseSimpleFraction(kValid, &out)); EXPECT_EQ((SimpleFraction{42, 30}), out); - EXPECT_TRUE(ParseAndValidateSimpleFraction(kValidNumber, &out)); + EXPECT_TRUE(TryParseSimpleFraction(kValidNumber, &out)); EXPECT_EQ((SimpleFraction{42, 1}), out); - EXPECT_TRUE(ParseAndValidateSimpleFraction(kInteger, &out)); + EXPECT_TRUE(TryParseSimpleFraction(kInteger, &out)); EXPECT_EQ((SimpleFraction{123, 1}), out); - EXPECT_FALSE(ParseAndValidateSimpleFraction(kUndefined, &out)); - EXPECT_FALSE(ParseAndValidateSimpleFraction(kNegative, &out)); - EXPECT_FALSE(ParseAndValidateSimpleFraction(kInvalidNumber, &out)); - EXPECT_FALSE(ParseAndValidateSimpleFraction(kNotSimpleFraction, &out)); - EXPECT_FALSE(ParseAndValidateSimpleFraction(kNone, &out)); - EXPECT_FALSE(ParseAndValidateSimpleFraction(kEmptyString, &out)); - EXPECT_FALSE(ParseAndValidateSimpleFraction(kNegativeInteger, &out)); + EXPECT_FALSE(TryParseSimpleFraction(kUndefined, &out)); + EXPECT_FALSE(TryParseSimpleFraction(kNegative, &out)); + EXPECT_FALSE(TryParseSimpleFraction(kInvalidNumber, &out)); + EXPECT_FALSE(TryParseSimpleFraction(kNotSimpleFraction, &out)); + EXPECT_FALSE(TryParseSimpleFraction(kNone, &out)); + EXPECT_FALSE(TryParseSimpleFraction(kEmptyString, &out)); + EXPECT_FALSE(TryParseSimpleFraction(kNegativeInteger, &out)); } -TEST(ParsingHelpersTest, ParseAndValidateMilliseconds) { +TEST(ParsingHelpersTest, TryParseMilliseconds) { const Json::Value kValid = 1000; const Json::Value kValidFloat = 500.0; const Json::Value kNegativeNumber = -120; @@ -134,18 +134,18 @@ TEST(ParsingHelpersTest, ParseAndValidateMilliseconds) { const Json::Value kNotNumber = "affogato"; milliseconds out; - EXPECT_TRUE(ParseAndValidateMilliseconds(kValid, &out)); + EXPECT_TRUE(TryParseMilliseconds(kValid, &out)); EXPECT_EQ(milliseconds(1000), out); - EXPECT_TRUE(ParseAndValidateMilliseconds(kValidFloat, &out)); + EXPECT_TRUE(TryParseMilliseconds(kValidFloat, &out)); EXPECT_EQ(milliseconds(500), out); - EXPECT_TRUE(ParseAndValidateMilliseconds(kZeroNumber, &out)); + EXPECT_TRUE(TryParseMilliseconds(kZeroNumber, &out)); EXPECT_EQ(milliseconds(0), out); - EXPECT_FALSE(ParseAndValidateMilliseconds(kNone, &out)); - EXPECT_FALSE(ParseAndValidateMilliseconds(kNegativeNumber, &out)); - EXPECT_FALSE(ParseAndValidateMilliseconds(kNotNumber, &out)); + EXPECT_FALSE(TryParseMilliseconds(kNone, &out)); + EXPECT_FALSE(TryParseMilliseconds(kNegativeNumber, &out)); + EXPECT_FALSE(TryParseMilliseconds(kNotNumber, &out)); } -TEST(ParsingHelpersTest, ParseAndValidateArray) { +TEST(ParsingHelpersTest, TryParseArray) { Json::Value valid_dummy_array; valid_dummy_array[0] = 123; valid_dummy_array[1] = 456; @@ -155,16 +155,13 @@ TEST(ParsingHelpersTest, ParseAndValidateArray) { invalid_dummy_array[1] = 456; std::vector<Dummy> out; - EXPECT_TRUE(ParseAndValidateArray<Dummy>(valid_dummy_array, - ParseAndValidateDummy, &out)); + EXPECT_TRUE(TryParseArray<Dummy>(valid_dummy_array, TryParseDummy, &out)); EXPECT_THAT(out, ElementsAre(Dummy{123}, Dummy{456})); - EXPECT_FALSE(ParseAndValidateArray<Dummy>(invalid_dummy_array, - ParseAndValidateDummy, &out)); - EXPECT_FALSE( - ParseAndValidateArray<Dummy>(kEmptyArray, ParseAndValidateDummy, &out)); + EXPECT_FALSE(TryParseArray<Dummy>(invalid_dummy_array, TryParseDummy, &out)); + EXPECT_FALSE(TryParseArray<Dummy>(kEmptyArray, TryParseDummy, &out)); } -TEST(ParsingHelpersTest, ParseAndValidateIntArray) { +TEST(ParsingHelpersTest, TryParseIntArray) { Json::Value valid_int_array; valid_int_array[0] = 123; valid_int_array[1] = 456; @@ -174,13 +171,13 @@ TEST(ParsingHelpersTest, ParseAndValidateIntArray) { invalid_int_array[1] = 456; std::vector<int> out; - EXPECT_TRUE(ParseAndValidateIntArray(valid_int_array, &out)); + EXPECT_TRUE(TryParseIntArray(valid_int_array, &out)); EXPECT_THAT(out, ElementsAre(123, 456)); - EXPECT_FALSE(ParseAndValidateIntArray(invalid_int_array, &out)); - EXPECT_FALSE(ParseAndValidateIntArray(kEmptyArray, &out)); + EXPECT_FALSE(TryParseIntArray(invalid_int_array, &out)); + EXPECT_FALSE(TryParseIntArray(kEmptyArray, &out)); } -TEST(ParsingHelpersTest, ParseAndValidateUintArray) { +TEST(ParsingHelpersTest, TryParseUintArray) { Json::Value valid_uint_array; valid_uint_array[0] = 123u; valid_uint_array[1] = 456u; @@ -190,13 +187,13 @@ TEST(ParsingHelpersTest, ParseAndValidateUintArray) { invalid_uint_array[1] = 456u; std::vector<uint32_t> out; - EXPECT_TRUE(ParseAndValidateUintArray(valid_uint_array, &out)); + EXPECT_TRUE(TryParseUintArray(valid_uint_array, &out)); EXPECT_THAT(out, ElementsAre(123u, 456u)); - EXPECT_FALSE(ParseAndValidateUintArray(invalid_uint_array, &out)); - EXPECT_FALSE(ParseAndValidateUintArray(kEmptyArray, &out)); + EXPECT_FALSE(TryParseUintArray(invalid_uint_array, &out)); + EXPECT_FALSE(TryParseUintArray(kEmptyArray, &out)); } -TEST(ParsingHelpersTest, ParseAndValidateStringArray) { +TEST(ParsingHelpersTest, TryParseStringArray) { Json::Value valid_string_array; valid_string_array[0] = "nitro cold brew"; valid_string_array[1] = "doppio espresso"; @@ -206,10 +203,10 @@ TEST(ParsingHelpersTest, ParseAndValidateStringArray) { invalid_string_array[1] = 456; std::vector<std::string> out; - EXPECT_TRUE(ParseAndValidateStringArray(valid_string_array, &out)); + EXPECT_TRUE(TryParseStringArray(valid_string_array, &out)); EXPECT_THAT(out, ElementsAre("nitro cold brew", "doppio espresso")); - EXPECT_FALSE(ParseAndValidateStringArray(invalid_string_array, &out)); - EXPECT_FALSE(ParseAndValidateStringArray(kEmptyArray, &out)); + EXPECT_FALSE(TryParseStringArray(invalid_string_array, &out)); + EXPECT_FALSE(TryParseStringArray(kEmptyArray, &out)); } } // namespace json diff --git a/util/simple_fraction.cc b/util/simple_fraction.cc index a98d825c..46d2e585 100644 --- a/util/simple_fraction.cc +++ b/util/simple_fraction.cc @@ -33,37 +33,14 @@ ErrorOr<SimpleFraction> SimpleFraction::FromString(absl::string_view value) { } } - return SimpleFraction{numerator, denominator}; + return SimpleFraction(numerator, denominator); } std::string SimpleFraction::ToString() const { - if (denominator == 1) { - return std::to_string(numerator); + if (denominator_ == 1) { + return std::to_string(numerator_); } - return absl::StrCat(numerator, "/", denominator); -} - -bool SimpleFraction::operator==(const SimpleFraction& other) const { - return numerator == other.numerator && denominator == other.denominator; -} - -bool SimpleFraction::operator!=(const SimpleFraction& other) const { - return !(*this == other); -} - -bool SimpleFraction::is_defined() const { - return denominator != 0; -} - -bool SimpleFraction::is_positive() const { - return is_defined() && (numerator >= 0) && (denominator > 0); -} - -SimpleFraction::operator double() const { - if (denominator == 0) { - return nan(""); - } - return static_cast<double>(numerator) / static_cast<double>(denominator); + return absl::StrCat(numerator_, "/", denominator_); } } // namespace openscreen diff --git a/util/simple_fraction.h b/util/simple_fraction.h index f8ab5083..2df45e24 100644 --- a/util/simple_fraction.h +++ b/util/simple_fraction.h @@ -5,6 +5,8 @@ #ifndef UTIL_SIMPLE_FRACTION_H_ #define UTIL_SIMPLE_FRACTION_H_ +#include <cmath> +#include <limits> #include <string> #include "absl/strings/string_view.h" @@ -14,30 +16,56 @@ namespace openscreen { // SimpleFraction is used to represent simple (or "common") fractions, composed // of a rational number written a/b where a and b are both integers. - -// Note: Since SimpleFraction is a trivial type, it comes with a -// default constructor and is copyable, as well as allowing static -// initialization. - // Some helpful notes on SimpleFraction assumptions/limitations: // 1. SimpleFraction does not perform reductions. 2/4 != 1/2, and -1/-1 != 1/1. // 2. denominator = 0 is considered undefined. // 3. numerator = saturates range to int min or int max // 4. A SimpleFraction is "positive" if and only if it is defined and at least // equal to zero. Since reductions are not performed, -1/-1 is negative. -struct SimpleFraction { +class SimpleFraction { + public: static ErrorOr<SimpleFraction> FromString(absl::string_view value); std::string ToString() const; - bool operator==(const SimpleFraction& other) const; - bool operator!=(const SimpleFraction& other) const; + constexpr SimpleFraction() = default; + constexpr SimpleFraction(int numerator) // NOLINT + : numerator_(numerator) {} + constexpr SimpleFraction(int numerator, int denominator) + : numerator_(numerator), denominator_(denominator) {} + + constexpr SimpleFraction(const SimpleFraction&) = default; + constexpr SimpleFraction(SimpleFraction&&) noexcept = default; + constexpr SimpleFraction& operator=(const SimpleFraction&) = default; + constexpr SimpleFraction& operator=(SimpleFraction&&) = default; + ~SimpleFraction() = default; + + constexpr bool operator==(const SimpleFraction& other) const { + return numerator_ == other.numerator_ && denominator_ == other.denominator_; + } + + constexpr bool operator!=(const SimpleFraction& other) const { + return !(*this == other); + } + + constexpr bool is_defined() const { return denominator_ != 0; } + + constexpr bool is_positive() const { + return (numerator_ >= 0) && (denominator_ > 0); + } + + constexpr explicit operator double() const { + if (denominator_ == 0) { + return nan(""); + } + return static_cast<double>(numerator_) / static_cast<double>(denominator_); + } - bool is_defined() const; - bool is_positive() const; - explicit operator double() const; + constexpr int numerator() const { return numerator_; } + constexpr int denominator() const { return denominator_; } - int numerator = 0; - int denominator = 0; + private: + int numerator_ = 0; + int denominator_ = 1; }; } // namespace openscreen diff --git a/util/stringprintf.cc b/util/stringprintf.cc index 2d9bba22..49c29dc8 100644 --- a/util/stringprintf.cc +++ b/util/stringprintf.cc @@ -32,11 +32,11 @@ std::string StringPrintf(const char* format, ...) { return result; } -std::string HexEncode(absl::Span<const uint8_t> bytes) { +std::string HexEncode(const uint8_t* bytes, std::size_t len) { std::ostringstream hex_dump; hex_dump << std::setfill('0') << std::hex; - for (const uint8_t byte : bytes) { - hex_dump << std::setw(2) << static_cast<int>(byte); + for (std::size_t i = 0; i < len; i++) { + hex_dump << std::setw(2) << static_cast<int>(bytes[i]); } return hex_dump.str(); } diff --git a/util/stringprintf.h b/util/stringprintf.h index 0de394ea..23f07fea 100644 --- a/util/stringprintf.h +++ b/util/stringprintf.h @@ -10,12 +10,6 @@ #include <ostream> #include <string> -// TODO: This header is included in the openscreen discovery public headers (dns_sd_instance.h), -// which exposes this abseil header. Need to figure out a way to hide it. -#if 0 -#include "absl/types/span.h" -#endif - namespace openscreen { // Enable compile-time checking of the printf format argument, if available. @@ -58,10 +52,8 @@ void PrettyPrintAsciiHex(std::ostream& os, It first, It last) { } } -#if 0 // Returns a hex string representation of the given |bytes|. -std::string HexEncode(absl::Span<const uint8_t> bytes); -#endif +std::string HexEncode(const uint8_t* bytes, std::size_t len); } // namespace openscreen diff --git a/util/stringprintf_unittest.cc b/util/stringprintf_unittest.cc index e37e7cb6..bf882163 100644 --- a/util/stringprintf_unittest.cc +++ b/util/stringprintf_unittest.cc @@ -20,13 +20,13 @@ TEST(StringPrintf, ProducesFormattedStrings) { TEST(HexEncode, ProducesEmptyStringFromEmptyByteArray) { const uint8_t kSomeMemoryLocation = 0; - EXPECT_EQ("", HexEncode(absl::Span<const uint8_t>(&kSomeMemoryLocation, 0))); + EXPECT_EQ("", HexEncode(&kSomeMemoryLocation, 0)); } TEST(HexEncode, ProducesHexStringsFromBytes) { const uint8_t kMessage[] = "Hello world!"; const char kMessageInHex[] = "48656c6c6f20776f726c642100"; - EXPECT_EQ(kMessageInHex, HexEncode(kMessage)); + EXPECT_EQ(kMessageInHex, HexEncode(kMessage, sizeof(kMessage))); } } // namespace |