aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDoug Greiman <dgreiman@google.com>2016-01-19 16:53:58 -0800
committerDoug Greiman <dgreiman@google.com>2016-01-19 17:14:57 -0800
commit898bee48b5f8f6f65075d45164195133b4a23988 (patch)
tree31977f162f3f30838985aa157fbffd108f705eec
parent4cc0f9fa73882dac5d0a8908c29e603a6ae91de3 (diff)
downloadportpicker-898bee48b5f8f6f65075d45164195133b4a23988.tar.gz
Fix bug in get_port_for_process
Added unit tests for corner cases where process id == port number.
-rw-r--r--ChangeLog.md1
-rw-r--r--src/portserver.py2
-rw-r--r--src/tests/portserver_test.py22
3 files changed, 24 insertions, 1 deletions
diff --git a/ChangeLog.md b/ChangeLog.md
index d74ea8c..63da2a5 100644
--- a/ChangeLog.md
+++ b/ChangeLog.md
@@ -1,6 +1,7 @@
## 1.1.1
* Changed default port range to 15000-24999 to avoid ephemeral ports.
+* Portserver bugfix.
## 1.1.0
diff --git a/src/portserver.py b/src/portserver.py
index 03beddc..43f5567 100644
--- a/src/portserver.py
+++ b/src/portserver.py
@@ -180,7 +180,7 @@ class _PortPool(object):
check_count += 1
if (candidate.start_time == 0 or
candidate.start_time != _get_process_start_time(candidate.pid)):
- if _is_port_free(candidate.pid):
+ if _is_port_free(candidate.port):
candidate.pid = pid
candidate.start_time = _get_process_start_time(pid)
if not candidate.start_time:
diff --git a/src/tests/portserver_test.py b/src/tests/portserver_test.py
index f0475c3..bd7d61e 100644
--- a/src/tests/portserver_test.py
+++ b/src/tests/portserver_test.py
@@ -192,6 +192,28 @@ class PortPoolTest(unittest.TestCase):
self.assertEqual(2, self.pool.num_ports())
self.assertEqual(2, self.pool.ports_checked_for_last_request)
+ @mock.patch.object(portserver, '_is_port_free')
+ @mock.patch.object(os, 'getpid')
+ def test_get_port_for_process_pid_eq_port(self, mock_getpid, mock_is_port_free):
+ self.pool.add_port_to_free_pool(12345)
+ self.pool.add_port_to_free_pool(12344)
+ mock_is_port_free.side_effect = lambda port: port == os.getpid()
+ mock_getpid.return_value = 12345
+ self.assertEqual(2, self.pool.num_ports())
+ self.assertEqual(12345, self.pool.get_port_for_process(os.getpid()))
+ self.assertEqual(2, self.pool.ports_checked_for_last_request)
+
+ @mock.patch.object(portserver, '_is_port_free')
+ @mock.patch.object(os, 'getpid')
+ def test_get_port_for_process_pid_ne_port(self, mock_getpid, mock_is_port_free):
+ self.pool.add_port_to_free_pool(12344)
+ self.pool.add_port_to_free_pool(12345)
+ mock_is_port_free.side_effect = lambda port: port != os.getpid()
+ mock_getpid.return_value = 12345
+ self.assertEqual(2, self.pool.num_ports())
+ self.assertEqual(12344, self.pool.get_port_for_process(os.getpid()))
+ self.assertEqual(2, self.pool.ports_checked_for_last_request)
+
@mock.patch.object(portserver, '_get_process_command_line')
@mock.patch.object(portserver, '_should_allocate_port')