diff --git a/synthtool/gcp/common.py b/synthtool/gcp/common.py index ff102a484..4f4532031 100644 --- a/synthtool/gcp/common.py +++ b/synthtool/gcp/common.py @@ -369,6 +369,17 @@ def _load_repo_metadata(metadata_file: str = "./.repo-metadata.json") -> Dict: def _get_default_branch_name(repository_name: str) -> str: + """Read the default branch name from the environment. + + First checks environment variable DEFAULT_BRANCH_PATH. If found, it + reads the contents of the file at DEFAULT_BRANCH_PATH and returns it. + + Then checks environment varabile DEFAULT_BRANCH, and returns it if found. + """ + default_branch_path = os.getenv("DEFAULT_BRANCH_PATH") + if default_branch_path: + return Path(default_branch_path).read_text().strip() + # This default should be switched to "main" once we've migrated # the majority of our repositories: return os.getenv("DEFAULT_BRANCH", "master") diff --git a/tests/test_common.py b/tests/test_common.py index 0e75185ca..446ec7c30 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synthtool.gcp.common import decamelize, _get_default_branch_name +import os +import tempfile from pathlib import Path +from unittest import mock + from pytest import raises -import os + import synthtool as s +from synthtool.gcp.common import _get_default_branch_name, decamelize + from . import util -from unittest import mock MOCK = Path(__file__).parent / "generationmock" template_dir = Path(__file__).parent.parent / "synthtool/gcp/templates" @@ -42,8 +46,17 @@ def test_handles_empty_string(): def test_get_default_branch(): - with mock.patch.dict(os.environ, {"DEFAULT_BRANCH": "main"}): - assert _get_default_branch_name("repo_name") == "main" + with mock.patch.dict(os.environ, {"DEFAULT_BRANCH": "chickens"}): + assert _get_default_branch_name("repo_name") == "chickens" + + +def test_get_default_branch_path(): + f = tempfile.NamedTemporaryFile("wt", delete=False) + fname = f.name + f.write("ducks\n") + f.close() + with mock.patch.dict(os.environ, {"DEFAULT_BRANCH_PATH": fname}): + assert _get_default_branch_name("repo_name") == "ducks" def test_py_samples_clientlib():