diff options
author | Yu Shan <shanyu@google.com> | 2018-01-04 06:25:27 +0000 |
---|---|---|
committer | android-build-merger <android-build-merger@google.com> | 2018-01-04 06:25:27 +0000 |
commit | c8280e30c836c7b44f01368f5165ad81c5fa5c9e (patch) | |
tree | 3a40fb070bf6a7a5892e48ae814e2199091a9b8d | |
parent | 88a5a46939994161c9a3fe623e261e60f23d096d (diff) | |
parent | cbfd415593d4fb9161dd1b932dbf912064c482a5 (diff) | |
download | attestation-c8280e30c836c7b44f01368f5165ad81c5fa5c9e.tar.gz |
[AT-Factory-Tool] Fix USB Location bug.
am: cbfd415593
Change-Id: I535d1a2a6d54087aa668f333d1df4be6473cee37
-rw-r--r-- | at-factory-tool/atftman.py | 28 | ||||
-rw-r--r-- | at-factory-tool/atftman_unittest.py | 35 | ||||
-rw-r--r-- | at-factory-tool/serialmapperlinux.py | 27 | ||||
-rw-r--r-- | at-factory-tool/serialmapperwin.py | 25 |
4 files changed, 77 insertions, 38 deletions
diff --git a/at-factory-tool/atftman.py b/at-factory-tool/atftman.py index 1a1fadc..e952763 100644 --- a/at-factory-tool/atftman.py +++ b/at-factory-tool/atftman.py @@ -399,11 +399,9 @@ class AtftManager(object): atfa_serial = self._atfa_dev_setting.serial_number if atfa_serial in self.stable_serials: # We found the ATFA device again. - serial_location_map = self._serial_mapper.get_serial_map() + self._serial_mapper.refresh_serial_map() controller = self._fastboot_device_controller(atfa_serial) - location = None - if atfa_serial in serial_location_map: - location = serial_location_map[atfa_serial] + location = self._serial_mapper.get_location(atfa_serial) self.atfa_dev = DeviceInfo(controller, atfa_serial, location) # Clean the state @@ -443,25 +441,21 @@ class AtftManager(object): common_serials = [device.serial_number for device in self.target_devs] # Create new device object for newly added devices. - serial_location_map = self._serial_mapper.get_serial_map() + self._serial_mapper.refresh_serial_map() for serial in new_targets: if serial not in common_serials: - self._CreateNewTargetDevice(serial, serial_location_map) + self._CreateNewTargetDevice(serial) - def _CreateNewTargetDevice( - self, serial, serial_location_map, check_status=True): + def _CreateNewTargetDevice(self, serial, check_status=True): """Create a new target device object. Args: serial: The serial number for the new target device. - serial_location_map: The serial location map. check_status: Whether to check provision status for the target device. """ try: controller = self._fastboot_device_controller(serial) - location = None - if serial in serial_location_map: - location = serial_location_map[serial] + location = self._serial_mapper.get_location(serial) new_target_dev = DeviceInfo(controller, serial, location) if check_status: @@ -482,11 +476,9 @@ class AtftManager(object): Args: atfa_serial: The serial number of the ATFA device to be added. """ + self._serial_mapper.refresh_serial_map() controller = self._fastboot_device_controller(atfa_serial) - serial_location_map = self._serial_mapper.get_serial_map() - location = None - if atfa_serial in serial_location_map: - location = serial_location_map[atfa_serial] + location = self._serial_mapper.get_location(atfa_serial) if self._atfa_reboot_lock.acquire(False): # If there's not an atfa setting os already happening self._atfa_dev_setting = DeviceInfo(controller, atfa_serial, location) @@ -845,8 +837,8 @@ class AtftManager(object): self.target_devs.remove(rebooting_dev) del rebooting_dev if success: - serial_location_map = self._serial_mapper.get_serial_map() - self._CreateNewTargetDevice(serial, serial_location_map, True) + self._serial_mapper.refresh_serial_map() + self._CreateNewTargetDevice(serial, True) self.GetTargetDevice(serial).provision_status = ( ProvisionStatus.REBOOT_SUCCESS) callback() diff --git a/at-factory-tool/atftman_unittest.py b/at-factory-tool/atftman_unittest.py index a9744bf..6ea5f1b 100644 --- a/at-factory-tool/atftman_unittest.py +++ b/at-factory-tool/atftman_unittest.py @@ -411,16 +411,30 @@ class AtftManTest(unittest.TestCase): # Nothing appears twice. self.assertEqual(0, len(atft_manager.target_devs)) + def mockSetSerialMapper(self, serial_map): + self.serial_map = {} + for serial in serial_map: + self.serial_map[serial.lower()] = serial_map[serial] + + def mockGetLocation(self, serial): + serial_lower = serial.lower() + if serial_lower in self.serial_map: + return self.serial_map[serial_lower] + return None + @patch('threading.Timer') def testListDevicesLocation(self, mock_create_timer): mock_create_timer.side_effect = self.MockCreateInstantTimer mock_serial_mapper = MagicMock() - mock_serial_instance = MagicMock() - mock_serial_mapper.return_value = mock_serial_instance - mock_serial_instance.get_serial_map.return_value = { + smap = { self.ATFA_TEST_SERIAL: self.TEST_LOCATION, self.TEST_SERIAL: self.TEST_LOCATION2 } + mock_serial_instance = MagicMock() + mock_serial_mapper.return_value = mock_serial_instance + mock_serial_instance.refresh_serial_map.side_effect = ( + lambda serial_map=smap: self.mockSetSerialMapper(serial_map)) + mock_serial_instance.get_location.side_effect = self.mockGetLocation mock_fastboot = MagicMock() mock_fastboot.side_effect = self.MockInit atft_manager = atftman.AtftManager( @@ -440,11 +454,14 @@ class AtftManTest(unittest.TestCase): mock_serial_mapper = MagicMock() mock_serial_instance = MagicMock() mock_serial_mapper.return_value = mock_serial_instance - mock_serial_instance.get_serial_map.return_value = { + smap = { self.TEST_SERIAL: self.TEST_LOCATION, self.TEST_SERIAL2: self.TEST_LOCATION2, self.TEST_SERIAL3: self.TEST_LOCATION3 } + mock_serial_instance.refresh_serial_map.side_effect = ( + lambda serial_map=smap: self.mockSetSerialMapper(serial_map)) + mock_serial_instance.get_location.side_effect = self.mockGetLocation mock_fastboot = MagicMock() mock_fastboot.side_effect = self.MockInit atft_manager = atftman.AtftManager( @@ -463,11 +480,14 @@ class AtftManTest(unittest.TestCase): mock_serial_mapper = MagicMock() mock_serial_instance = MagicMock() mock_serial_mapper.return_value = mock_serial_instance - mock_serial_instance.get_serial_map.return_value = { + smap = { self.TEST_SERIAL: self.TEST_LOCATION, self.TEST_SERIAL2: self.TEST_LOCATION2, self.TEST_SERIAL3: self.TEST_LOCATION3 } + mock_serial_instance.refresh_serial_map.side_effect = ( + lambda serial_map=smap: self.mockSetSerialMapper(serial_map)) + mock_serial_instance.get_location.side_effect = self.mockGetLocation mock_fastboot = MagicMock() mock_fastboot.side_effect = self.MockInit atft_manager = atftman.AtftManager( @@ -486,11 +506,14 @@ class AtftManTest(unittest.TestCase): mock_serial_mapper = MagicMock() mock_serial_instance = MagicMock() mock_serial_mapper.return_value = mock_serial_instance - mock_serial_instance.get_serial_map.return_value = { + smap = { self.TEST_SERIAL: self.TEST_LOCATION, self.TEST_SERIAL2: self.TEST_LOCATION2, self.TEST_SERIAL3: self.TEST_LOCATION3 } + mock_serial_instance.refresh_serial_map.side_effect = ( + lambda serial_map=smap: self.mockSetSerialMapper(serial_map)) + mock_serial_instance.get_location.side_effect = self.mockGetLocation mock_fastboot = MagicMock() mock_fastboot_instance = MagicMock() mock_fastboot.side_effect = self.MockInit diff --git a/at-factory-tool/serialmapperlinux.py b/at-factory-tool/serialmapperlinux.py index 7407338..e32b73c 100644 --- a/at-factory-tool/serialmapperlinux.py +++ b/at-factory-tool/serialmapperlinux.py @@ -25,16 +25,16 @@ class SerialMapper(object): USB_DEVICES_PATH = '/sys/bus/usb/devices/' - def get_serial_map(self): - """Get the serial_number -> USB location map. + def __init__(self): + self.serial_map = {} - Returns: - A Dictionary of {serial_number: USB location} + def refresh_serial_map(self): + """Refresh the serial_number -> USB location map. """ serial_to_location_map = {} # check if sysfs is mounted. if not os.path.exists(self.USB_DEVICES_PATH): - return serial_to_location_map + return for device_folder_name in os.listdir(self.USB_DEVICES_PATH): device_folder = os.path.join(self.USB_DEVICES_PATH, device_folder_name) @@ -47,7 +47,20 @@ class SerialMapper(object): serial_path = os.path.join(device_folder, 'serial') if os.path.isfile(serial_path): with open(serial_path) as f: - serial = f.readline().rstrip('\n') + serial = f.readline().rstrip('\n').lower() serial_to_location_map[serial] = device_folder_name - return serial_to_location_map + self.serial_map = serial_to_location_map + + def get_location(self, serial): + """Get the USB location according to the serial number. + + Args: + serial: The serial number for the device. + Returns: + The USB physical location for the device. + """ + serial_lower = serial.lower() + if serial_lower in self.serial_map: + return self.serial_map[serial_lower] + return None diff --git a/at-factory-tool/serialmapperwin.py b/at-factory-tool/serialmapperwin.py index d902974..ff5ae52 100644 --- a/at-factory-tool/serialmapperwin.py +++ b/at-factory-tool/serialmapperwin.py @@ -63,12 +63,10 @@ class SerialMapper(object): def __init__(self): self.setupapi = ctypes.WinDLL('setupapi') + self.serial_map = {} - def get_serial_map(self): - """Get the serial_number -> USB location map. - - Returns: - A Dictionary of {serial_number: USB location} + def refresh_serial_map(self): + """Refresh the serial_number -> USB location map. """ serial_map = {} device_inf_set = None @@ -131,12 +129,25 @@ class SerialMapper(object): instance_id = device_instance_id_buffer.value instance_parts = instance_id.split('\\') if instance_parts: - serial = instance_parts.pop() + serial = instance_parts.pop().lower() serial_map[serial] = location i += 1 # Destroy the device information set if device_inf_set is not None: SetupDiDestroyDeviceInfoList(device_inf_set) - return serial_map + self.serial_map = serial_map + + def get_location(self, serial): + """Get the USB location according to the serial number. + + Args: + serial: The serial number for the device. + Returns: + The USB physical location for the device. + """ + serial_lower = serial.lower() + if serial_lower in self.serial_map: + return self.serial_map[serial_lower] + return None |