diff options
Diffstat (limited to 'binary_search_tool/binary_search_state.py')
-rwxr-xr-x | binary_search_tool/binary_search_state.py | 598 |
1 files changed, 598 insertions, 0 deletions
diff --git a/binary_search_tool/binary_search_state.py b/binary_search_tool/binary_search_state.py new file mode 100755 index 00000000..a10e90b9 --- /dev/null +++ b/binary_search_tool/binary_search_state.py @@ -0,0 +1,598 @@ +#!/usr/bin/python2 +"""The binary search wrapper.""" + +from __future__ import print_function + +import argparse +import contextlib +import errno +import math +import os +import pickle +import sys +import tempfile +import time + +# Adds cros_utils to PYTHONPATH +import common + +# Now we do import from cros_utils +from cros_utils import command_executer +from cros_utils import logger + +import binary_search_perforce + +GOOD_SET_VAR = 'BISECT_GOOD_SET' +BAD_SET_VAR = 'BISECT_BAD_SET' + +STATE_FILE = '%s.state' % sys.argv[0] +HIDDEN_STATE_FILE = os.path.join( + os.path.dirname(STATE_FILE), '.%s' % os.path.basename(STATE_FILE)) + + +class Error(Exception): + """The general binary search tool error class.""" + pass + + +@contextlib.contextmanager +def SetFile(env_var, items): + """Generate set files that can be used by switch/test scripts. + + Generate temporary set file (good/bad) holding contents of good/bad items for + the current binary search iteration. Store the name of each file as an + environment variable so all child processes can access it. + + This function is a contextmanager, meaning it's meant to be used with the + "with" statement in Python. This is so cleanup and setup happens automatically + and cleanly. Execution of the outer "with" statement happens at the "yield" + statement. + + Args: + env_var: What environment variable to store the file name in. + items: What items are in this set. + """ + with tempfile.NamedTemporaryFile() as f: + os.environ[env_var] = f.name + f.write('\n'.join(items)) + f.flush() + yield + + +class BinarySearchState(object): + """The binary search state class.""" + + def __init__(self, get_initial_items, switch_to_good, switch_to_bad, + test_setup_script, test_script, incremental, prune, iterations, + prune_iterations, verify, file_args, verbose): + """BinarySearchState constructor, see Run for full args documentation.""" + self.get_initial_items = get_initial_items + self.switch_to_good = switch_to_good + self.switch_to_bad = switch_to_bad + self.test_setup_script = test_setup_script + self.test_script = test_script + self.incremental = incremental + self.prune = prune + self.iterations = iterations + self.prune_iterations = prune_iterations + self.verify = verify + self.file_args = file_args + self.verbose = verbose + + self.l = logger.GetLogger() + self.ce = command_executer.GetCommandExecuter() + + self.resumed = False + self.prune_cycles = 0 + self.search_cycles = 0 + self.binary_search = None + self.all_items = None + self.PopulateItemsUsingCommand(self.get_initial_items) + self.currently_good_items = set([]) + self.currently_bad_items = set([]) + self.found_items = set([]) + self.known_good = set([]) + + self.start_time = time.time() + + def SwitchToGood(self, item_list): + """Switch given items to "good" set.""" + if self.incremental: + self.l.LogOutput( + 'Incremental set. Wanted to switch %s to good' % str(item_list), + print_to_console=self.verbose) + incremental_items = [ + item for item in item_list if item not in self.currently_good_items + ] + item_list = incremental_items + self.l.LogOutput( + 'Incremental set. Actually switching %s to good' % str(item_list), + print_to_console=self.verbose) + + if not item_list: + return + + self.l.LogOutput( + 'Switching %s to good' % str(item_list), print_to_console=self.verbose) + self.RunSwitchScript(self.switch_to_good, item_list) + self.currently_good_items = self.currently_good_items.union(set(item_list)) + self.currently_bad_items.difference_update(set(item_list)) + + def SwitchToBad(self, item_list): + """Switch given items to "bad" set.""" + if self.incremental: + self.l.LogOutput( + 'Incremental set. Wanted to switch %s to bad' % str(item_list), + print_to_console=self.verbose) + incremental_items = [ + item for item in item_list if item not in self.currently_bad_items + ] + item_list = incremental_items + self.l.LogOutput( + 'Incremental set. Actually switching %s to bad' % str(item_list), + print_to_console=self.verbose) + + if not item_list: + return + + self.l.LogOutput( + 'Switching %s to bad' % str(item_list), print_to_console=self.verbose) + self.RunSwitchScript(self.switch_to_bad, item_list) + self.currently_bad_items = self.currently_bad_items.union(set(item_list)) + self.currently_good_items.difference_update(set(item_list)) + + def RunSwitchScript(self, switch_script, item_list): + """Pass given items to switch script. + + Args: + switch_script: path to switch script + item_list: list of all items to be switched + """ + if self.file_args: + with tempfile.NamedTemporaryFile() as f: + f.write('\n'.join(item_list)) + f.flush() + command = '%s %s' % (switch_script, f.name) + ret, _, _ = self.ce.RunCommandWExceptionCleanup( + command, print_to_console=self.verbose) + else: + command = '%s %s' % (switch_script, ' '.join(item_list)) + try: + ret, _, _ = self.ce.RunCommandWExceptionCleanup( + command, print_to_console=self.verbose) + except OSError as e: + if e.errno == errno.E2BIG: + raise Error('Too many arguments for switch script! Use --file_args') + else: + raise + assert ret == 0, 'Switch script %s returned %d' % (switch_script, ret) + + def TestScript(self): + """Run test script and return exit code from script.""" + command = self.test_script + ret, _, _ = self.ce.RunCommandWExceptionCleanup(command) + return ret + + def TestSetupScript(self): + """Run test setup script and return exit code from script.""" + if not self.test_setup_script: + return 0 + + command = self.test_setup_script + ret, _, _ = self.ce.RunCommandWExceptionCleanup(command) + return ret + + def DoVerify(self): + """Verify correctness of test environment. + + Verify that a "good" set of items produces a "good" result and that a "bad" + set of items produces a "bad" result. To be run directly before running + DoSearch. If verify is False this step is skipped. + """ + if not self.verify: + return + + self.l.LogOutput('VERIFICATION') + self.l.LogOutput('Beginning tests to verify good/bad sets\n') + + self._OutputProgress('Verifying items from GOOD set\n') + with SetFile(GOOD_SET_VAR, self.all_items), SetFile(BAD_SET_VAR, []): + self.l.LogOutput('Resetting all items to good to verify.') + self.SwitchToGood(self.all_items) + status = self.TestSetupScript() + assert status == 0, 'When reset_to_good, test setup should succeed.' + status = self.TestScript() + assert status == 0, 'When reset_to_good, status should be 0.' + + self._OutputProgress('Verifying items from BAD set\n') + with SetFile(GOOD_SET_VAR, []), SetFile(BAD_SET_VAR, self.all_items): + self.l.LogOutput('Resetting all items to bad to verify.') + self.SwitchToBad(self.all_items) + status = self.TestSetupScript() + # The following assumption is not true; a bad image might not + # successfully push onto a device. + # assert status == 0, 'When reset_to_bad, test setup should succeed.' + if status == 0: + status = self.TestScript() + assert status == 1, 'When reset_to_bad, status should be 1.' + + def DoSearch(self): + """Perform full search for bad items. + + Perform full search until prune_iterations number of bad items are found. + """ + while (True and len(self.all_items) > 1 and + self.prune_cycles < self.prune_iterations): + terminated = self.DoBinarySearch() + self.prune_cycles += 1 + if not terminated: + break + # Prune is set. + prune_index = self.binary_search.current + + # If found item is last item, no new items can be found + if prune_index == len(self.all_items) - 1: + self.l.LogOutput('First bad item is the last item. Breaking.') + self.l.LogOutput('Bad items are: %s' % self.all_items[-1]) + break + + # If already seen item we have no new bad items to find, finish up + if self.all_items[prune_index] in self.found_items: + self.l.LogOutput( + 'Found item already found before: %s.' % + self.all_items[prune_index], + print_to_console=self.verbose) + self.l.LogOutput('No more bad items remaining. Done searching.') + self.l.LogOutput('Bad items are: %s' % ' '.join(self.found_items)) + break + + new_all_items = list(self.all_items) + # Move prune item to the end of the list. + new_all_items.append(new_all_items.pop(prune_index)) + self.found_items.add(new_all_items[-1]) + + # Everything below newly found bad item is now known to be a good item. + # Take these good items out of the equation to save time on the next + # search. We save these known good items so they are still sent to the + # switch_to_good script. + if prune_index: + self.known_good.update(new_all_items[:prune_index]) + new_all_items = new_all_items[prune_index:] + + self.l.LogOutput( + 'Old list: %s. New list: %s' % (str(self.all_items), + str(new_all_items)), + print_to_console=self.verbose) + + if not self.prune: + self.l.LogOutput('Not continuning further, --prune is not set') + break + # FIXME: Do we need to Convert the currently good items to bad + self.PopulateItemsUsingList(new_all_items) + + def DoBinarySearch(self): + """Perform single iteration of binary search.""" + # If in resume mode don't reset search_cycles + if not self.resumed: + self.search_cycles = 0 + else: + self.resumed = False + + terminated = False + while self.search_cycles < self.iterations and not terminated: + self.SaveState() + self.OutputIterationProgress() + + self.search_cycles += 1 + [bad_items, good_items] = self.GetNextItems() + + with SetFile(GOOD_SET_VAR, good_items), SetFile(BAD_SET_VAR, bad_items): + # TODO: bad_items should come first. + self.SwitchToGood(good_items) + self.SwitchToBad(bad_items) + status = self.TestSetupScript() + if status == 0: + status = self.TestScript() + terminated = self.binary_search.SetStatus(status) + + if terminated: + self.l.LogOutput('Terminated!', print_to_console=self.verbose) + if not terminated: + self.l.LogOutput('Ran out of iterations searching...') + self.l.LogOutput(str(self), print_to_console=self.verbose) + return terminated + + def PopulateItemsUsingCommand(self, command): + """Update all_items and binary search logic from executable. + + This method is mainly required for enumerating the initial list of items + from the get_initial_items script. + + Args: + command: path to executable that will enumerate items. + """ + ce = command_executer.GetCommandExecuter() + _, out, _ = ce.RunCommandWExceptionCleanup( + command, return_output=True, print_to_console=self.verbose) + all_items = out.split() + self.PopulateItemsUsingList(all_items) + + def PopulateItemsUsingList(self, all_items): + """Update all_items and binary searching logic from list. + + Args: + all_items: new list of all_items + """ + self.all_items = all_items + self.binary_search = binary_search_perforce.BinarySearcher( + logger_to_set=self.l) + self.binary_search.SetSortedList(self.all_items) + + def SaveState(self): + """Save state to STATE_FILE. + + SaveState will create a new unique, hidden state file to hold data from + object. Then atomically overwrite the STATE_FILE symlink to point to the + new data. + + Raises: + Error if STATE_FILE already exists but is not a symlink. + """ + ce, l = self.ce, self.l + self.ce, self.l, self.binary_search.logger = None, None, None + old_state = None + + _, path = tempfile.mkstemp(prefix=HIDDEN_STATE_FILE, dir='.') + with open(path, 'wb') as f: + pickle.dump(self, f) + + if os.path.exists(STATE_FILE): + if os.path.islink(STATE_FILE): + old_state = os.readlink(STATE_FILE) + else: + raise Error(('%s already exists and is not a symlink!\n' + 'State file saved to %s' % (STATE_FILE, path))) + + # Create new link and atomically overwrite old link + temp_link = '%s.link' % HIDDEN_STATE_FILE + os.symlink(path, temp_link) + os.rename(temp_link, STATE_FILE) + + if old_state: + os.remove(old_state) + + self.ce, self.l, self.binary_search.logger = ce, l, l + + @classmethod + def LoadState(cls): + """Create BinarySearchState object from STATE_FILE.""" + if not os.path.isfile(STATE_FILE): + return None + try: + bss = pickle.load(file(STATE_FILE)) + bss.l = logger.GetLogger() + bss.ce = command_executer.GetCommandExecuter() + bss.binary_search.logger = bss.l + bss.start_time = time.time() + + # Set resumed to be True so we can enter DoBinarySearch without the method + # resetting our current search_cycles to 0. + bss.resumed = True + + # Set currently_good_items and currently_bad_items to empty so that the + # first iteration after resuming will always be non-incremental. This is + # just in case the environment changes, the user makes manual changes, or + # a previous switch_script corrupted the environment. + bss.currently_good_items = set([]) + bss.currently_bad_items = set([]) + + binary_search_perforce.verbose = bss.verbose + return bss + except StandardError: + return None + + def RemoveState(self): + """Remove STATE_FILE and its symlinked data from file system.""" + if os.path.exists(STATE_FILE): + if os.path.islink(STATE_FILE): + real_file = os.readlink(STATE_FILE) + os.remove(real_file) + os.remove(STATE_FILE) + + def GetNextItems(self): + """Get next items for binary search based on result of the last test run.""" + border_item = self.binary_search.GetNext() + index = self.all_items.index(border_item) + + next_bad_items = self.all_items[:index + 1] + next_good_items = self.all_items[index + 1:] + list(self.known_good) + + return [next_bad_items, next_good_items] + + def ElapsedTimeString(self): + """Return h m s format of elapsed time since execution has started.""" + diff = int(time.time() - self.start_time) + seconds = diff % 60 + minutes = (diff / 60) % 60 + hours = diff / (60 * 60) + + seconds = str(seconds).rjust(2) + minutes = str(minutes).rjust(2) + hours = str(hours).rjust(2) + + return '%sh %sm %ss' % (hours, minutes, seconds) + + def _OutputProgress(self, progress_text): + """Output current progress of binary search to console and logs. + + Args: + progress_text: The progress to display to the user. + """ + progress = ('\n***** PROGRESS (elapsed time: %s) *****\n' + '%s' + '************************************************') + progress = progress % (self.ElapsedTimeString(), progress_text) + self.l.LogOutput(progress) + + def OutputIterationProgress(self): + out = ('Search %d of estimated %d.\n' + 'Prune %d of max %d.\n' + 'Current bad items found:\n' + '%s\n') + out = out % (self.search_cycles + 1, + math.ceil(math.log(len(self.all_items), 2)), + self.prune_cycles + 1, self.prune_iterations, + ', '.join(self.found_items)) + self._OutputProgress(out) + + def __str__(self): + ret = '' + ret += 'all: %s\n' % str(self.all_items) + ret += 'currently_good: %s\n' % str(self.currently_good_items) + ret += 'currently_bad: %s\n' % str(self.currently_bad_items) + ret += str(self.binary_search) + return ret + + +class MockBinarySearchState(BinarySearchState): + """Mock class for BinarySearchState.""" + + def __init__(self, **kwargs): + # Initialize all arguments to None + default_kwargs = { + 'get_initial_items': 'echo "1"', + 'switch_to_good': None, + 'switch_to_bad': None, + 'test_setup_script': None, + 'test_script': None, + 'incremental': True, + 'prune': False, + 'iterations': 50, + 'prune_iterations': 100, + 'verify': True, + 'file_args': False, + 'verbose': False + } + default_kwargs.update(kwargs) + super(MockBinarySearchState, self).__init__(**default_kwargs) + + +def _CanonicalizeScript(script_name): + """Return canonical path to script. + + Args: + script_name: Relative or absolute path to script + + Returns: + Canonicalized script path + """ + script_name = os.path.expanduser(script_name) + if not script_name.startswith('/'): + return os.path.join('.', script_name) + + +def Run(get_initial_items, + switch_to_good, + switch_to_bad, + test_script, + test_setup_script=None, + iterations=50, + prune=False, + noincremental=False, + file_args=False, + verify=True, + prune_iterations=100, + verbose=False, + resume=False): + """Run binary search tool. Equivalent to running through terminal. + + Args: + get_initial_items: Script to enumerate all items being binary searched + switch_to_good: Script that will take items as input and switch them to good + set + switch_to_bad: Script that will take items as input and switch them to bad + set + test_script: Script that will determine if the current combination of good + and bad items make a "good" or "bad" result. + test_setup_script: Script to do necessary setup (building, compilation, + etc.) for test_script. + iterations: How many binary search iterations to run before exiting. + prune: If False the binary search tool will stop when the first bad item is + found. Otherwise then binary search tool will continue searching + until all bad items are found (or prune_iterations is reached). + noincremental: Whether to send "diffs" of good/bad items to switch scripts. + file_args: If True then arguments to switch scripts will be a file name + containing a newline separated list of the items to switch. + verify: If True, run tests to ensure initial good/bad sets actually + produce a good/bad result. + prune_iterations: Max number of bad items to search for. + verbose: If True will print extra debug information to user. + resume: If True will resume using STATE_FILE. + + Returns: + 0 for success, error otherwise + """ + if resume: + bss = BinarySearchState.LoadState() + if not bss: + logger.GetLogger().LogOutput( + '%s is not a valid binary_search_tool state file, cannot resume!' % + STATE_FILE) + return 1 + else: + switch_to_good = _CanonicalizeScript(switch_to_good) + switch_to_bad = _CanonicalizeScript(switch_to_bad) + if test_setup_script: + test_setup_script = _CanonicalizeScript(test_setup_script) + test_script = _CanonicalizeScript(test_script) + get_initial_items = _CanonicalizeScript(get_initial_items) + incremental = not noincremental + + binary_search_perforce.verbose = verbose + + bss = BinarySearchState(get_initial_items, switch_to_good, switch_to_bad, + test_setup_script, test_script, incremental, prune, + iterations, prune_iterations, verify, file_args, + verbose) + bss.DoVerify() + + try: + bss.DoSearch() + bss.RemoveState() + logger.GetLogger().LogOutput('Total execution time: %s' % + bss.ElapsedTimeString()) + except Error as e: + logger.GetLogger().LogError(e) + return 1 + + return 0 + + +def Main(argv): + """The main function.""" + # Common initializations + + parser = argparse.ArgumentParser() + common.BuildArgParser(parser) + logger.GetLogger().LogOutput(' '.join(argv)) + options = parser.parse_args(argv) + + if not (options.get_initial_items and options.switch_to_good and + options.switch_to_bad and options.test_script) and not options.resume: + parser.print_help() + return 1 + + if options.resume: + logger.GetLogger().LogOutput('Resuming from %s' % STATE_FILE) + if len(argv) > 1: + logger.GetLogger().LogOutput(('Note: resuming from previous state, ' + 'ignoring given options and loading saved ' + 'options instead.')) + + # Get dictionary of all options + args = vars(options) + return Run(**args) + + +if __name__ == '__main__': + sys.exit(Main(sys.argv[1:])) |