diff options
author | David Lord <davidism@gmail.com> | 2020-06-22 10:12:46 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-22 10:12:46 -0700 |
commit | 9b718ed3d26fc2f80ab58f3249b517a65db9cc7b (patch) | |
tree | 350a8732bfa8c222e39c82640197b8f788df7a91 | |
parent | 5eea6e49e09971f844f7a3b6812c7be167ee04c9 (diff) | |
parent | cc792d8c918b44c3c6815cced07b0a334a2fed42 (diff) | |
download | jinja-9b718ed3d26fc2f80ab58f3249b517a65db9cc7b.tar.gz |
Merge pull request #1241 from MLH-Fellowship/macros_with_globals
Fix bug with imported macros and template globals
-rw-r--r-- | CHANGES.rst | 2 | ||||
-rw-r--r-- | src/jinja2/asyncsupport.py | 4 | ||||
-rw-r--r-- | src/jinja2/compiler.py | 4 | ||||
-rw-r--r-- | src/jinja2/environment.py | 19 | ||||
-rw-r--r-- | src/jinja2/runtime.py | 7 | ||||
-rw-r--r-- | tests/test_imports.py | 41 |
6 files changed, 70 insertions, 7 deletions
diff --git a/CHANGES.rst b/CHANGES.rst index 57de4ae5..d42f213c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -11,6 +11,8 @@ Unreleased - Remove code that was marked deprecated. - Use :pep:`451` API to load templates with :class:`~loaders.PackageLoader`. :issue:`1168` +- Fix a bug that caused imported macros to not have access to the + current template's globals. :issue:`688` Version 2.11.2 diff --git a/src/jinja2/asyncsupport.py b/src/jinja2/asyncsupport.py index 3aef7ad2..e46a85a3 100644 --- a/src/jinja2/asyncsupport.py +++ b/src/jinja2/asyncsupport.py @@ -116,10 +116,10 @@ async def get_default_module_async(self): def wrap_default_module(original_default_module): @internalcode - def _get_default_module(self): + def _get_default_module(self, ctx=None): if self.environment.is_async: raise RuntimeError("Template module attribute is unavailable in async mode") - return original_default_module(self) + return original_default_module(self, ctx) return _get_default_module diff --git a/src/jinja2/compiler.py b/src/jinja2/compiler.py index 045a3a88..abdbe6da 100644 --- a/src/jinja2/compiler.py +++ b/src/jinja2/compiler.py @@ -925,7 +925,7 @@ class CodeGenerator(NodeVisitor): elif self.environment.is_async: self.write("_get_default_module_async()") else: - self.write("_get_default_module()") + self.write("_get_default_module(context)") if frame.toplevel and not node.target.startswith("_"): self.writeline(f"context.exported_vars.discard({node.target!r})") @@ -944,7 +944,7 @@ class CodeGenerator(NodeVisitor): elif self.environment.is_async: self.write("_get_default_module_async()") else: - self.write("_get_default_module()") + self.write("_get_default_module(context)") var_names = [] discarded_names = [] diff --git a/src/jinja2/environment.py b/src/jinja2/environment.py index 3c93c484..556f7255 100644 --- a/src/jinja2/environment.py +++ b/src/jinja2/environment.py @@ -1120,7 +1120,24 @@ class Template: ) @internalcode - def _get_default_module(self): + def _get_default_module(self, ctx=None): + """If a context is passed in, this means that the template was + imported. Imported templates have access to the current template's + globals by default, but they can only be accessed via the context + during runtime. + + If there are new globals, we need to create a new + module because the cached module is already rendered and will not have + access to globals from the current context. This new module is not + cached as :attr:`_module` because the template can be imported elsewhere, + and it should have access to only the current template's globals. + """ + if ctx is not None: + globals = { + key: ctx.parent[key] for key in ctx.globals_keys - self.globals.keys() + } + if globals: + return self.make_module(globals) if self._module is not None: return self._module self._module = rv = self.make_module() diff --git a/src/jinja2/runtime.py b/src/jinja2/runtime.py index 00f1f59f..7b5925b1 100644 --- a/src/jinja2/runtime.py +++ b/src/jinja2/runtime.py @@ -97,7 +97,9 @@ def new_context( for key, value in locals.items(): if value is not missing: parent[key] = value - return environment.context_class(environment, parent, template_name, blocks) + return environment.context_class( + environment, parent, template_name, blocks, globals=globals + ) class TemplateReference: @@ -179,13 +181,14 @@ class Context(metaclass=ContextMeta): _legacy_resolve_mode = False _fast_resolve_mode = False - def __init__(self, environment, parent, name, blocks): + def __init__(self, environment, parent, name, blocks, globals=None): self.parent = parent self.vars = {} self.environment = environment self.eval_ctx = EvalContext(self.environment, name) self.exported_vars = set() self.name = name + self.globals_keys = set() if globals is None else set(globals) # create the initial mapping of blocks. Whenever template inheritance # takes place the runtime will update this mapping with the new blocks diff --git a/tests/test_imports.py b/tests/test_imports.py index 7a2bd942..054c9010 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -98,6 +98,47 @@ class TestImports: with pytest.raises(UndefinedError, match="does not export the requested name"): t.render() + def test_import_with_globals(self, test_env): + env = Environment( + loader=DictLoader( + { + "macros": "{% macro test() %}foo: {{ foo }}{% endmacro %}", + "test": "{% import 'macros' as m %}{{ m.test() }}", + "test1": "{% import 'macros' as m %}{{ m.test() }}", + } + ) + ) + tmpl = env.get_template("test", globals={"foo": "bar"}) + assert tmpl.render() == "foo: bar" + + tmpl = env.get_template("test1") + assert tmpl.render() == "foo: " + + def test_import_with_globals_override(self, test_env): + env = Environment( + loader=DictLoader( + { + "macros": "{% set foo = '42' %}{% macro test() %}" + "foo: {{ foo }}{% endmacro %}", + "test": "{% from 'macros' import test %}{{ test() }}", + } + ) + ) + tmpl = env.get_template("test", globals={"foo": "bar"}) + assert tmpl.render() == "foo: 42" + + def test_from_import_with_globals(self, test_env): + env = Environment( + loader=DictLoader( + { + "macros": "{% macro testing() %}foo: {{ foo }}{% endmacro %}", + "test": "{% from 'macros' import testing %}{{ testing() }}", + } + ) + ) + tmpl = env.get_template("test", globals={"foo": "bar"}) + assert tmpl.render() == "foo: bar" + class TestIncludes: def test_context_include(self, test_env): |