aboutsummaryrefslogtreecommitdiff
path: root/catapult/common/py_utils/py_utils/discover.py
blob: a9333e2d7dd02d0061b49857e4c059790e1fed6f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Copyright 2012 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import fnmatch
import importlib
import inspect
import os
import re
import sys

from py_utils import camel_case


def DiscoverModules(start_dir, top_level_dir, pattern='*'):
  """Discover all modules in |start_dir| which match |pattern|.

  Args:
    start_dir: The directory to recursively search.
    top_level_dir: The top level of the package, for importing.
    pattern: Unix shell-style pattern for filtering the filenames to import.

  Returns:
    list of modules.
  """
  # start_dir and top_level_dir must be consistent with each other.
  start_dir = os.path.realpath(start_dir)
  top_level_dir = os.path.realpath(top_level_dir)

  modules = []
  sub_paths = list(os.walk(start_dir))
  # We sort the directories & file paths to ensure a deterministic ordering when
  # traversing |top_level_dir|.
  sub_paths.sort(key=lambda paths_tuple: paths_tuple[0])
  for dir_path, _, filenames in sub_paths:
    # Sort the directories to walk recursively by the directory path.
    filenames.sort()
    for filename in filenames:
      # Filter out unwanted filenames.
      if filename.startswith('.') or filename.startswith('_'):
        continue
      if os.path.splitext(filename)[1] != '.py':
        continue
      if not fnmatch.fnmatch(filename, pattern):
        continue

      # Find the module.
      module_rel_path = os.path.relpath(
          os.path.join(dir_path, filename), top_level_dir)
      module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0])

      # Import the module.
      try:
        # Make sure that top_level_dir is the first path in the sys.path in case
        # there are naming conflict in module parts.
        original_sys_path = sys.path[:]
        sys.path.insert(0, top_level_dir)
        module = importlib.import_module(module_name)
        modules.append(module)
      finally:
        sys.path = original_sys_path
  return modules


def AssertNoKeyConflicts(classes_by_key_1, classes_by_key_2):
  for k in classes_by_key_1:
    if k in classes_by_key_2:
      assert classes_by_key_1[k] is classes_by_key_2[k], (
          'Found conflicting classes for the same key: '
          'key=%s, class_1=%s, class_2=%s' % (
              k, classes_by_key_1[k], classes_by_key_2[k]))


# TODO(dtu): Normalize all discoverable classes to have corresponding module
# and class names, then always index by class name.
def DiscoverClasses(start_dir,
                    top_level_dir,
                    base_class,
                    pattern='*',
                    index_by_class_name=True,
                    directly_constructable=False):
  """Discover all classes in |start_dir| which subclass |base_class|.

  Base classes that contain subclasses are ignored by default.

  Args:
    start_dir: The directory to recursively search.
    top_level_dir: The top level of the package, for importing.
    base_class: The base class to search for.
    pattern: Unix shell-style pattern for filtering the filenames to import.
    index_by_class_name: If True, use class name converted to
        lowercase_with_underscores instead of module name in return dict keys.
    directly_constructable: If True, will only return classes that can be
        constructed without arguments

  Returns:
    dict of {module_name: class} or {underscored_class_name: class}
  """
  modules = DiscoverModules(start_dir, top_level_dir, pattern)
  classes = {}
  for module in modules:
    new_classes = DiscoverClassesInModule(
        module, base_class, index_by_class_name, directly_constructable)
    # TODO(crbug.com/548652): we should remove index_by_class_name once
    # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
    # naming collisions to reduce the number of smoked benchmark tests.
    if index_by_class_name:
      AssertNoKeyConflicts(classes, new_classes)
    classes = dict(list(classes.items()) + list(new_classes.items()))
  return classes


# TODO(crbug.com/548652): we should remove index_by_class_name once
# benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
# naming collisions to reduce the number of smoked benchmark tests.
def DiscoverClassesInModule(module,
                            base_class,
                            index_by_class_name=False,
                            directly_constructable=False):
  """Discover all classes in |module| which subclass |base_class|.

  Base classes that contain subclasses are ignored by default.

  Args:
    module: The module to search.
    base_class: The base class to search for.
    index_by_class_name: If True, use class name converted to
        lowercase_with_underscores instead of module name in return dict keys.

  Returns:
    dict of {module_name: class} or {underscored_class_name: class}
  """
  classes = {}
  for _, obj in inspect.getmembers(module):
    # Ensure object is a class.
    if not inspect.isclass(obj):
      continue
    # Include only subclasses of base_class.
    if not issubclass(obj, base_class):
      continue
    # Exclude the base_class itself.
    if obj is base_class:
      continue
    # Exclude protected or private classes.
    if obj.__name__.startswith('_'):
      continue
    # Include only the module in which the class is defined.
    # If a class is imported by another module, exclude those duplicates.
    if obj.__module__ != module.__name__:
      continue

    if index_by_class_name:
      key_name = camel_case.ToUnderscore(obj.__name__)
    else:
      key_name = module.__name__.split('.')[-1]
    if not directly_constructable or IsDirectlyConstructable(obj):
      if key_name in classes and index_by_class_name:
        assert classes[key_name] is obj, (
            'Duplicate key_name with different objs detected: '
            'key=%s, obj1=%s, obj2=%s' % (key_name, classes[key_name], obj))
      else:
        classes[key_name] = obj

  return classes


def IsDirectlyConstructable(cls):
  """Returns True if instance of |cls| can be construct without arguments."""
  assert inspect.isclass(cls)
  if not hasattr(cls, '__init__'):
    # Case |class A: pass|.
    return True
  if cls.__init__ is object.__init__:
    # Case |class A(object): pass|.
    return True
  # Case |class (object):| with |__init__| other than |object.__init__|.
  args, _, _, defaults = inspect.getargspec(cls.__init__)
  if defaults is None:
    defaults = ()
  # Return true if |self| is only arg without a default.
  return len(args) == len(defaults) + 1


_COUNTER = [0]


def _GetUniqueModuleName():
  _COUNTER[0] += 1
  return "module_" + str(_COUNTER[0])