diff options
-rw-r--r-- | git_updater.py | 4 | ||||
-rw-r--r-- | git_utils.py | 14 |
2 files changed, 16 insertions, 2 deletions
diff --git a/git_updater.py b/git_updater.py index 4c02821..ef320cb 100644 --- a/git_updater.py +++ b/git_updater.py @@ -66,8 +66,10 @@ class GitUpdater(base_updater.Updater): self._new_ver = updater_utils.get_latest_version(self._old_ver, tags) def _check_head(self) -> None: + branch = git_utils.get_default_branch(self._proj_path, + self.UPSTREAM_REMOTE_NAME) self._new_ver = git_utils.get_sha_for_branch( - self._proj_path, self.UPSTREAM_REMOTE_NAME + '/master') + self._proj_path, self.UPSTREAM_REMOTE_NAME + '/' + branch) def update(self) -> None: """Updates the package. diff --git a/git_utils.py b/git_utils.py index 1825c37..f96a600 100644 --- a/git_utils.py +++ b/git_utils.py @@ -99,7 +99,7 @@ def list_remote_branches(proj_path: Path, remote_name: str) -> List[str]: stripped = [line.strip() for line in lines] remote_path = remote_name + '/' return [ - line.lstrip(remote_path) for line in stripped + line[len(remote_path):] for line in stripped if line.startswith(remote_path) ] @@ -116,6 +116,18 @@ def list_remote_tags(proj_path: Path, remote_name: str) -> List[str]: return list(set(tags)) +def get_default_branch(proj_path: Path, remote_name: str) -> str: + """Gets the name of the upstream branch to use.""" + branches_to_try = ['master', 'main'] + remote_branches = list_remote_branches(proj_path, remote_name) + for branch in branches_to_try: + if branch in remote_branches: + return branch + # We couldn't find any of the branches we expected. + # Default to 'master', although nothing will work well. + return 'master' + + COMMIT_PATTERN = r'^[a-f0-9]{40}$' COMMIT_RE = re.compile(COMMIT_PATTERN) |