aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--crosperf/suite_runner.py43
-rwxr-xr-xcrosperf/suite_runner_unittest.py13
2 files changed, 53 insertions, 3 deletions
diff --git a/crosperf/suite_runner.py b/crosperf/suite_runner.py
index dfd14b21..0e4ba045 100644
--- a/crosperf/suite_runner.py
+++ b/crosperf/suite_runner.py
@@ -6,15 +6,29 @@
"""SuiteRunner defines the interface from crosperf to test script."""
+import contextlib
import json
import os
+from pathlib import Path
import pipes
+import random
import shlex
+import subprocess
import time
from cros_utils import command_executer
+SSHWATCHER = [
+ "go",
+ "run",
+ str(
+ Path(
+ __file__,
+ "../../../../platform/dev/contrib/sshwatcher/sshwatcher.go",
+ ).resolve()
+ ),
+]
TEST_THAT_PATH = "/usr/bin/test_that"
TAST_PATH = "/usr/bin/tast"
CROSFLEET_PATH = "crosfleet"
@@ -53,6 +67,32 @@ def GetDutConfigArgs(dut_config):
return f"dut_config={pipes.quote(json.dumps(dut_config))}"
+@contextlib.contextmanager
+def ssh_tunnel(machinename):
+ """Context manager that forwards a TCP port over SSH while active.
+
+ This class is used to set up port forwarding before entering the
+ chroot, so that the forwarded port can be used from inside
+ the chroot.
+
+ The value yielded by ssh_tunnel is a host:port string.
+ """
+ # We have to tell sshwatcher which port we want to use.
+ # We pick a port that is likely to be available.
+ port = random.randrange(4096, 32768)
+ cmd = SSHWATCHER + [machinename, str(port)]
+ # Pylint wants us to use subprocess.Popen as a context manager,
+ # but we don't, so that we can ask sshwatcher to terminate and
+ # limit the time we wait for it to do so.
+ # pylint: disable=consider-using-with
+ proc = subprocess.Popen(cmd)
+ try:
+ yield f"localhost:{port}"
+ finally:
+ proc.terminate()
+ proc.wait(timeout=5)
+
+
class SuiteRunner(object):
"""This defines the interface from crosperf to test script."""
@@ -83,7 +123,8 @@ class SuiteRunner(object):
)
else:
if benchmark.suite == "tast":
- ret_tup = self.Tast_Run(machine_name, label, benchmark)
+ with ssh_tunnel(machine_name) as hostport:
+ ret_tup = self.Tast_Run(hostport, label, benchmark)
else:
ret_tup = self.Test_That_Run(
machine_name, label, benchmark, test_args, profiler_args
diff --git a/crosperf/suite_runner_unittest.py b/crosperf/suite_runner_unittest.py
index 69476f37..cc96ee4a 100755
--- a/crosperf/suite_runner_unittest.py
+++ b/crosperf/suite_runner_unittest.py
@@ -8,6 +8,7 @@
"""Unittest for suite_runner."""
+import contextlib
import json
import unittest
import unittest.mock as mock
@@ -118,7 +119,14 @@ class SuiteRunnerTest(unittest.TestCase):
res = suite_runner.GetDutConfigArgs(dut_config)
self.assertEqual(res, output_str)
- def test_run(self):
+ @mock.patch("suite_runner.ssh_tunnel")
+ def test_run(self, ssh_tunnel):
+ @contextlib.contextmanager
+ def mock_ssh_tunnel(_host):
+ yield "fakelocalhost:1234"
+
+ ssh_tunnel.side_effect = mock_ssh_tunnel
+
def reset():
self.test_that_args = []
self.crosfleet_run_args = []
@@ -254,7 +262,8 @@ class SuiteRunnerTest(unittest.TestCase):
self.assertFalse(self.call_test_that_run)
self.assertFalse(self.call_crosfleet_run)
self.assertEqual(
- self.tast_args, ["fake_machine", self.mock_label, self.tast_bench]
+ self.tast_args,
+ ["fakelocalhost:1234", self.mock_label, self.tast_bench],
)
def test_gen_test_args(self):