-
Notifications
You must be signed in to change notification settings - Fork 331
[Bugfix] Fallback torch.accelerator.synchronize() to torch.cuda.synchronize()
#987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughReplaces a post-invocation synchronization call in Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant B as Bench
participant T as PyTorch
participant C as CUDA
B->>T: Run profiled function
Note right of T: Post-call synchronization
T->>C: torch.cuda.synchronize()
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~2 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/profiler/bench.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-metal
tilelang/profiler/bench.py
Outdated
| from typing import Callable, List, Literal, Optional, Union | ||
|
|
||
| import torch | ||
| from packaging import version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Import looks good, but verify dependency.
The packaging library is commonly used for version parsing and the import is correct.
Run the following script to verify that packaging is listed as a dependency:
🏁 Script executed:
#!/bin/bash
# Check if packaging is listed in project dependencies
fd -t f 'requirements.*\.txt|setup\.py|setup\.cfg|pyproject\.toml' --exec grep -l 'packaging'Length of output: 123
🏁 Script executed:
#!/bin/bash
grep -R "packaging" setup.py pyproject.toml requirements*.txt || trueLength of output: 224
Ensure device-agnostic synchronization and API compatibility
- Line 161 still uses
torch.cuda.synchronize()without version or device check. do_benchhardcodes"cuda"and usestorch.cuda.Event; adapt for MPS and other targets.- Confirm that
torch.accelerator.synchronize()exists in your PyTorch ≥2.6.0 build.
🤖 Prompt for AI Agents
tilelang/profiler/bench.py lines 8-8: the file currently imports
packaging.version but later uses hardcoded torch.cuda synchronization and CUDA
events; update to be device-agnostic and API-safe by adding a small helper that
takes a torch.device (or device string) and synchronizes safely: first, if torch
has torch.accelerator.synchronize (check hasattr and ensure torch version ≥2.6
via packaging.version.parse(torch.__version__)), call it; else if device.type ==
"cuda" call torch.cuda.synchronize(); else if device.type == "mps" call
torch.mps.synchronize() only if available (hasattr(torch, "mps") and
hasattr(torch.mps, "synchronize")), otherwise no-op. Modify do_bench to accept a
device parameter (or infer torch.device(device)), use torch.cuda.Event only when
device.type == "cuda" and fall back to time.perf_counter for other devices while
surrounding timing with the helper synchronize calls; ensure all calls guarded
by hasattr/version checks to avoid AttributeError on builds without
accelerator/mps support.
tilelang/profiler/bench.py
Outdated
| if version.parse(torch.__version__) < version.parse("2.6.0"): | ||
| torch.cuda.synchronize() | ||
| else: | ||
| torch.accelerator.synchronize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Inconsistent synchronization and missing device type check.
The version-conditional synchronization has several issues:
-
Inconsistency: Line 161 in
_bench_with_cuda_events()still usestorch.cuda.synchronize()without the version check, creating inconsistent behavior across the benchmarking workflow. -
Missing device type check: The synchronization logic doesn't consider the device type. Lines 58-60 detect whether CUDA or MPS is available, but this new logic at lines 101-104 always calls
torch.cuda.synchronize()for versions < 2.6.0, even on non-CUDA devices like MPS. -
API verification needed: Need to confirm that
torch.accelerator.synchronize()is the correct API for PyTorch 2.6.0+ and that version 2.6.0 is the correct cutoff.
To verify the correct API and version, please search for the official PyTorch documentation:
To check for other synchronization calls that need updating, run:
Consider refactoring to:
- Update line 161 to use the same version-conditional logic
- Consider the device type (IS_CUDA) in the synchronization logic
- Extract the synchronization logic into a helper function to ensure consistency
Example refactor:
def _synchronize():
"""Synchronize the current device."""
if version.parse(torch.__version__) < version.parse("2.6.0"):
if IS_CUDA:
torch.cuda.synchronize()
else:
torch.mps.synchronize()
else:
torch.accelerator.synchronize()Then use _synchronize() at both lines 102 and 161.
🏁 Script executed:
#!/bin/bash
# Find all torch.cuda.synchronize() calls in the file
rg -n 'torch\.cuda\.synchronize\(\)' tilelang/profiler/bench.pyLength of output: 69
🌐 Web query:
What is the correct API for device synchronization in PyTorch 2.6.0 and later? Is it torch.accelerator.synchronize()?
💡 Result:
Yes — starting in PyTorch 2.6 the accelerator‑agnostic API is torch.accelerator.synchronize(device=None). Use that for device synchronization; older device‑specific calls (e.g. torch.cuda.synchronize or torch.cpu.synchronize) still exist for their backends. [1][2][3]
Sources:
[1] PyTorch 2.6 torch.accelerator.synchronize docs.
[2] PyTorch 2.6 torch.accelerator overview.
[3] PyTorch torch.cuda.synchronize docs.
Unify device synchronization into a helper
Both torch.cuda.synchronize() calls at lines 102 and 161 need the same version-conditional, device-aware logic. For PyTorch <2.6.0, call torch.cuda.synchronize() if IS_CUDA else torch.mps.synchronize(); for >=2.6.0, use torch.accelerator.synchronize(). Extract this into a single _synchronize() function and replace both calls with _synchronize().
🤖 Prompt for AI Agents
In tilelang/profiler/bench.py around lines 101-104 and also where the second
sync occurs near line 161, both places duplicate the same PyTorch version- and
device-aware synchronization logic; extract this into a single helper function
_synchronize() that checks torch.__version__ (using version.parse) and calls
torch.cuda.synchronize() when PyTorch <2.6.0 and IS_CUDA is true, calls
torch.mps.synchronize() when PyTorch <2.6.0 and not IS_CUDA, and calls
torch.accelerator.synchronize() for PyTorch >=2.6.0; add the helper near the top
of the file (ensuring version, torch and IS_CUDA are in scope), then replace
both original conditional blocks with a simple call to _synchronize().
torch.accelerator.synchronize() to torch.cuda.synchronize()
…nchronize()` (tile-ai#987) * [Refactor]:Add support for torch version lower than 2.6.0 * update
Summary by CodeRabbit