Skip to content

Commit 3006c8e

Browse files
committed
Give the lowered fn computation a more meaningful name
1 parent 4020bb5 commit 3006c8e

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

test/test_operations.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2640,12 +2640,14 @@ def test_api(self):
26402640

26412641
result = a + b
26422642

2643-
ctx = torch_xla._XLAC.lowering.LoweringContext()
2643+
ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
26442644
ctx.build([result])
26452645
hlo = ctx.hlo()
26462646
hlo_text = ctx.hlo_text()
2647-
self.assertTrue('opcode: "parameter"' in hlo_text)
2648-
self.assertTrue('opcode: "add"' in hlo_text)
2647+
self.assertIn('MyCustomName', hlo_text)
2648+
self.assertIn('opcode: "parameter"', hlo_text)
2649+
self.assertIn('opcode: "parameter"', hlo_text)
2650+
self.assertIn('opcode: "add"', hlo_text)
26492651
mapping = ctx.parameter_id_tensor_mapping()
26502652
self.assertEqual(len(mapping), 2)
26512653

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -978,10 +978,14 @@ void BuildProfilerSubmodule(py::module* m) {
978978

979979
class PyLoweringContext {
980980
public:
981-
PyLoweringContext() : PyLoweringContext(bridge::GetCurrentDevice()) {}
981+
PyLoweringContext()
982+
: PyLoweringContext("PyLoweringContext", bridge::GetCurrentDevice()) {}
982983

983-
PyLoweringContext(torch::lazy::BackendDevice device)
984-
: lowering_ctx("PyLoweringContext", device) {}
984+
PyLoweringContext(const std::string& name)
985+
: PyLoweringContext(name, bridge::GetCurrentDevice()) {}
986+
987+
PyLoweringContext(const std::string& name, torch::lazy::BackendDevice device)
988+
: lowering_ctx(name, device) {}
985989

986990
// Builds a HLO graph given a set of output tensors.
987991
void Build(std::vector<at::Tensor> tensors) {
@@ -1188,7 +1192,8 @@ void BuildLoweringContextSubmodule(py::module* m) {
11881192
py::class_<PyLoweringContext, std::unique_ptr<PyLoweringContext>>
11891193
lowering_context_class(lowering, "LoweringContext", py::module_local());
11901194

1191-
lowering_context_class.def(py::init<>())
1195+
lowering_context_class.def(py::init())
1196+
.def(py::init<std::string>())
11921197
.def("build", &PyLoweringContext::Build)
11931198
.def("buildforiloop", &PyLoweringContext::BuildForiLoop)
11941199
.def("hlo", &PyLoweringContext::GetHlo)

torch_xla/experimental/scan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor:
477477
y_len = len(fake_output_y)
478478
fn_outputs = fake_output_carry + fake_output_y
479479

480-
fn_ctx = torch_xla._XLAC.lowering.LoweringContext()
480+
fn_ctx = torch_xla._XLAC.lowering.LoweringContext("FnComputation")
481481
fn_ctx.set_name_string("fn_ctx")
482482
fn_ctx.build(list(fn_outputs))
483483
fn_hlo = fn_ctx.hlo()

0 commit comments

Comments
 (0)