Skip to content

Commit f53e1ea

Browse files
committed
Revert "Update PyTorch and XLA pin. (#9668)"
This reverts commit 11590c1.
1 parent d291621 commit f53e1ea

18 files changed

+82
-1390
lines changed

.github/workflows/_tpu_ci.yml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,10 @@ jobs:
5151
pip install fsspec
5252
pip install rich
5353
54-
# Test dependencies
55-
pip install --upgrade protobuf
56-
pip install flax
57-
5854
# PyTorch/XLA Optional Dependencies
5955
# =================================
6056
#
61-
# Install `jax` and `libtpu` dependencies for pallas and TPU tests.
57+
# Install `JAX` and `libtpu` dependencies for pallas and TPU tests.
6258
#
6359
# Note that we might need to install pre-release versions of both, in
6460
# external artifact repositories.
@@ -74,6 +70,18 @@ jobs:
7470
pip install "$WHL[pallas]" --pre --index-url $INDEX --find-links $LINKS
7571
pip install "$WHL[tpu]" --pre --index-url $INDEX --find-links $LINKS
7672
73+
pip install --upgrade protobuf
74+
75+
# Flax Pin
76+
# ========
77+
#
78+
# Be careful when bumping the `flax` version, since it can cause tests that
79+
# depend on `jax` to start breaking.
80+
#
81+
# Newer `flax` versions might pull newer `jax` versions, which might be incompatible
82+
# with the current version of PyTorch/XLA.
83+
pip install flax==0.11.2
84+
7785
- name: Run Tests (${{ matrix.test_script }})
7886
if: inputs.has_code_changes == 'true'
7987
env:

.torch_commit

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
# 2025-09-29
2-
21fec65781bebe867faf209f89bb687ffd236ca4
1+
# 2025-09-17
2+
928ac57c2ab03f9f79376f9995553eea2e6f4ca8

BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
load("@python//:defs.bzl", "compile_pip_requirements")
12
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
2-
load("@rules_python//python:pip.bzl", "compile_pip_requirements")
33

44
compile_pip_requirements(
55
name = "requirements",

WORKSPACE

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ new_local_repository(
5252

5353
# To build PyTorch/XLA with a new revison of OpenXLA, update the xla_hash to
5454
# the openxla git commit hash and note the date of the commit.
55-
xla_hash = '9a9aa0e11e4fcda8d6a9c3267dca6776ddbdb0ca' # Committed on 2025-10-01.
55+
xla_hash = '92f7b5952dd585c5be17c9a5caad27407005b513' # Committed on 2025-08-15.
5656

5757
http_archive(
5858
name = "xla",
@@ -63,7 +63,6 @@ http_archive(
6363
patch_tool = "patch",
6464
patches = [
6565
"//openxla_patches:no_fortify.diff",
66-
"//openxla_patches:if_constexpr_static_assert.diff",
6766
],
6867
strip_prefix = "xla-" + xla_hash,
6968
urls = [

openxla_patches/if_constexpr_static_assert.diff

Lines changed: 0 additions & 40 deletions
This file was deleted.

setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,12 @@
112112

113113
USE_NIGHTLY = True # Whether to use nightly or stable libtpu and JAX.
114114

115-
_libtpu_version = '0.0.24'
116-
_libtpu_date = '20250929'
115+
_libtpu_version = '0.0.21'
116+
_libtpu_date = '20250813'
117117

118-
_jax_version = '0.8.0'
119-
_jaxlib_version = '0.8.0'
120-
_jax_date = '20251001' # Date for jax and jaxlib.
118+
_jax_version = '0.7.1'
119+
_jaxlib_version = '0.7.1'
120+
_jax_date = '20250813' # Date for jax and jaxlib.
121121

122122
_torchax_version = '0.0.7' # likely stay the same
123123

test/spmd/test_fsdp_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_fsdp_v2_basic(self):
5555
# Make sure optimization barrier is applied.
5656
hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
5757
self.assertIn(
58-
'opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.2',
58+
'opt-barrier.38 = (f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[16,64]{1,0}) %tuple.37',
5959
hlo)
6060

6161
# Make sure the model can execute without error.

test/spmd/test_xla_sharding.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def test_inplace_add_with_sharding(self):
638638
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))
639639
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xt])
640640
self.assertIn(
641-
'%custom-call.1 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.1), custom_call_target="Sharding", sharding=',
641+
'%custom-call.7 = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %add.6), custom_call_target="Sharding", sharding=',
642642
hlo)
643643

644644
# avoid calling xr.addressable_device_count here otherwise it will init the test
@@ -738,8 +738,7 @@ def test_xla_sharded_hlo_dump(self):
738738
partition_spec)
739739
xst2 = xst1 + 5
740740
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xst2.global_tensor])
741-
print(hlo)
742-
self.assertIn('%p1.1 = f32[1,8]{1,0} parameter(1), sharding', hlo)
741+
self.assertIn('%p1.3 = f32[1,8]{1,0} parameter(1), sharding', hlo)
743742
if torch_xla._XLAC._xla_get_auto_sharding():
744743
# scalar 5 should be implicitly replicated, so the pre-optimization HLO
745744
# shouldn't mark it with sharding.
@@ -854,13 +853,13 @@ def test_mark_sharding_ir(self):
854853
(0, 1))
855854
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
856855
self.assertIn(
857-
'%custom-call.1 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.1), custom_call_target="Sharding", sharding=',
856+
'%custom-call.7 = f32[1,128]{1,0} custom-call(f32[1,128]{1,0} %add.6), custom_call_target="Sharding", sharding=',
858857
hlo)
859858

860859
actual += 0
861860
hlo = torch_xla._XLAC._get_xla_tensors_hlo([actual.global_tensor])
862861
self.assertIn(
863-
'%add.3 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.1, f32[1,128]{1,0} %broadcast.3)',
862+
'%add.12 = f32[1,128]{1,0} add(f32[1,128]{1,0} %custom-call.9, f32[1,128]{1,0} %broadcast.11)',
864863
hlo)
865864

866865
self.assertTrue(torch.allclose(expected, actual.cpu()))
@@ -1169,7 +1168,7 @@ def test_backward_optimization_barrier(self):
11691168

11701169
hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
11711170
self.assertIn(
1172-
'%opt-barrier.1 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.2)',
1171+
'%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)',
11731172
hlo)
11741173

11751174
def test_mark_shard_scalar(self):
@@ -1226,7 +1225,7 @@ def test_spmd_full_to_shard_shape(self):
12261225

12271226
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
12281227
self.assertEqual(xx.shape, (8, 8 // self.n_devices))
1229-
self.assertIn(f'%custom-call.1 = f32[8,{8//self.n_devices}]{{1,0}}', hlo)
1228+
self.assertIn(f'%custom-call.2 = f32[8,{8//self.n_devices}]{{1,0}}', hlo)
12301229
self.assertIn(
12311230
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
12321231
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
@@ -1243,7 +1242,7 @@ def test_spmd_full_to_shard_shape(self):
12431242

12441243
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
12451244
self.assertEqual(xx.shape, (8, 4))
1246-
self.assertIn(f'%custom-call.1 = f32[8,4]{{1,0}}', hlo)
1245+
self.assertIn(f'%custom-call.2 = f32[8,4]{{1,0}}', hlo)
12471246
self.assertIn(
12481247
f'custom_call_target="SPMDFullToShardShape", sharding={{manual}}', hlo)
12491248
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{manual}")
@@ -1274,7 +1273,7 @@ def test_spmd_shard_to_full_shape(self):
12741273

12751274
hlo = torch_xla._XLAC._get_xla_tensors_hlo([xx])
12761275
self.assertEqual(xx.shape, x.shape)
1277-
self.assertIn('%custom-call.5 = f32[8,8]{1,0}', hlo)
1276+
self.assertIn('%custom-call.9 = f32[8,8]{1,0}', hlo)
12781277
self.assertIn(
12791278
'custom_call_target="SPMDShardToFullShape", sharding={replicated}', hlo)
12801279
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xx), "{replicated}")
@@ -1325,7 +1324,7 @@ def test_spmd_reduce_scatter(self):
13251324

13261325
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13271326
self.assertIn(
1328-
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.1",
1327+
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3",
13291328
hlo)
13301329

13311330
expected_x = torch.ones(8 // self.n_devices, 8) * self.n_devices
@@ -1346,7 +1345,7 @@ def test_spmd_reduce_scatter_canonical_index(self):
13461345

13471346
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13481347
self.assertIn(
1349-
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.1",
1348+
f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3",
13501349
hlo)
13511350

13521351
expected_x = torch.ones(8, 8 // self.n_devices) * self.n_devices
@@ -1366,7 +1365,7 @@ def test_spmd_all_reduce(self):
13661365

13671366
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13681367
self.assertIn(
1369-
f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1",
1368+
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
13701369
hlo)
13711370

13721371
expected_x = torch.ones(8, 8) * self.n_devices
@@ -1387,7 +1386,7 @@ def test_spmd_all_reduce_scale(self):
13871386

13881387
hlo = torch_xla._XLAC._get_xla_tensors_hlo([x])
13891388
self.assertIn(
1390-
f"all-reduce(f32[8,8]{{1,0}} %custom-call.3), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.1",
1389+
f"all-reduce(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, to_apply=%AddComputation.3",
13911390
hlo)
13921391

13931392
expected_x = torch.ones(8, 8) * int(self.n_devices * scale)
@@ -1741,7 +1740,7 @@ def test_annotate_custom_sharding(self):
17411740
f'%p0.1 = f32[2,4,64,64]{{3,2,1,0}} parameter(0), sharding={original_sharding_spec}',
17421741
hlo)
17431742
self.assertIn(
1744-
f'%custom-call.1 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
1743+
f'%custom-call.2 = f32[2,4,64,64]{{3,2,1,0}} custom-call(f32[2,4,64,64]{{3,2,1,0}} %p0.1), custom_call_target="Sharding", sharding={custom_sharding_spec}',
17451744
hlo)
17461745
xm.mark_step()
17471746
# Ensure that the resulting sharding spec is preserved

torch_xla/csrc/lowering_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ class LoweringContext : public torch::lazy::LoweringContext {
124124
};
125125

126126
// Reports an XLA builder error for the given node.
127-
ABSL_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node,
128-
absl::string_view error_msg);
127+
TF_ATTRIBUTE_NORETURN void ReportBuilderError(const torch::lazy::Node& node,
128+
absl::string_view error_msg);
129129

130130
xla::XlaBuilder builder_;
131131
std::unordered_map<torch::lazy::BackendData::Handle, Parameter>

torch_xla/csrc/runtime/BUILD

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -382,34 +382,18 @@ cc_test(
382382
],
383383
)
384384

385-
cc_library(
386-
name = "tsl_platform_logging",
387-
srcs = ["tsl_platform_logging.cpp"],
388-
hdrs = ["tsl_platform_logging.h"],
389-
deps = [
390-
"@xla//xla/tsl/platform:env_time",
391-
"@xla//xla/tsl/platform:logging",
392-
"@xla//xla/tsl/platform:macros",
393-
"@xla//xla/tsl/platform:types",
394-
"@com_google_absl//absl/base:core_headers",
395-
"@com_google_absl//absl/base:log_severity",
396-
"@com_google_absl//absl/container:flat_hash_map",
397-
"@com_google_absl//absl/strings:str_format",
398-
"@com_google_absl//absl/strings:string_view",
399-
],
400-
)
401-
402385
cc_library(
403386
name = "tf_logging",
404387
srcs = ["tf_logging.cpp"],
405388
hdrs = ["tf_logging.h"],
406389
deps = [
407-
":tsl_platform_logging",
408390
"//torch_xla/csrc:status",
409391
"@torch//:headers",
410392
"@torch//:runtime_headers",
393+
"@tsl//tsl/platform:stacktrace",
394+
"@tsl//tsl/platform:statusor",
395+
"@xla//xla/service:platform_util",
411396
"@com_google_absl//absl/base:log_severity",
412-
"@com_google_absl//absl/log:absl_log",
413397
],
414398
)
415399

0 commit comments

Comments
 (0)