diff options
author | Adrian Dole <adriandole@google.com> | 2022-06-17 21:46:03 +0000 |
---|---|---|
committer | Chromeos LUCI <chromeos-scoped@luci-project-accounts.iam.gserviceaccount.com> | 2022-06-27 16:50:27 +0000 |
commit | d31ac0b99b083f51dbb88031e8329f336d0a666b (patch) | |
tree | b125a9707551273736823425cf880244eb2ecd71 | |
parent | d0fe2198f2c4bb8aa769c26209d9423fb40b32ed (diff) | |
download | toolchain-utils-d31ac0b99b083f51dbb88031e8329f336d0a666b.tar.gz |
get_upstream_patch: Validate patch application
Currently, get_upstream_patch does not validate that a patch applies to
the current LLVM state.
Add validation before modifying PATCHES.json.
Move several functions into patch_utils to avoid depending on
patch_manager.
BUG=b:227216280
TEST=./get_upstream_patch.py --platform chromiumos --sha [patch SHA]
Change-Id: I97e7d401e7f8fc6d85dbfb9a310e4a77205ef444
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/toolchain-utils/+/3711269
Reviewed-by: Adrian Dole <adriandole@google.com>
Commit-Queue: Adrian Dole <adriandole@google.com>
Auto-Submit: Adrian Dole <adriandole@google.com>
Reviewed-by: Jordan Abrahams-Whitehead <ajordanr@google.com>
Tested-by: Adrian Dole <adriandole@google.com>
-rwxr-xr-x | llvm_tools/get_upstream_patch.py | 37 | ||||
-rwxr-xr-x | llvm_tools/llvm_patch_management.py | 6 | ||||
-rwxr-xr-x | llvm_tools/llvm_patch_management_unittest.py | 4 | ||||
-rwxr-xr-x | llvm_tools/patch_manager.py | 184 | ||||
-rwxr-xr-x | llvm_tools/patch_manager_unittest.py | 29 | ||||
-rw-r--r-- | llvm_tools/patch_utils.py | 161 | ||||
-rwxr-xr-x | llvm_tools/patch_utils_unittest.py | 24 |
7 files changed, 248 insertions, 197 deletions
diff --git a/llvm_tools/get_upstream_patch.py b/llvm_tools/get_upstream_patch.py index a2327c4d..b5b61153 100755 --- a/llvm_tools/get_upstream_patch.py +++ b/llvm_tools/get_upstream_patch.py @@ -12,6 +12,7 @@ from datetime import datetime import json import logging import os +from pathlib import Path import shlex import subprocess import sys @@ -21,6 +22,7 @@ import chroot import get_llvm_hash import git import git_llvm_rev +import patch_utils import update_chromeos_llvm_hash @@ -39,6 +41,36 @@ class CherrypickVersionError(ValueError): """A ValueError that highlights the cherry-pick is before the start_sha""" +class PatchApplicationError(ValueError): + """A ValueError indicating that a test patch application was unsuccessful""" + + +def validate_patch_application(llvm_dir: Path, svn_version: int, + patches_json_fp: Path, patch_props): + + start_sha = get_llvm_hash.GetGitHashFrom(llvm_dir, svn_version) + subprocess.run(['git', '-C', llvm_dir, 'checkout', start_sha], check=True) + + predecessor_apply_results = patch_utils.apply_all_from_json( + svn_version, llvm_dir, patches_json_fp, continue_on_failure=True) + + if predecessor_apply_results.failed_patches: + logging.error('Failed to apply patches from PATCHES.json:') + for p in predecessor_apply_results.failed_patches: + logging.error(f'Patch title: {p.title()}') + raise PatchApplicationError('Failed to apply patch from PATCHES.json') + + patch_entry = patch_utils.PatchEntry.from_dict(patches_json_fp.parent, + patch_props) + test_apply_result = patch_entry.test_apply(Path(llvm_dir)) + + if not test_apply_result: + logging.error('Could not apply requested patch') + logging.error(test_apply_result.failure_info()) + raise PatchApplicationError( + f'Failed to apply patch: {patch_props["metadata"]["title"]}') + + def add_patch(patches_json_path: str, patches_dir: str, relative_patches_dir: str, start_version: git_llvm_rev.Rev, llvm_dir: str, rev: t.Union[git_llvm_rev.Rev, str], sha: str, @@ -121,6 +153,11 @@ def add_patch(patches_json_path: str, patches_dir: str, 'until': end_vers, }, } + + with patch_utils.git_clean_context(Path(llvm_dir)): + validate_patch_application(Path(llvm_dir), start_version.number, + Path(patches_json_path), patch_props) + patches_json.append(patch_props) temp_file = patches_json_path + '.tmp' diff --git a/llvm_tools/llvm_patch_management.py b/llvm_tools/llvm_patch_management.py index b7ac1973..46ddb867 100755 --- a/llvm_tools/llvm_patch_management.py +++ b/llvm_tools/llvm_patch_management.py @@ -13,12 +13,14 @@ from __future__ import print_function import argparse import os -from failure_modes import FailureModes import chroot +from failure_modes import FailureModes import get_llvm_hash import patch_manager +import patch_utils import subprocess_helpers + # If set to `True`, then the contents of `stdout` after executing a command will # be displayed to the terminal. verbose = False @@ -228,7 +230,7 @@ def UpdatePackagesPatchMetadataFile(chroot_path, svn_version, # Make sure the patch metadata path is valid. _CheckPatchMetadataPath(patch_metadata_path) - patch_manager.CleanSrcTree(src_path) + patch_utils.clean_src_tree(src_path) # Get the patch results for the current package. patches_info = patch_manager.HandlePatches(svn_version, diff --git a/llvm_tools/llvm_patch_management_unittest.py b/llvm_tools/llvm_patch_management_unittest.py index 78d55259..52117c93 100755 --- a/llvm_tools/llvm_patch_management_unittest.py +++ b/llvm_tools/llvm_patch_management_unittest.py @@ -9,6 +9,7 @@ """Unit tests when creating the arguments for the patch manager.""" from __future__ import print_function + from collections import namedtuple import os import unittest @@ -18,6 +19,7 @@ from failure_modes import FailureModes import get_llvm_hash import llvm_patch_management import patch_manager +import patch_utils import subprocess_helpers @@ -217,7 +219,7 @@ class LlvmPatchManagementTest(unittest.TestCase): # Simulate `CleanSrcTree()` when successfully removed changes from the # worktree. - @mock.patch.object(patch_manager, 'CleanSrcTree') + @mock.patch.object(patch_utils, 'clean_src_tree') # Simulate `GetGitHashFrom()` when successfully retrieved the git hash # of the version passed in. @mock.patch.object(get_llvm_hash, diff --git a/llvm_tools/patch_manager.py b/llvm_tools/patch_manager.py index 82ba65d1..056757fe 100755 --- a/llvm_tools/patch_manager.py +++ b/llvm_tools/patch_manager.py @@ -6,15 +6,13 @@ """A manager for patches.""" import argparse -import contextlib -import dataclasses import enum import json import os from pathlib import Path import subprocess import sys -from typing import Any, Dict, IO, Iterable, List, Optional, Tuple +from typing import Any, Dict, IO, Iterable, List, Tuple from failure_modes import FailureModes import get_llvm_hash @@ -23,26 +21,6 @@ from subprocess_helpers import check_call from subprocess_helpers import check_output -@dataclasses.dataclass(frozen=True) -class PatchInfo: - """Holds info for a round of patch applications.""" - # str types are legacy. Patch lists should - # probably be PatchEntries, - applied_patches: List[patch_utils.PatchEntry] - failed_patches: List[patch_utils.PatchEntry] - # Can be deleted once legacy code is removed. - non_applicable_patches: List[str] - # Can be deleted once legacy code is removed. - disabled_patches: List[str] - # Can be deleted once legacy code is removed. - removed_patches: List[str] - # Can be deleted once legacy code is removed. - modified_metadata: Optional[str] - - def _asdict(self): - return dataclasses.asdict(self) - - class GitBisectionCode(enum.IntEnum): """Git bisection exit codes. @@ -382,18 +360,6 @@ def PerformBisection(src_path, good_commit, bad_commit, svn_version, return version -def CleanSrcTree(src_path): - """Cleans the source tree of the changes made in 'src_path'.""" - - reset_src_tree_cmd = ['git', '-C', src_path, 'reset', 'HEAD', '--hard'] - - check_output(reset_src_tree_cmd) - - clean_src_tree_cmd = ['git', '-C', src_path, 'clean', '-fd'] - - check_output(clean_src_tree_cmd) - - def SaveSrcTreeState(src_path): """Stashes the changes made so far to the source tree.""" @@ -414,81 +380,6 @@ def RestoreSrcTreeState(src_path, bad_commit_hash): check_output(get_changes_cmd) -def ApplyAllFromJson(svn_version: int, - llvm_src_dir: Path, - patches_json_fp: Path, - continue_on_failure: bool = False) -> PatchInfo: - """Attempt to apply some patches to a given LLVM source tree. - - This relies on a PATCHES.json file to be the primary way - the patches are applied. - - Args: - svn_version: LLVM Subversion revision to patch. - llvm_src_dir: llvm-project root-level source directory to patch. - patches_json_fp: Filepath to the PATCHES.json file. - continue_on_failure: Skip any patches which failed to apply, - rather than throw an Exception. - """ - with patches_json_fp.open(encoding='utf-8') as f: - patches = patch_utils.json_to_patch_entries(patches_json_fp.parent, f) - skipped_patches = [] - failed_patches = [] - applied_patches = [] - for pe in patches: - applied, failed_hunks = ApplySinglePatchEntry(svn_version, llvm_src_dir, - pe) - if applied: - applied_patches.append(pe) - continue - if failed_hunks is not None: - if continue_on_failure: - failed_patches.append(pe) - continue - else: - _PrintFailedPatch(pe, failed_hunks) - raise RuntimeError('failed to apply patch ' - f'{pe.patch_path()}: {pe.title()}') - # Didn't apply, didn't fail, it was skipped. - skipped_patches.append(pe) - return PatchInfo( - non_applicable_patches=skipped_patches, - applied_patches=applied_patches, - failed_patches=failed_patches, - disabled_patches=[], - removed_patches=[], - modified_metadata=None, - ) - - -def ApplySinglePatchEntry( - svn_version: int, - llvm_src_dir: Path, - pe: patch_utils.PatchEntry, - ignore_version_range: bool = False -) -> Tuple[bool, Optional[Dict[str, List[patch_utils.Hunk]]]]: - """Try to apply a single PatchEntry object. - - Returns: - Tuple where the first element indicates whether the patch applied, - and the second element is a faild hunk mapping from file name to lists of - hunks (if the patch didn't apply). - """ - # Don't apply patches outside of the version range. - if not ignore_version_range and not pe.can_patch_version(svn_version): - return False, None - # Test first to avoid making changes. - test_application = pe.test_apply(llvm_src_dir) - if not test_application: - return False, test_application.failed_hunks - # Now actually make changes. - application_result = pe.apply(llvm_src_dir) - if not application_result: - # This should be very rare/impossible. - return False, application_result.failed_hunks - return True, None - - def RemoveOldPatches(svn_version: int, llvm_src_dir: Path, patches_json_fp: Path): """Remove patches that don't and will never apply for the future. @@ -536,7 +427,7 @@ def UpdateVersionRanges(svn_version: int, llvm_src_dir: Path, f, ) modified_entries: List[patch_utils.PatchEntry] = [] - with _GitCleanContext(llvm_src_dir): + with patch_utils.git_clean_context(llvm_src_dir): for pe in patch_entries: test_result = pe.test_apply(llvm_src_dir) if not test_result: @@ -556,18 +447,6 @@ def UpdateVersionRanges(svn_version: int, llvm_src_dir: Path, f'for r{svn_version}') -def IsGitDirty(git_root_dir: Path) -> bool: - """Return whether the given git directory has uncommitted changes.""" - if not git_root_dir.is_dir(): - raise ValueError(f'git_root_dir {git_root_dir} is not a directory') - cmd = ['git', 'ls-files', '-m', '--other', '--exclude-standard'] - return (subprocess.run(cmd, - stdout=subprocess.PIPE, - check=True, - cwd=git_root_dir, - encoding='utf-8').stdout != '') - - def CheckPatchApplies(svn_version: int, llvm_src_dir: Path, patches_json_fp: Path, rel_patch_path: str) -> GitBisectionCode: @@ -590,7 +469,7 @@ def CheckPatchApplies(svn_version: int, llvm_src_dir: Path, patches_json_fp.parent, f, ) - with _GitCleanContext(llvm_src_dir): + with patch_utils.git_clean_context(llvm_src_dir): success, _, failed_patches = ApplyPatchAndPrior( svn_version, llvm_src_dir, @@ -635,7 +514,7 @@ def ApplyPatchAndPrior( # as patches can stack. for pe in patch_entries: is_patch_of_interest = pe.rel_patch_path == rel_patch_path - applied, failed_hunks = ApplySinglePatchEntry( + applied, failed_hunks = patch_utils.apply_single_patch_entry( svn_version, src_dir, pe, ignore_version_range=is_patch_of_interest) meant_to_apply = bool(failed_hunks) or is_patch_of_interest if is_patch_of_interest: @@ -658,36 +537,6 @@ def ApplyPatchAndPrior( 'Does it exist?') -def _PrintFailedPatch(pe: patch_utils.PatchEntry, - failed_hunks: Dict[str, List[patch_utils.Hunk]]): - """Print information about a single failing PatchEntry. - - Args: - pe: A PatchEntry that failed. - failed_hunks: Hunks for pe which failed as dict: - filepath: [Hunk...] - """ - print(f'Could not apply {pe.rel_patch_path}: {pe.title()}', file=sys.stderr) - for fp, hunks in failed_hunks.items(): - print(f'{fp}:', file=sys.stderr) - for h in hunks: - print( - f'- {pe.rel_patch_path} ' - f'l:{h.patch_hunk_lineno_begin}...{h.patch_hunk_lineno_end}', - file=sys.stderr) - - -@contextlib.contextmanager -def _GitCleanContext(git_root_dir: Path): - """Cleans up a git directory when the context exits.""" - if IsGitDirty(git_root_dir): - raise RuntimeError('Cannot setup clean context; git_root_dir is dirty') - try: - yield - finally: - CleanSrcTree(git_root_dir) - - def HandlePatches(svn_version, patch_metadata_file, filesdir_path, @@ -871,7 +720,7 @@ def HandlePatches(svn_version, # Need a clean source tree for `git bisect run` to avoid unnecessary # fails for patches. - CleanSrcTree(src_path) + patch_utils.clean_src_tree(src_path) print('\nStarting to bisect patch %s for SVN version %d:\n' % (os.path.basename( @@ -904,7 +753,7 @@ def HandlePatches(svn_version, UpdatePatchMetadataFile(patch_metadata_file, patch_file_contents) # Clear the changes made to the source tree by `git bisect run`. - CleanSrcTree(src_path) + patch_utils.clean_src_tree(src_path) if not continue_bisection: # Exiting program early because 'continue_bisection' is not set. @@ -952,7 +801,7 @@ def HandlePatches(svn_version, # Changes to the source tree need to be removed, otherwise some # patches may fail when applying the patch to the source tree when # `git bisect run` calls this script again. - CleanSrcTree(src_path) + patch_utils.clean_src_tree(src_path) # The last patch in the interval [0, N] failed to apply, so let # `git bisect run` know that the last patch (the patch that failed @@ -974,17 +823,18 @@ def HandlePatches(svn_version, # complain that the changes would need to be 'stashed' or 'removed' in # order to reset HEAD back to the bad commit's git hash, so HEAD will remain # on the last git hash used by `git bisect run`. - CleanSrcTree(src_path) + patch_utils.clean_src_tree(src_path) # NOTE: Exit code 0 is similar to `git bisect good`. sys.exit(0) - patch_info = PatchInfo(applied_patches=applied_patches, - failed_patches=failed_patches, - non_applicable_patches=non_applicable_patches, - disabled_patches=disabled_patches, - removed_patches=removed_patches, - modified_metadata=modified_metadata) + patch_info = patch_utils.PatchInfo( + applied_patches=applied_patches, + failed_patches=failed_patches, + non_applicable_patches=non_applicable_patches, + disabled_patches=disabled_patches, + removed_patches=removed_patches, + modified_metadata=modified_metadata) # Determine post actions after iterating through the patches. if mode == FailureModes.REMOVE_PATCHES: @@ -1007,7 +857,7 @@ def HandlePatches(svn_version, return patch_info -def PrintPatchResults(patch_info: PatchInfo): +def PrintPatchResults(patch_info: patch_utils.PatchInfo): """Prints the results of handling the patches of a package. Args: @@ -1049,7 +899,7 @@ def main(): args_output = GetCommandLineArgs() def _apply_all(args): - result = ApplyAllFromJson( + result = patch_utils.apply_all_from_json( svn_version=args.svn_version, llvm_src_dir=Path(args.src_path), patches_json_fp=Path(args.patch_metadata_file), diff --git a/llvm_tools/patch_manager_unittest.py b/llvm_tools/patch_manager_unittest.py index b77c3022..f74480c2 100755 --- a/llvm_tools/patch_manager_unittest.py +++ b/llvm_tools/patch_manager_unittest.py @@ -245,31 +245,8 @@ class PatchManagerTest(unittest.TestCase): for r, a in cases: _t(dirname, r, a) - def testIsGitDirty(self): - """Test if a git directory has uncommitted changes.""" - with tempfile.TemporaryDirectory( - prefix='patch_manager_unittest') as dirname: - dirpath = Path(dirname) - - def _run_h(cmd): - subprocess.run(cmd, cwd=dirpath, stdout=subprocess.DEVNULL, check=True) - - _run_h(['git', 'init']) - self.assertFalse(patch_manager.IsGitDirty(dirpath)) - test_file = dirpath / 'test_file' - test_file.touch() - self.assertTrue(patch_manager.IsGitDirty(dirpath)) - _run_h(['git', 'add', '.']) - _run_h(['git', 'commit', '-m', 'test']) - self.assertFalse(patch_manager.IsGitDirty(dirpath)) - test_file.touch() - self.assertFalse(patch_manager.IsGitDirty(dirpath)) - with test_file.open('w', encoding='utf-8'): - test_file.write_text('abc') - self.assertTrue(patch_manager.IsGitDirty(dirpath)) - @mock.patch('builtins.print') - @mock.patch.object(patch_manager, '_GitCleanContext') + @mock.patch.object(patch_utils, 'git_clean_context') def testCheckPatchApplies(self, _, mock_git_clean_context): """Tests whether we can apply a single patch for a given svn_version.""" mock_git_clean_context.return_value = mock.MagicMock() @@ -334,8 +311,8 @@ class PatchManagerTest(unittest.TestCase): def _harness2(version: int, application_func: Callable, expected: patch_manager.GitBisectionCode): with mock.patch.object( - patch_manager, - 'ApplySinglePatchEntry', + patch_utils, + 'apply_single_patch_entry', application_func, ): result = patch_manager.CheckPatchApplies( diff --git a/llvm_tools/patch_utils.py b/llvm_tools/patch_utils.py index cdf9f215..003990be 100644 --- a/llvm_tools/patch_utils.py +++ b/llvm_tools/patch_utils.py @@ -12,7 +12,7 @@ from pathlib import Path import re import subprocess import sys -from typing import Any, Dict, IO, List, Optional, Union +from typing import Any, Dict, IO, List, Optional, Tuple, Union CHECKED_FILE_RE = re.compile(r'^checking file\s+(.*)$') @@ -139,6 +139,17 @@ class PatchResult: def __bool__(self): return self.succeeded + def failure_info(self) -> str: + if self.succeeded: + return '' + s = '' + for file, hunks in self.failed_hunks.items(): + s += f'{file}:\n' + for h in hunks: + s += f'Lines {h.orig_start} to {h.orig_start + h.orig_hunk_len}\n' + s += '--------------------\n' + return s + @dataclasses.dataclass class PatchEntry: @@ -250,6 +261,26 @@ class PatchEntry: return self.metadata.get('title', '') +@dataclasses.dataclass(frozen=True) +class PatchInfo: + """Holds info for a round of patch applications.""" + # str types are legacy. Patch lists should + # probably be PatchEntries, + applied_patches: List[PatchEntry] + failed_patches: List[PatchEntry] + # Can be deleted once legacy code is removed. + non_applicable_patches: List[str] + # Can be deleted once legacy code is removed. + disabled_patches: List[str] + # Can be deleted once legacy code is removed. + removed_patches: List[str] + # Can be deleted once legacy code is removed. + modified_metadata: Optional[str] + + def _asdict(self): + return dataclasses.asdict(self) + + def json_to_patch_entries(workdir: Path, json_fd: IO[str]) -> List[PatchEntry]: """Convert a json IO object to List[PatchEntry]. @@ -258,3 +289,131 @@ def json_to_patch_entries(workdir: Path, json_fd: IO[str]) -> List[PatchEntry]: >>> patch_entries = json_to_patch_entries(Path(), f) """ return [PatchEntry.from_dict(workdir, d) for d in json.load(json_fd)] + + +def _print_failed_patch(pe: PatchEntry, failed_hunks: Dict[str, List[Hunk]]): + """Print information about a single failing PatchEntry. + + Args: + pe: A PatchEntry that failed. + failed_hunks: Hunks for pe which failed as dict: + filepath: [Hunk...] + """ + print(f'Could not apply {pe.rel_patch_path}: {pe.title()}', file=sys.stderr) + for fp, hunks in failed_hunks.items(): + print(f'{fp}:', file=sys.stderr) + for h in hunks: + print( + f'- {pe.rel_patch_path} ' + f'l:{h.patch_hunk_lineno_begin}...{h.patch_hunk_lineno_end}', + file=sys.stderr) + + +def apply_all_from_json(svn_version: int, + llvm_src_dir: Path, + patches_json_fp: Path, + continue_on_failure: bool = False) -> PatchInfo: + """Attempt to apply some patches to a given LLVM source tree. + + This relies on a PATCHES.json file to be the primary way + the patches are applied. + + Args: + svn_version: LLVM Subversion revision to patch. + llvm_src_dir: llvm-project root-level source directory to patch. + patches_json_fp: Filepath to the PATCHES.json file. + continue_on_failure: Skip any patches which failed to apply, + rather than throw an Exception. + """ + with patches_json_fp.open(encoding='utf-8') as f: + patches = json_to_patch_entries(patches_json_fp.parent, f) + skipped_patches = [] + failed_patches = [] + applied_patches = [] + for pe in patches: + applied, failed_hunks = apply_single_patch_entry(svn_version, llvm_src_dir, + pe) + if applied: + applied_patches.append(pe) + continue + if failed_hunks is not None: + if continue_on_failure: + failed_patches.append(pe) + continue + else: + _print_failed_patch(pe, failed_hunks) + raise RuntimeError('failed to apply patch ' + f'{pe.patch_path()}: {pe.title()}') + # Didn't apply, didn't fail, it was skipped. + skipped_patches.append(pe) + return PatchInfo( + non_applicable_patches=skipped_patches, + applied_patches=applied_patches, + failed_patches=failed_patches, + disabled_patches=[], + removed_patches=[], + modified_metadata=None, + ) + + +def apply_single_patch_entry( + svn_version: int, + llvm_src_dir: Path, + pe: PatchEntry, + ignore_version_range: bool = False +) -> Tuple[bool, Optional[Dict[str, List[Hunk]]]]: + """Try to apply a single PatchEntry object. + + Returns: + Tuple where the first element indicates whether the patch applied, + and the second element is a faild hunk mapping from file name to lists of + hunks (if the patch didn't apply). + """ + # Don't apply patches outside of the version range. + if not ignore_version_range and not pe.can_patch_version(svn_version): + return False, None + # Test first to avoid making changes. + test_application = pe.test_apply(llvm_src_dir) + if not test_application: + return False, test_application.failed_hunks + # Now actually make changes. + application_result = pe.apply(llvm_src_dir) + if not application_result: + # This should be very rare/impossible. + return False, application_result.failed_hunks + return True, None + + +def is_git_dirty(git_root_dir: Path) -> bool: + """Return whether the given git directory has uncommitted changes.""" + if not git_root_dir.is_dir(): + raise ValueError(f'git_root_dir {git_root_dir} is not a directory') + cmd = ['git', 'ls-files', '-m', '--other', '--exclude-standard'] + return (subprocess.run(cmd, + stdout=subprocess.PIPE, + check=True, + cwd=git_root_dir, + encoding='utf-8').stdout != '') + + +def clean_src_tree(src_path): + """Cleans the source tree of the changes made in 'src_path'.""" + + reset_src_tree_cmd = ['git', '-C', src_path, 'reset', 'HEAD', '--hard'] + + subprocess.run(reset_src_tree_cmd, check=True) + + clean_src_tree_cmd = ['git', '-C', src_path, 'clean', '-fd'] + + subprocess.run(clean_src_tree_cmd, check=True) + + +@contextlib.contextmanager +def git_clean_context(git_root_dir: Path): + """Cleans up a git directory when the context exits.""" + if is_git_dirty(git_root_dir): + raise RuntimeError('Cannot setup clean context; git_root_dir is dirty') + try: + yield + finally: + clean_src_tree(git_root_dir) diff --git a/llvm_tools/patch_utils_unittest.py b/llvm_tools/patch_utils_unittest.py index 3a6409b9..f73ee751 100755 --- a/llvm_tools/patch_utils_unittest.py +++ b/llvm_tools/patch_utils_unittest.py @@ -7,6 +7,7 @@ import io from pathlib import Path +import subprocess import tempfile import unittest import unittest.mock as mock @@ -159,6 +160,29 @@ Hunk #1 SUCCEEDED at 96 with fuzz 1. self.assertEqual(result['x/y/z.h'], [4]) self.assertNotIn('works.cpp', result) + def test_is_git_dirty(self): + """Test if a git directory has uncommitted changes.""" + with tempfile.TemporaryDirectory( + prefix='patch_utils_unittest') as dirname: + dirpath = Path(dirname) + + def _run_h(cmd): + subprocess.run(cmd, cwd=dirpath, stdout=subprocess.DEVNULL, check=True) + + _run_h(['git', 'init']) + self.assertFalse(pu.is_git_dirty(dirpath)) + test_file = dirpath / 'test_file' + test_file.touch() + self.assertTrue(pu.is_git_dirty(dirpath)) + _run_h(['git', 'add', '.']) + _run_h(['git', 'commit', '-m', 'test']) + self.assertFalse(pu.is_git_dirty(dirpath)) + test_file.touch() + self.assertFalse(pu.is_git_dirty(dirpath)) + with test_file.open('w', encoding='utf-8'): + test_file.write_text('abc') + self.assertTrue(pu.is_git_dirty(dirpath)) + @staticmethod def _default_json_dict(): return { |