aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwbond <will@wbond.net>2019-12-29 01:30:34 -0500
committerwbond <will@wbond.net>2019-12-29 01:30:34 -0500
commit4137cda82528d36dd8813bfd58e9092a4c14b45b (patch)
tree413f993d3c8a73f516d08033958d87366c2327f6
parent012cdc8554cadc55b17e3eea8cc7982b1269cbba (diff)
downloadasn1crypto-4137cda82528d36dd8813bfd58e9092a4c14b45b.tar.gz
Switch to reusable run.py
-rw-r--r--dev/__init__.py2
-rw-r--r--dev/_import.py12
-rw-r--r--dev/_task.py163
-rw-r--r--dev/tests.py13
-rw-r--r--dev/version.py8
-rw-r--r--run.py67
6 files changed, 197 insertions, 68 deletions
diff --git a/dev/__init__.py b/dev/__init__.py
index 02e9c6c..a4b5cb4 100644
--- a/dev/__init__.py
+++ b/dev/__init__.py
@@ -15,6 +15,8 @@ other_packages = [
"ocspbuilder"
]
+task_keyword_args = []
+
requires_oscrypto = False
has_tests_package = True
diff --git a/dev/_import.py b/dev/_import.py
index 1041968..680e7d1 100644
--- a/dev/_import.py
+++ b/dev/_import.py
@@ -13,7 +13,7 @@ else:
getcwd = os.getcwd
-def _import_from(mod, path, mod_dir=None):
+def _import_from(mod, path, mod_dir=None, allow_error=False):
"""
Imports a module from a specific path
@@ -27,23 +27,29 @@ def _import_from(mod, path, mod_dir=None):
If the sub directory of "path" is different than the "mod" name,
pass the sub directory as a unicode string
+ :param allow_error:
+ If an ImportError should be raised when the module can't be imported
+
:return:
None if not loaded, otherwise the module
"""
if mod_dir is None:
- mod_dir = mod
+ mod_dir = mod.replace('.', os.sep)
if not os.path.exists(path):
return None
- if not os.path.exists(os.path.join(path, mod_dir)):
+ if not os.path.exists(os.path.join(path, mod_dir)) \
+ and not os.path.exists(os.path.join(path, mod_dir + '.py')):
return None
try:
mod_info = imp.find_module(mod_dir, [path])
return imp.load_module(mod, *mod_info)
except ImportError:
+ if allow_error:
+ raise
return None
diff --git a/dev/_task.py b/dev/_task.py
new file mode 100644
index 0000000..5cc257a
--- /dev/null
+++ b/dev/_task.py
@@ -0,0 +1,163 @@
+# coding: utf-8
+from __future__ import unicode_literals, division, absolute_import, print_function
+
+import ast
+import _ast
+import os
+import sys
+
+from . import package_root, task_keyword_args
+from ._import import _import_from
+
+
+if sys.version_info < (3,):
+ byte_cls = str
+else:
+ byte_cls = bytes
+
+
+def _list_tasks():
+ """
+ Fetches a list of all valid tasks that may be run, and the args they
+ accept. Does not actually import the task module to prevent errors if a
+ user does not have the dependencies installed for every task.
+
+ :return:
+ A list of 2-element tuples:
+ 0: a unicode string of the task name
+ 1: a list of dicts containing the parameter definitions
+ """
+
+ out = []
+ dev_path = os.path.join(package_root, 'dev')
+ for fname in sorted(os.listdir(dev_path)):
+ if fname.startswith('.') or fname.startswith('_'):
+ continue
+ if not fname.endswith('.py'):
+ continue
+ name = fname[:-3]
+ args = ()
+
+ full_path = os.path.join(package_root, 'dev', fname)
+ with open(full_path, 'rb') as f:
+ full_code = f.read()
+ if sys.version_info >= (3,):
+ full_code = full_code.decode('utf-8')
+
+ task_node = ast.parse(full_code, filename=full_path)
+ for node in ast.iter_child_nodes(task_node):
+ if isinstance(node, _ast.Assign):
+ if len(node.targets) == 1 \
+ and isinstance(node.targets[0], _ast.Name) \
+ and node.targets[0].id == 'run_args':
+ args = ast.literal_eval(node.value)
+ break
+
+ out.append((name, args))
+ return out
+
+
+def show_usage():
+ """
+ Prints to stderr the valid options for invoking tasks
+ """
+
+ valid_tasks = []
+ for task in _list_tasks():
+ usage = task[0]
+ for run_arg in task[1]:
+ usage += ' '
+ name = run_arg.get('name', '')
+ if run_arg.get('required', False):
+ usage += '{%s}' % name
+ else:
+ usage += '[%s]' % name
+ valid_tasks.append(usage)
+
+ out = 'Usage: run.py'
+ for karg in task_keyword_args:
+ out += ' [%s=%s]' % (karg['name'], karg['placeholder'])
+ out += ' (%s)' % ' | '.join(valid_tasks)
+
+ print(out, file=sys.stderr)
+ sys.exit(1)
+
+
+def _get_arg(num):
+ """
+ :return:
+ A unicode string of the requested command line arg
+ """
+
+ if len(sys.argv) < num + 1:
+ return None
+ arg = sys.argv[num]
+ if isinstance(arg, byte_cls):
+ arg = arg.decode('utf-8')
+ return arg
+
+
+def run_task():
+ """
+ Parses the command line args, invoking the requested task
+ """
+
+ arg_num = 1
+ task = None
+ args = []
+ kwargs = {}
+
+ # We look for the task name, processing any global task keyword args
+ # by setting the appropriate env var
+ while True:
+ val = _get_arg(arg_num)
+ if val is None:
+ break
+
+ next_arg = False
+ for karg in task_keyword_args:
+ if val.startswith(karg['name'] + '='):
+ os.environ[karg['env_var']] = val[len(karg['name']) + 1:]
+ next_arg = True
+ break
+
+ if next_arg:
+ arg_num += 1
+ continue
+
+ task = val
+ break
+
+ if task is None:
+ show_usage()
+
+ task_mod = _import_from('dev.%s' % task, package_root, allow_error=True)
+ if task_mod is None:
+ show_usage()
+
+ run_args = task_mod.__dict__.get('run_args', [])
+ max_args = arg_num + 1 + len(run_args)
+
+ if len(sys.argv) > max_args:
+ show_usage()
+
+ for i, run_arg in enumerate(run_args):
+ val = _get_arg(arg_num + 1 + i)
+ if val is None:
+ if run_arg.get('required', False):
+ show_usage()
+ break
+
+ if run_arg.get('cast') == 'int' and val.isdigit():
+ val = int(val)
+
+ kwarg = run_arg.get('kwarg')
+ if kwarg:
+ kwargs[kwarg] = val
+ else:
+ args.append(val)
+
+ run = task_mod.__dict__.get('run')
+
+ result = run(*args, **kwargs)
+ sys.exit(int(not result))
diff --git a/dev/tests.py b/dev/tests.py
index 5deb8cc..101c691 100644
--- a/dev/tests.py
+++ b/dev/tests.py
@@ -18,6 +18,19 @@ else:
from io import StringIO
+run_args = [
+ {
+ 'name': 'regex',
+ 'kwarg': 'matcher',
+ },
+ {
+ 'name': 'repeat_count',
+ 'kwarg': 'repeat',
+ 'cast': 'int',
+ },
+]
+
+
def run(matcher=None, repeat=1, ci=False):
"""
Runs the tests
diff --git a/dev/version.py b/dev/version.py
index 3027431..fe37d3d 100644
--- a/dev/version.py
+++ b/dev/version.py
@@ -8,6 +8,14 @@ import re
from . import package_root, package_name, has_tests_package
+run_args = [
+ {
+ 'name': 'pep440_version',
+ 'required': True
+ },
+]
+
+
def run(new_version):
"""
Updates the package version in the various locations
diff --git a/run.py b/run.py
index 64666d9..2f53221 100644
--- a/run.py
+++ b/run.py
@@ -2,70 +2,7 @@
# coding: utf-8
from __future__ import unicode_literals, division, absolute_import, print_function
-import sys
+from dev._task import run_task
-if sys.version_info < (3,):
- byte_cls = str
-else:
- byte_cls = bytes
-
-def show_usage():
- print('Usage: run.py (lint | tests [regex] | coverage | deps | ci | version {pep440_version} | build | release)', file=sys.stderr)
- sys.exit(1)
-
-
-def get_arg(num):
- if len(sys.argv) < num + 1:
- return None, num
- arg = sys.argv[num]
- if isinstance(arg, byte_cls):
- arg = arg.decode('utf-8')
- return arg, num + 1
-
-
-if len(sys.argv) < 2 or len(sys.argv) > 3:
- show_usage()
-
-task, next_arg = get_arg(1)
-
-if task not in set(['lint', 'tests', 'coverage', 'deps', 'ci', 'version', 'build', 'release']):
- show_usage()
-
-if task != 'tests' and task != 'version' and len(sys.argv) == 3:
- show_usage()
-
-params = []
-if task == 'lint':
- from dev.lint import run
-
-elif task == 'tests':
- from dev.tests import run
- matcher, next_arg = get_arg(next_arg)
- if matcher:
- params.append(matcher)
-
-elif task == 'coverage':
- from dev.coverage import run
-
-elif task == 'deps':
- from dev.deps import run
-
-elif task == 'ci':
- from dev.ci import run
-
-elif task == 'version':
- from dev.version import run
- if len(sys.argv) != 3:
- show_usage()
- pep440_version, next_arg = get_arg(next_arg)
- params.append(pep440_version)
-
-elif task == 'build':
- from dev.build import run
-
-elif task == 'release':
- from dev.release import run
-
-result = run(*params)
-sys.exit(int(not result))
+run_task()