diff options
Diffstat (limited to 'catapult/common/py_utils/py_utils/discover.py')
-rw-r--r-- | catapult/common/py_utils/py_utils/discover.py | 191 |
1 files changed, 191 insertions, 0 deletions
diff --git a/catapult/common/py_utils/py_utils/discover.py b/catapult/common/py_utils/py_utils/discover.py new file mode 100644 index 00000000..09d5c5e2 --- /dev/null +++ b/catapult/common/py_utils/py_utils/discover.py @@ -0,0 +1,191 @@ +# Copyright 2012 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +import fnmatch +import importlib +import inspect +import os +import re +import sys + +from py_utils import camel_case + + +def DiscoverModules(start_dir, top_level_dir, pattern='*'): + """Discover all modules in |start_dir| which match |pattern|. + + Args: + start_dir: The directory to recursively search. + top_level_dir: The top level of the package, for importing. + pattern: Unix shell-style pattern for filtering the filenames to import. + + Returns: + list of modules. + """ + # start_dir and top_level_dir must be consistent with each other. + start_dir = os.path.realpath(start_dir) + top_level_dir = os.path.realpath(top_level_dir) + + modules = [] + sub_paths = list(os.walk(start_dir)) + # We sort the directories & file paths to ensure a deterministic ordering when + # traversing |top_level_dir|. + sub_paths.sort(key=lambda paths_tuple: paths_tuple[0]) + for dir_path, _, filenames in sub_paths: + # Sort the directories to walk recursively by the directory path. + filenames.sort() + for filename in filenames: + # Filter out unwanted filenames. + if filename.startswith('.') or filename.startswith('_'): + continue + if os.path.splitext(filename)[1] != '.py': + continue + if not fnmatch.fnmatch(filename, pattern): + continue + + # Find the module. + module_rel_path = os.path.relpath( + os.path.join(dir_path, filename), top_level_dir) + module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0]) + + # Import the module. + try: + # Make sure that top_level_dir is the first path in the sys.path in case + # there are naming conflict in module parts. + original_sys_path = sys.path[:] + sys.path.insert(0, top_level_dir) + module = importlib.import_module(module_name) + modules.append(module) + finally: + sys.path = original_sys_path + return modules + + +def AssertNoKeyConflicts(classes_by_key_1, classes_by_key_2): + for k in classes_by_key_1: + if k in classes_by_key_2: + assert classes_by_key_1[k] is classes_by_key_2[k], ( + 'Found conflicting classes for the same key: ' + 'key=%s, class_1=%s, class_2=%s' % ( + k, classes_by_key_1[k], classes_by_key_2[k])) + + +# TODO(dtu): Normalize all discoverable classes to have corresponding module +# and class names, then always index by class name. +def DiscoverClasses(start_dir, + top_level_dir, + base_class, + pattern='*', + index_by_class_name=True, + directly_constructable=False): + """Discover all classes in |start_dir| which subclass |base_class|. + + Base classes that contain subclasses are ignored by default. + + Args: + start_dir: The directory to recursively search. + top_level_dir: The top level of the package, for importing. + base_class: The base class to search for. + pattern: Unix shell-style pattern for filtering the filenames to import. + index_by_class_name: If True, use class name converted to + lowercase_with_underscores instead of module name in return dict keys. + directly_constructable: If True, will only return classes that can be + constructed without arguments + + Returns: + dict of {module_name: class} or {underscored_class_name: class} + """ + modules = DiscoverModules(start_dir, top_level_dir, pattern) + classes = {} + for module in modules: + new_classes = DiscoverClassesInModule( + module, base_class, index_by_class_name, directly_constructable) + # TODO(nednguyen): we should remove index_by_class_name once + # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied + # naming collisions to reduce the number of smoked benchmark tests. + # crbug.com/548652 + if index_by_class_name: + AssertNoKeyConflicts(classes, new_classes) + classes = dict(classes.items() + new_classes.items()) + return classes + + +# TODO(nednguyen): we should remove index_by_class_name once +# benchmark_smoke_unittest in chromium/src/tools/perf no longer relied +# naming collisions to reduce the number of smoked benchmark tests. +# crbug.com/548652 +def DiscoverClassesInModule(module, + base_class, + index_by_class_name=False, + directly_constructable=False): + """Discover all classes in |module| which subclass |base_class|. + + Base classes that contain subclasses are ignored by default. + + Args: + module: The module to search. + base_class: The base class to search for. + index_by_class_name: If True, use class name converted to + lowercase_with_underscores instead of module name in return dict keys. + + Returns: + dict of {module_name: class} or {underscored_class_name: class} + """ + classes = {} + for _, obj in inspect.getmembers(module): + # Ensure object is a class. + if not inspect.isclass(obj): + continue + # Include only subclasses of base_class. + if not issubclass(obj, base_class): + continue + # Exclude the base_class itself. + if obj is base_class: + continue + # Exclude protected or private classes. + if obj.__name__.startswith('_'): + continue + # Include only the module in which the class is defined. + # If a class is imported by another module, exclude those duplicates. + if obj.__module__ != module.__name__: + continue + + if index_by_class_name: + key_name = camel_case.ToUnderscore(obj.__name__) + else: + key_name = module.__name__.split('.')[-1] + if not directly_constructable or IsDirectlyConstructable(obj): + if key_name in classes and index_by_class_name: + assert classes[key_name] is obj, ( + 'Duplicate key_name with different objs detected: ' + 'key=%s, obj1=%s, obj2=%s' % (key_name, classes[key_name], obj)) + else: + classes[key_name] = obj + + return classes + + +def IsDirectlyConstructable(cls): + """Returns True if instance of |cls| can be construct without arguments.""" + assert inspect.isclass(cls) + if not hasattr(cls, '__init__'): + # Case |class A: pass|. + return True + if cls.__init__ is object.__init__: + # Case |class A(object): pass|. + return True + # Case |class (object):| with |__init__| other than |object.__init__|. + args, _, _, defaults = inspect.getargspec(cls.__init__) + if defaults is None: + defaults = () + # Return true if |self| is only arg without a default. + return len(args) == len(defaults) + 1 + + +_COUNTER = [0] + + +def _GetUniqueModuleName(): + _COUNTER[0] += 1 + return "module_" + str(_COUNTER[0]) |