summaryrefslogtreecommitdiff
path: root/python/helpers/pycharm/utrunner.py
blob: 1f11206b7d30980495744e09f920e52166e126d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import sys
import imp
import os

helpers_dir = os.getenv("PYCHARM_HELPERS_DIR", sys.path[0])
if sys.path[0] != helpers_dir:
    sys.path.insert(0, helpers_dir)

from tcunittest import TeamcityTestRunner
from nose_helper import TestLoader, ContextSuite
from pycharm_run_utils import import_system_module
from pycharm_run_utils import adjust_sys_path
from pycharm_run_utils import debug, getModuleName, PYTHON_VERSION_MAJOR

adjust_sys_path()

os = import_system_module("os")
re = import_system_module("re")

modules = {}

def loadSource(fileName):
  baseName = os.path.basename(fileName)
  moduleName = os.path.splitext(baseName)[0]

  # for users wanted to run unittests under django
  #because of django took advantage of module name
  settings_file = os.getenv('DJANGO_SETTINGS_MODULE')
  if settings_file and moduleName == "models":
    baseName = os.path.realpath(fileName)
    moduleName = ".".join((baseName.split(os.sep)[-2], "models"))

  if moduleName in modules and len(sys.argv[1:-1]) == 1: # add unique number to prevent name collisions
    cnt = 2
    prefix = moduleName
    while getModuleName(prefix, cnt) in modules:
      cnt += 1
    moduleName = getModuleName(prefix, cnt)
  debug("/ Loading " + fileName + " as " + moduleName)
  module = imp.load_source(moduleName, fileName)
  modules[moduleName] = module
  return module

def walkModules(modulesAndPattern, dirname, names):
  modules = modulesAndPattern[0]
  pattern = modulesAndPattern[1]
  prog_list = [re.compile(pat.strip()) for pat in pattern.split(',')]
  for name in names:
    for prog in prog_list:
      if name.endswith(".py") and prog.match(name):
        modules.append(loadSource(os.path.join(dirname, name)))

def loadModulesFromFolderRec(folder, pattern = "test.*"):
  modules = []
  if PYTHON_VERSION_MAJOR == 3:
    prog_list = [re.compile(pat.strip()) for pat in pattern.split(',')]
    for root, dirs, files in os.walk(folder):
      for name in files:
        for prog in prog_list:
          if name.endswith(".py") and prog.match(name):
            modules.append(loadSource(os.path.join(root, name)))
  else:   # actually for jython compatibility
    os.path.walk(folder, walkModules, (modules, pattern))

  return modules

testLoader = TestLoader()
all = ContextSuite()
pure_unittest = False

def setLoader(module):
  global testLoader, all
  try:
    module.__getattribute__('unittest2')
    import unittest2

    testLoader = unittest2.TestLoader()
    all = unittest2.TestSuite()
  except:
    pass

if __name__ == "__main__":
  arg = sys.argv[-1]
  if arg == "true":
    import unittest

    testLoader = unittest.TestLoader()
    all = unittest.TestSuite()
    pure_unittest = True

  options = {}
  for arg in sys.argv[1:-1]:
    arg = arg.strip()
    if len(arg) == 0:
      continue

    if arg.startswith("--"):
      options[arg[2:]] = True
      continue

    a = arg.split("::")
    if len(a) == 1:
      # From module or folder
      a_splitted = a[0].split(";")
      if len(a_splitted) != 1:
        # means we have pattern to match against
        if a_splitted[0].endswith(os.path.sep):
          debug("/ from folder " + a_splitted[0] + ". Use pattern: " + a_splitted[1])
          modules = loadModulesFromFolderRec(a_splitted[0], a_splitted[1])
      else:
        if a[0].endswith("/"):
          debug("/ from folder " + a[0])
          modules = loadModulesFromFolderRec(a[0])
        else:
          debug("/ from module " + a[0])
          modules = [loadSource(a[0])]

      for module in modules:
        all.addTests(testLoader.loadTestsFromModule(module))

    elif len(a) == 2:
      # From testcase
      debug("/ from testcase " + a[1] + " in " + a[0])
      module = loadSource(a[0])
      setLoader(module)

      if pure_unittest:
        all.addTests(testLoader.loadTestsFromTestCase(getattr(module, a[1])))
      else:
        all.addTests(testLoader.loadTestsFromTestClass(getattr(module, a[1])),
          getattr(module, a[1]))
    else:
      # From method in class or from function
      debug("/ from method " + a[2] + " in testcase " + a[1] + " in " + a[0])
      module = loadSource(a[0])
      setLoader(module)

      if a[1] == "":
        # test function, not method
        all.addTest(testLoader.makeTest(getattr(module, a[2])))
      else:
        testCaseClass = getattr(module, a[1])
        try:
          all.addTest(testCaseClass(a[2]))
        except:
          # class is not a testcase inheritor
          all.addTest(
            testLoader.makeTest(getattr(testCaseClass, a[2]), testCaseClass))

  debug("/ Loaded " + str(all.countTestCases()) + " tests")
  TeamcityTestRunner().run(all, **options)