Skip to content

Commit ed212d7

Browse files
authored
Bazel torchxla (#4755)
* Migrate bazel torchxla * Remove sndboxing for coverage execution * Add missing files * Remove build_torch_xla_libs mentions. * Improve cache hits * Dont separately require configuring test building; they are free now. * format python * fix run_tests.sh * Pass test arguments via bazelrc * Merge tests into a single target due to grpc address in use issues. * Make testenv consistent for cache hits * Remove abi logic, it's all in setup.py now * Write both to log file and to output * Update depreacated property * add libpython to libs * Change test filter flag * Log file comment out for debugging * Minimize downloads from cache * Migrate to new bazel flag for exec propertieS * Cache silo for CI * set python version so that python3-config is found and used on circleci * use ci cache silos when building * simplify the silo flag * improve silos * Add conda init for tests * format py * hide the creds * remove conda activation * Setup conda library path * Try improving conda setup * Move the setup into bashrc * common * revert to old cache silo flag that allows overrides * ormat py * Revert to old style of specifying remote exec params * Add bes timeout * remove default silos key * Rebase on updates * pass in ld_lib_path to tests * Propagate XLA_EXPERIMENTAL to bazel * Support for cuda in tests * Pass the cuda flag to cpp tests. * remove cuda from deps of ptxla test since it's already in xla_client linked via xla_client:computation_client * Fix multiconfiguration issues for tests * Don't trim the tets config; test_filter remains * Copy the codegen directory to get the source in docker * Add libtpu to the wheel, and link accordingly * include buildextensions; that redefines some disttools classes. python sucks. * Update to cloud builder docker image and pass in the remote bazel flags * Setup silo and remote cache for cloudbuild * Set cache silo even with default creds * fix debug flag * Allow CXX_ABI flag to be set externally. * Set instrumentatoin filter to avoid tests * Document bazel * User might be root often so make sure docs are clear * format py * Remove gen_lazy_tensor; now under codegen/ * Update documentation * add coverage script * Update docs with remote bazel role in gcp * Update bazel docs * Enable remote cache for bazel in ansible. * Propagate default credentials to docker * Remove unused rpath settings * Upstream xla native functions * Don't make the build DEBUG just for coverage. * Avoid waiting for bes, which can be flaky * Remove build-only testing * Update xla native functions yaml * Adjust cpp coverage stuff * Use remote build for building tests. * Debug mode * Allow building tests * Pass the TPU config to bazel tests.
1 parent 74eff29 commit ed212d7

File tree

272 files changed

+2245
-902
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

272 files changed

+2245
-902
lines changed

Diff for: .bazelrc

+48-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
############################################################################
22
# All default build options below.
33

4-
# Enable exceptions in C++.
5-
common --copt=-fexceptions
6-
74
# Make Bazel print out all options from rc files.
85
build --announce_rc
96

@@ -28,9 +25,9 @@ build -c opt
2825
build --config=short_logs
2926

3027
# Force GCC because clang/bazel has issues.
31-
common --action_env=CC=gcc
32-
common --action_env=CXX=g++
33-
common --spawn_strategy=standalone
28+
build --action_env=CC=gcc
29+
build --action_env=CXX=g++
30+
build --spawn_strategy=standalone
3431

3532
###########################################################################
3633

@@ -63,7 +60,6 @@ build:acl --define==build_with_acl=true
6360
build:nonccl --define=no_nccl_support=true
6461

6562
build:linux --config=posix
66-
build:linux --copt=-Wno-unknown-warning-option
6763

6864
# Suppress all warning messages.
6965
build:short_logs --output_filter=DONT_MATCH_ANYTHING
@@ -75,6 +71,45 @@ build:tpu --define=with_tpu_support=true
7571
# RBE config options below.
7672
# Flag to enable remote config
7773
common --experimental_repo_remote_exec
74+
75+
# Inherit environmental variables that are used in testing.
76+
test --test_env=TPU_NUM_DEVICES --test_env=GPU_NUM_DEVICES --test_env=CPU_NUM_DEVICES --test_env=XRT_LOCAL_WORKER
77+
test --test_env=XRT_TPU_CONFIG --test_env=XRT_DEVICE_MAP --test_env=XRT_WORKERS --test_env=XRT_MESH_SERVICE_ADDRESS
78+
test --test_env=XRT_SHARD_WORLD_SIZE --test_env=XRT_MULTI_PROCESSING_DEVICE --test_env=XRT_HOST_ORDINAL --test_env=XRT_SHARD_ORDINAL
79+
test --test_env=XRT_START_LOCAL_SERVER --test_env=TPUVM_MODE --test_env=PJRT_DEVICE --test_env=PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS
80+
test --test_env=PJRT_CPU_ASYNC_CLIENT --test_env=PJRT_GPU_ASYNC_CLIENT --test_env=TPU_LIBRARY_PATH --test_env=PJRT_DIST_SERVICE_ADDR
81+
test --test_env=PJRT_LOCAL_PROCESS_RANK
82+
83+
# This environmental variable is important for properly integrating with XLA.
84+
test --test_env=XLA_EXPERIMENTAL
85+
86+
# To find `libpython` that is required to run tests (they run using installed wheels).
87+
test --test_env=LD_LIBRARY_PATH
88+
89+
# This fixes an issue where targets are configured differently because of `test_filter`.
90+
# https://github.com/bazelbuild/bazel/issues/6842
91+
test --notrim_test_configuration
92+
93+
# Stabilize the environmental variables used to minimize cache misses (src and env affects cache keys).
94+
build --incompatible_strict_action_env
95+
96+
# By default in local builds, do not upload local results to cache.
97+
build --noremote_upload_local_results
98+
99+
# Remote caching with local builds.
100+
build:remote_cache --remote_cache=grpcs://remotebuildexecution.googleapis.com
101+
build:remote_cache --remote_instance_name=projects/tpu-pytorch/instances/default_instance
102+
build:remote_cache --google_default_credentials
103+
build:remote_cache --remote_upload_local_results
104+
build:remote_cache --bes_backend=buildeventservice.googleapis.com
105+
build:remote_cache --bes_upload_mode=fully_async
106+
build:remote_cache --bes_results_url="https://source.cloud.google.com/results/invocations"
107+
build:remote_cache --bes_instance_name="tpu-pytorch"
108+
build:remote_cache --bes_timeout=600s # On longer builds, BES can cause a non-zero exit from bazel.
109+
110+
# Attempt to minimize the amount of data transfer between bazel and the remote
111+
# workers:
112+
build:remote_cache --remote_download_toplevel
78113
#########################################################################
79114

80115
# Load rc file with user-specific options.
@@ -84,17 +119,14 @@ try-import %workspace%/.bazelrc.user
84119
build:compdb --features=-layering_check
85120

86121
# Compiling tests requires Java.
87-
common --java_runtime_version=remotejdk_11
122+
build --java_runtime_version=remotejdk_11
88123

89124
# Coverage requires Java and GCC.
90125
coverage --config=coverage
91126
coverage --build_tests_only
92-
build:coverage --copt=-DNDEBUG
127+
coverage --instrumentation_filter="//torch_xla/,//third_party/"
93128
build:coverage --combined_report=lcov
94-
build:coverage --strategy=TestRunner=sandboxed,local
95129
build:coverage --strategy=CoverageReport=sandboxed,local
96-
build:coverage --experimental_use_llvm_covmap
97-
build:coverage --collect_code_coverage
98130
build:coverage --test_tag_filters=-nocoverage
99131

100132
############################################################################
@@ -175,3 +207,7 @@ build:linux --copt="-Wswitch"
175207
build:linux --copt="-Werror=switch"
176208
# Required for building with clang
177209
build:linux --copt="-Wno-error=unused-but-set-variable"
210+
211+
# Only include debug info for files in this repository, excluding external deps.
212+
build:dbg -c dbg
213+
build:dbg --per_file_copt=+external/.*@-g0,-DNDEBUG

Diff for: .circleci/build.sh

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ python setup.py install
4646
sccache --show-stats
4747

4848
source $XLA_DIR/xla_env
49+
export GCLOUD_SERVICE_KEY_FILE="$XLA_DIR/default_credentials.json"
50+
export SILO_NAME='cache-silo-ci' # cache bucket for CI
4951
build_torch_xla $XLA_DIR
5052

5153
popd

Diff for: .circleci/common.sh

+14-23
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ fi
1414
# 2. CONDA_PREFIX (if it exists)
1515
# 3. The conda install directory (if it exists)
1616
export CMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH:-${CONDA_PREFIX:-"$(dirname $(which conda))/../"}}
17+
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$(python3-config --prefix)/lib"
18+
echo $LD_LIBRARY_PATH
1719

1820
function clone_pytorch() {
1921
PYTORCH_DIR=$1
@@ -95,21 +97,6 @@ function install_deps_pytorch_xla() {
9597
if ls $CUBLAS_PATTERN 1> /dev/null 2>&1; then
9698
sudo ln -s $CUBLAS_PATTERN /usr/local/cuda/include
9799
fi
98-
99-
# Use cloud cache to build when available.
100-
if [[ "$USE_CACHE" == 1 ]]; then
101-
# Install bazels3cache for cloud cache
102-
sudo npm install -g n
103-
sudo n 16.18.0
104-
sudo npm install -g bazels3cache
105-
BAZELS3CACHE="$(which /usr/local/bin/bazels3cache)"
106-
if [ -z "${BAZELS3CACHE}" ]; then
107-
echo "Unable to find bazels3cache..."
108-
return 1
109-
fi
110-
/usr/local/bin/bazels3cache --bucket=${XLA_CLANG_CACHE_S3_BUCKET_NAME} --maxEntrySizeBytes=0 --logging.level=verbose
111-
sed -i '/bazel build/ a --remote_http_cache=http://localhost:7777 \\' $XLA_DIR/build_torch_xla_libs.sh
112-
fi
113100
}
114101

115102
function build_torch_xla() {
@@ -172,18 +159,22 @@ function run_torch_xla_tests() {
172159

173160
pushd test/cpp
174161
echo "Running C++ Tests on PJRT"
162+
EXTRA_ARGS=""
163+
if [ "$USE_COVERAGE" != "0" ]; then
164+
EXTRA_ARGS="-C"
165+
fi
166+
if [ ! -z "$GCLOUD_SERVICE_KEY_FILE" ]; then
167+
EXTRA_ARGS="-R"
168+
fi
175169
if [ -x "$(command -v nvidia-smi)" ]; then
176-
PJRT_DEVICE=GPU ./run_tests.sh
177-
PJRT_DEVICE=GPU ./run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L""
170+
PJRT_DEVICE=GPU ./run_tests.sh $EXTRA_ARGS
171+
PJRT_DEVICE=GPU ./run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS
178172
else
179-
PJRT_DEVICE=CPU ./run_tests.sh
173+
PJRT_DEVICE=CPU ./run_tests.sh $EXTRA_ARGS
180174
fi
181175
if [ "$USE_COVERAGE" != "0" ]; then
182-
export PATH=$PATH:/usr/lib/llvm-8/bin
183-
chmod +x /tmp/pytorch/xla/test/cpp/get_coverage.sh
184-
lcov --directory /tmp/pytorch/xla/build/temp.linux-x86_64-cpython-38/torch_xla/csrc --base-directory . --gcov-tool /tmp/pytorch/xla/test/cpp/get_coverage.sh --capture -o cpp_lcov.info
185-
genhtml cpp_lcov.info -o ~/htmlcov//cpp/cpp_lcov.info
186-
mv cpp_lcov.info ~/htmlcov/cpp_lcov.info
176+
genhtml .bazel-out/_coverage/_coverage_report.dat -o ~/htmlcov/cpp/cpp_lcov.info
177+
mv ./.bazel-out/_coverage/_coverage_report.dat ~/htmlcov/cpp_lcov.info
187178
fi
188179
popd
189180
popd

Diff for: .circleci/config.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ launch_docker_and_build: &launch_docker_and_build
4747
echo "declare -x CIRCLE_PROJECT_USERNAME=${CIRCLE_PROJECT_USERNAME}" >> /home/circleci/project/env
4848
echo "declare -x CIRCLE_PROJECT_REPONAME=${CIRCLE_PROJECT_REPONAME}" >> /home/circleci/project/env
4949
50-
# Set debug so that xla builds with coverage symbols
51-
echo "declare -x DEBUG=1" >> /home/circleci/project/env
50+
# Set up remote cache/build authentication.
51+
echo "declare -x BAZEL_REMOTE_CACHE=1" >> /home/circleci/project/xla_env
52+
(set +x; echo $GCLOUD_SERVICE_KEY > /home/circleci/project/default_credentials.json; set -x)
5253
5354
pid=$(docker run -t -d -w $WORKDIR ${GCR_DOCKER_IMAGE})
5455
docker cp /home/circleci/project/. "$pid:$WORKDIR"

Diff for: .circleci/docker/Dockerfile

+2-7
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@ ENV CUDA_PATH /usr/local/cuda
3434
ENV CC "${cc}"
3535
ENV CXX "${cxx}"
3636

37-
# Whether to build torch and torch_xla libraries with CXX ABI
38-
ENV _GLIBCXX_USE_CXX11_ABI "${cxx_abi}"
39-
ENV CFLAGS "${CFLAGS} -D_GLIBCXX_USE_CXX11_ABI=${cxx_abi}"
40-
ENV CXXFLAGS "${CXXFLAGS} -D_GLIBCXX_USE_CXX11_ABI=${cxx_abi}"
41-
4237
# Whether to build for TPUVM mode
4338
ENV TPUVM_MODE "${tpuvm}"
4439

@@ -49,8 +44,8 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
4944
# Install base system packages
5045
RUN apt-get clean && apt-get update
5146
RUN apt-get upgrade -y
52-
RUN apt-get install --fix-missing -y python3-pip git curl libopenblas-dev vim jq \
53-
apt-transport-https ca-certificates procps openssl sudo wget libssl-dev libc6-dbg
47+
RUN apt-get install --fix-missing -y python-pip python3-pip git curl libopenblas-dev vim jq \
48+
apt-transport-https ca-certificates procps openssl sudo wget libssl-dev libc6-dbg
5449

5550
# Install clang & llvm
5651
ADD ./install_llvm_clang.sh install_llvm_clang.sh

Diff for: .circleci/test.sh

+2
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,6 @@ function install_torchvision() {
2525

2626
install_torchvision
2727

28+
export GCLOUD_SERVICE_KEY_FILE="$XLA_DIR/default_credentials.json"
29+
export SILO_NAME='cache-silo-ci' # cache bucket for CI
2830
run_torch_xla_tests $PYTORCH_DIR $XLA_DIR $USE_COVERAGE

Diff for: .gitignore

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@ torch_xla/csrc/version.cpp
1010
*.pyc
1111
*.so
1212

13-
# Directory autogenerated by full_codegen
14-
torch_xla/csrc/generated/
15-
1613
# BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.)
1714
#
1815
# Below files are not deleted by "setup.py clean".
@@ -30,3 +27,6 @@ torch_xla/csrc/generated/
3027

3128
# Build system temporary files
3229
/bazel-*
30+
31+
# Clangd cache directory
32+
.cache/*

Diff for: .vscode/settings.json

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
"--config=compdb",
44
],
55
"bsv.cc.compdb.targets": [
6-
"//third_party/xla_client/...",
6+
"//third_party/xla_client:all",
7+
"//torch_xla/csrc:all",
8+
"//test/cpp:all",
79
],
810
"coverage-gutters.coverageBaseDir": ".",
911
"coverage-gutters.showLineCoverage": false,
@@ -13,8 +15,8 @@
1315
"./bazel-out/_coverage/_coverage_report.dat"
1416
],
1517
"lcov.path": [
16-
"./.bazel-out/_coverage/_coverage_report.dat"
18+
"./bazel-out/_coverage/_coverage_report.dat"
1719
],
1820
"python.formatting.provider": "yapf",
1921
"editor.formatOnSave": true
20-
}
22+
}

Diff for: BUILD

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
load(
2+
"@org_tensorflow//tensorflow:tensorflow.bzl",
3+
"tf_cc_shared_object",
4+
)
5+
6+
tf_cc_shared_object(
7+
name = "_XLAC.so",
8+
copts = [
9+
"-DTORCH_API_INCLUDE_EXTENSION_H",
10+
"-DTORCH_EXTENSION_NAME=_XLAC",
11+
"-fopenmp",
12+
"-fPIC",
13+
"-fwrapv",
14+
],
15+
linkopts = [
16+
"-Wl,-rpath,$$ORIGIN/torch_xla/lib", # for libtpu
17+
],
18+
visibility = ["//visibility:public"],
19+
deps = [
20+
"//third_party/xla_client:computation_client",
21+
"//third_party/xla_client:mesh_service",
22+
"//third_party/xla_client:metrics",
23+
"//third_party/xla_client:metrics_analysis",
24+
"//third_party/xla_client:metrics_reader",
25+
"//third_party/xla_client:multi_wait",
26+
"//third_party/xla_client:profiler",
27+
"//third_party/xla_client:record_reader",
28+
"//third_party/xla_client:sys_util",
29+
"//third_party/xla_client:thread_pool",
30+
"//third_party/xla_client:util",
31+
"//third_party/xla_client:xla_util",
32+
"//torch_xla/csrc:computation",
33+
"//torch_xla/csrc:device",
34+
"//torch_xla/csrc:init_python_bindings",
35+
"//torch_xla/csrc:tensor",
36+
"//torch_xla/csrc:version",
37+
"@com_google_absl//absl/container:flat_hash_map",
38+
"@com_google_absl//absl/strings",
39+
"@com_google_absl//absl/types:variant",
40+
"@org_tensorflow//tensorflow/compiler/xla/python/profiler/internal:traceme_wrapper",
41+
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_parser",
42+
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_pass_pipeline",
43+
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_verifier",
44+
"@org_tensorflow//tensorflow/compiler/xla/service:sharding_propagation",
45+
"@org_tensorflow//tensorflow/compiler/xla/service/spmd:spmd_partitioner",
46+
"@org_tensorflow//tensorflow/core",
47+
"@org_tensorflow//tensorflow/core:protos_all_cc",
48+
"@org_tensorflow//tensorflow/core/platform:env",
49+
"@org_tensorflow//tensorflow/core/profiler/lib:traceme",
50+
"@org_tensorflow//tensorflow/python/profiler/internal:profiler_pywrap_impl",
51+
"@torch//:headers",
52+
"@torch//:libc10",
53+
"@torch//:libtorch",
54+
"@torch//:libtorch_cpu",
55+
"@torch//:libtorch_python",
56+
],
57+
)

Diff for: CODEGEN_MIGRATION_GUIDE.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the e
3030
- Contains all the op XLA supported today. Most of the ops are under the supported category, the goal of this document is to move most of the ops to the full_codegen category.
3131
- xla/scripts/gen_lazy_tensor.py
3232
- Provides necessary XLA versions of the codegen Codegen class and calls the upstream codegen API.
33-
- xla/torch_xla/csrc/generated/XLANativeFunctions.cpp
34-
- Result of the full_codegen column of the xla/xla_native_functions.yaml. The op function defined here will implement the op declared in the XLANativeFunctions.h. Each op will take at::tensor and return another at::tensor wrapped around a XLATensor.
35-
- xla/torch_xla/csrc/generated/LazyIr.h
36-
- Result of the full_codegen column of the xla/xla_native_functions.yaml. Defines the IR that is used to construct the full_codegen ops.
33+
- xla/torch_xla/csrc/XLANativeFunctions.cpp
34+
- Result of the full_codegen column of the xla/codegen/xla_native_functions.yaml. The op function defined here will implement the op declared in the XLANativeFunctions.h. Each op will take at::tensor and return another at::tensor wrapped around a XLATensor.
35+
- xla/torch_xla/csrc/LazyIr.h
36+
- Result of the full_codegen column of the xla/codegen/xla_native_functions.yaml. Defines the IR that is used to construct the full_codegen ops.
3737

3838
### PyTorch/XLA Old Op Lowering files
3939
- xla/torch_xla/csrc/generated/aten_xla_type.cpp
40-
- Manually implements ops defined in xla/xla_native_functions.yaml. Will be replaced by XLANativeFunctions.cpp
40+
- Manually implements ops defined in xla/codegen/xla_native_functions.yaml. Will be replaced by XLANativeFunctions.cpp
4141
- xla/torch_xla/csrc/generated/tensor.h
4242
- Defines XLATensor class and XLATensor method declarations. These declarations are usually a one to one mapping of the at::Tensor nodes we declared in XLANativeFunctions.h. XLATensor method will be removed for full_codegen ops
4343
- xla/torch_xla/csrc/generated/tensor_method.cpp
@@ -76,7 +76,7 @@ at::Tensor XLANativeFunctions::abs(const at::Tensor& self) {
7676
```
7777

7878
### 2. Codegen the op and inspect the generated file
79-
Find the op in `xla/xla_native_functions.yaml` and move it to the full_codegen column and run `python setup.py install` under xla directory again. The build will fail (reason explained later in this guide) but you can still see the generated file. The code snippets below uses `abs` as an example.
79+
Find the op in `xla/codegen/xla_native_functions.yaml` and move it to the full_codegen column and run `python setup.py install` under xla directory again. The build will fail (reason explained later in this guide) but you can still see the generated file. The code snippets below uses `abs` as an example.
8080
#### XLANativeFunctions.cpp
8181
```
8282
at::Tensor XLANativeFunctions::abs(const at::Tensor & self) {

Diff for: OP_LOWERING_GUIDE.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export PJRT_DEVICE=CPU
1414
You can find the definition of the C++ ATen operations in [native_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml). After you build Pytorch/XLA from source, you will also find our default implementation (a boxed kernel which forwards calls to PyTorch native CPU) in `xla/torch_xla/csrc/aten_cpu_fallback.h/cpp`. Pytorch operations can usually be mapped to [PyTorch tensor api](https://pytorch.org/docs/stable/index.html) easily. If that is not the case searching the PyTorch native implementation under [PyTorch repo](https://github.com/pytorch/pytorch) is recommended. The goal is to lower the PyTorch operations into a sequence of XLA operations defined in [here](https://www.tensorflow.org/xla/operation_semantics).
1515

1616
## File structure
17-
All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the exception of `xla_native_functions.yaml`
17+
All file mentioned below lives under the `xla/torch_xla/csrc` folder, with the exception of `codegen/xla_native_functions.yaml`
1818

1919
1. `xla_native_functions.yaml` contains the list of all operators that are lowered. Each operator name must directly match a pytorch operator listed in [native_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml). This file serves as the interface to adding new xla operators, and is an input to PyTorch's [codegen machinery](https://github.com/pytorch/pytorch/blob/main/torchgen/gen_backend_stubs.py). It generates the below 3 files: `XLANativeFunctions.h`, `RegisterXLA.cpp`, and `RegisterAutogradXLA.cpp`
2020
2. `XLANativeFunctions.h` and `aten_xla_type.cpp` are entry points of PyTorch to the pytorch_xla world, and contain the manually written lowerings to XLA for each operator. `XLANativeFunctions.h` is auto-generated through a combination of `xla_native_functions.yaml` and the PyTorch core `native_functions.yaml` file, and contains declarations for kernels that need to be defined in `aten_xla_type.cpp`. The kernels written here need to construct 'XLATensor' using the input `at::Tensor` and other parameters. The resulting `XLATensor` needs to be converted back to the `at::Tensor` before returning to the PyTorch world.

Diff for: WORKSPACE

+35
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,30 @@
11
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
22

3+
################################ Python Setup ################################
4+
5+
# For embedded python interpreter (libpython.so.)
6+
http_archive(
7+
name = "pybind11_bazel",
8+
strip_prefix = "pybind11_bazel-fc56ce8a8b51e3dd941139d329b63ccfea1d304b",
9+
urls = ["https://github.com/pybind/pybind11_bazel/archive/fc56ce8a8b51e3dd941139d329b63ccfea1d304b.zip"],
10+
)
11+
12+
http_archive(
13+
name = "pybind11",
14+
build_file = "@pybind11_bazel//:pybind11.BUILD",
15+
strip_prefix = "pybind11-442261da585536521ff459b1457b2904895f23b4",
16+
urls = ["https://github.com/pybind/pybind11/archive/442261da585536521ff459b1457b2904895f23b4.tar.gz"],
17+
)
18+
19+
load("@pybind11_bazel//:python_configure.bzl", "python_configure")
20+
21+
# This is required for setting up the linkopts for -lpython.q
22+
python_configure(
23+
name = "local_config_python",
24+
python_version = "3", # required to use `python3-config`
25+
)
26+
############################# TensorFlow Setup ###############################
27+
328
# To update TensorFlow to a new revision,
429
# a) update URL and strip_prefix to the new git commit hash
530
# b) get the sha256 hash of the commit by running:
@@ -60,3 +85,13 @@ tf_workspace1()
6085
load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0")
6186

6287
tf_workspace0()
88+
89+
################################ PyTorch Setup ################################
90+
91+
load("//bazel:dependencies.bzl", "PYTORCH_LOCAL_DIR")
92+
93+
new_local_repository(
94+
name = "torch",
95+
build_file = "//bazel:torch.BUILD",
96+
path = PYTORCH_LOCAL_DIR,
97+
)

0 commit comments

Comments
 (0)