diff options
-rw-r--r-- | crosperf/suite_runner.py | 43 | ||||
-rwxr-xr-x | crosperf/suite_runner_unittest.py | 13 |
2 files changed, 53 insertions, 3 deletions
diff --git a/crosperf/suite_runner.py b/crosperf/suite_runner.py index dfd14b21..0e4ba045 100644 --- a/crosperf/suite_runner.py +++ b/crosperf/suite_runner.py @@ -6,15 +6,29 @@ """SuiteRunner defines the interface from crosperf to test script.""" +import contextlib import json import os +from pathlib import Path import pipes +import random import shlex +import subprocess import time from cros_utils import command_executer +SSHWATCHER = [ + "go", + "run", + str( + Path( + __file__, + "../../../../platform/dev/contrib/sshwatcher/sshwatcher.go", + ).resolve() + ), +] TEST_THAT_PATH = "/usr/bin/test_that" TAST_PATH = "/usr/bin/tast" CROSFLEET_PATH = "crosfleet" @@ -53,6 +67,32 @@ def GetDutConfigArgs(dut_config): return f"dut_config={pipes.quote(json.dumps(dut_config))}" +@contextlib.contextmanager +def ssh_tunnel(machinename): + """Context manager that forwards a TCP port over SSH while active. + + This class is used to set up port forwarding before entering the + chroot, so that the forwarded port can be used from inside + the chroot. + + The value yielded by ssh_tunnel is a host:port string. + """ + # We have to tell sshwatcher which port we want to use. + # We pick a port that is likely to be available. + port = random.randrange(4096, 32768) + cmd = SSHWATCHER + [machinename, str(port)] + # Pylint wants us to use subprocess.Popen as a context manager, + # but we don't, so that we can ask sshwatcher to terminate and + # limit the time we wait for it to do so. + # pylint: disable=consider-using-with + proc = subprocess.Popen(cmd) + try: + yield f"localhost:{port}" + finally: + proc.terminate() + proc.wait(timeout=5) + + class SuiteRunner(object): """This defines the interface from crosperf to test script.""" @@ -83,7 +123,8 @@ class SuiteRunner(object): ) else: if benchmark.suite == "tast": - ret_tup = self.Tast_Run(machine_name, label, benchmark) + with ssh_tunnel(machine_name) as hostport: + ret_tup = self.Tast_Run(hostport, label, benchmark) else: ret_tup = self.Test_That_Run( machine_name, label, benchmark, test_args, profiler_args diff --git a/crosperf/suite_runner_unittest.py b/crosperf/suite_runner_unittest.py index 69476f37..cc96ee4a 100755 --- a/crosperf/suite_runner_unittest.py +++ b/crosperf/suite_runner_unittest.py @@ -8,6 +8,7 @@ """Unittest for suite_runner.""" +import contextlib import json import unittest import unittest.mock as mock @@ -118,7 +119,14 @@ class SuiteRunnerTest(unittest.TestCase): res = suite_runner.GetDutConfigArgs(dut_config) self.assertEqual(res, output_str) - def test_run(self): + @mock.patch("suite_runner.ssh_tunnel") + def test_run(self, ssh_tunnel): + @contextlib.contextmanager + def mock_ssh_tunnel(_host): + yield "fakelocalhost:1234" + + ssh_tunnel.side_effect = mock_ssh_tunnel + def reset(): self.test_that_args = [] self.crosfleet_run_args = [] @@ -254,7 +262,8 @@ class SuiteRunnerTest(unittest.TestCase): self.assertFalse(self.call_test_that_run) self.assertFalse(self.call_crosfleet_run) self.assertEqual( - self.tast_args, ["fake_machine", self.mock_label, self.tast_bench] + self.tast_args, + ["fakelocalhost:1234", self.mock_label, self.tast_bench], ) def test_gen_test_args(self): |