|
| 1 | +"""Repository rule to setup a torch repo.""" |
| 2 | + |
| 3 | +_BUILD_TEMPLATE = """ |
| 4 | +
|
| 5 | +load("@//bazel:torch_targets.bzl", "define_torch_targets") |
| 6 | +
|
| 7 | +package( |
| 8 | + default_visibility = [ |
| 9 | + "//visibility:public", |
| 10 | + ], |
| 11 | +) |
| 12 | +
|
| 13 | +define_torch_targets() |
| 14 | +""" |
| 15 | + |
| 16 | +def _get_url_basename(url): |
| 17 | + basename = url.rpartition("/")[2] |
| 18 | + |
| 19 | + # Starlark doesn't have any URL decode functions, so just approximate |
| 20 | + # one with the cases we see. |
| 21 | + return basename.replace("%2B", "+") |
| 22 | + |
| 23 | +def _torch_repo_impl(rctx): |
| 24 | + rctx.file("BUILD.bazel", _BUILD_TEMPLATE) |
| 25 | + |
| 26 | + env_torch_whl = rctx.os.environ.get("TORCH_WHL", "") |
| 27 | + |
| 28 | + urls = None |
| 29 | + local_path = None |
| 30 | + if env_torch_whl: |
| 31 | + if env_torch_whl.startswith("http"): |
| 32 | + urls = [env_torch_whl] |
| 33 | + else: |
| 34 | + local_path = rctx.path(env_torch_whl) |
| 35 | + else: |
| 36 | + root_workspace = rctx.path(Label("@@//:WORKSPACE")).dirname |
| 37 | + dist_dir = rctx.workspace_root.get_child(rctx.attr.dist_dir) |
| 38 | + |
| 39 | + if dist_dir.exists: |
| 40 | + for child in dist_dir.readdir(): |
| 41 | + # For lack of a better option, take the first match |
| 42 | + if child.basename.endswith(".whl"): |
| 43 | + local_path = child |
| 44 | + break |
| 45 | + |
| 46 | + if not local_path and not urls: |
| 47 | + fail(( |
| 48 | + "No torch wheel source configured:\n" + |
| 49 | + "* Set TORCH_WHL environment variable to a local path or URL.\n" + |
| 50 | + "* Or ensure the {dist_dir} directory is present with a torch wheel." + |
| 51 | + "\n" |
| 52 | + ).format( |
| 53 | + dist_dir = dist_dir, |
| 54 | + )) |
| 55 | + |
| 56 | + if local_path: |
| 57 | + whl_path = local_path |
| 58 | + if not whl_path.exists: |
| 59 | + fail("File not found: {}".format(whl_path)) |
| 60 | + |
| 61 | + # The dist/ directory is necessary for XLA's python_init_repositories |
| 62 | + # to discover the wheel and add it to requirements.txt |
| 63 | + rctx.symlink(whl_path, "dist/{}".format(whl_path.basename)) |
| 64 | + elif urls: |
| 65 | + whl_basename = _get_url_basename(urls[0]) |
| 66 | + |
| 67 | + # The dist/ directory is necessary for XLA's python_init_repositories |
| 68 | + # to discover the wheel and add it to requirements.txt |
| 69 | + whl_path = rctx.path("dist/{}".format(whl_basename)) |
| 70 | + result = rctx.download( |
| 71 | + url = urls, |
| 72 | + output = whl_path, |
| 73 | + ) |
| 74 | + if not result.success: |
| 75 | + fail("Failed to download: {}", urls) |
| 76 | + |
| 77 | + # Extract into the repo root. Also use .zip as the extension so that extract |
| 78 | + # recognizes the file type. |
| 79 | + # Use the whl basename so progress messages are more informative. |
| 80 | + whl_zip = whl_path.basename.replace(".whl", ".zip") |
| 81 | + rctx.symlink(whl_path, whl_zip) |
| 82 | + rctx.extract(whl_zip) |
| 83 | + rctx.delete(whl_zip) |
| 84 | + |
| 85 | +torch_repo = repository_rule( |
| 86 | + implementation = _torch_repo_impl, |
| 87 | + doc = """ |
| 88 | +Creates a repository with torch headers, shared libraries, and wheel |
| 89 | +for integration with Bazel. |
| 90 | +""", |
| 91 | + attrs = { |
| 92 | + "dist_dir": attr.string( |
| 93 | + doc = "Directory with a prebuilt torch wheel. Typically points to " + |
| 94 | + "a source checkout that built a torch wheel.", |
| 95 | + ), |
| 96 | + }, |
| 97 | + environ = ["TORCH_WHL"], |
| 98 | +) |
0 commit comments