Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using int4_t = int4;
#define hexp cutlass::fast_exp
#define hlog cutlass::fast_log
#define hsqrt cutlass::fast_sqrt
#define hsin cutlass::fast_sin
#define hcos cutlass::fast_cos
Comment on lines +21 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While using macros is consistent with the existing definitions in this file, it's generally safer to use inline functions in C++. This avoids potential issues with macro expansion, improves type safety, and respects namespaces. Since the project uses C++17 features (like std::is_same_v), you can use auto function parameters (a C++14 feature) to create generic wrappers concisely. This would be a good opportunity to start migrating away from macros for these function aliases.

TL_DEVICE auto hsin(auto x) { return cutlass::fast_sin(x); }
TL_DEVICE auto hcos(auto x) { return cutlass::fast_cos(x); }

Comment on lines +21 to +22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

🧩 Analysis chain

Adds fast trig macros consistently; verify CUTLASS support for all target types/versions.

hsin/hcos align with existing hexp/hlog/hsqrt patterns. Please confirm that cutlass::fast_sin / cutlass::fast_cos are available (and device-callable) for your targeted CUTLASS/CUDA versions and support the input types you pass (e.g., half_t, bfloat16_t, float). If bfloat16_t isn’t covered (similar to the custom bfloat16_t fast_exp wrapper below), consider adding analogous wrappers or fallbacks.

Run the following to check for potential collisions and usage sites:

Optionally, verify availability in docs:


🏁 Script executed:

#!/bin/bash
# 1) Find any other definitions/usages of hsin/hcos to avoid macro conflicts.
rg -n -C2 -P '(^\s*#\s*define\s+hsin\b)|\bhsin\s*\('
rg -n -C2 -P '(^\s*#\s*define\s+hcos\b)|\bhcos\s*\('

# 2) Confirm CUTLASS fast trig references in the repo (sanity check).
rg -n -C2 -P '\bcutlass::fast_(sin|cos)\b'

# 3) Optional: locate bfloat16 sin/cos usage that might require wrappers.
rg -n -C2 -P '\b(bfloat16_t|half_t)\b.*\b(hsin|hcos)\s*\('

Length of output: 1025


🌐 Web query:

Does CUTLASS provide cutlass::fast_sin and cutlass::fast_cos in cutlass/fast_math.h, and what types are supported (float, half_t, bfloat16_t)?

💡 Result:

Short answer: No — CUTLASS’s fast_math.h does not expose cutlass::fast_sin / cutlass::fast_cos. The fast_math header in CUTLASS implements a set of “fast_…” helpers (e.g. fast_tanh, fast_exp, etc.) with overloads for float / double and some device half support, but you won’t find cutlass::fast_sin/fast_cos there. The header and related GitHub issues show the library provides fast variants for operations like tanh (including float/double/half overloads) but not sin/cos or bfloat16 fast trig helpers. [1][2]

Sources:

  • CUTLASS fast_math.h (repository header listing of the fast_math utilities). [1]
  • CUTLASS GitHub issue showing fast_* functions (fast_tanh) and type overloads discussion. [2]

References:
[1] cutlass/include/cutlass/fast_math.h — NVIDIA/cutlass (GitHub).
[2] NVIDIA/cutlass issue discussing fast_tanh overloads in fast_math.h.


Remove/guard hsin/hcos — cutlass::fast_sin / cutlass::fast_cos not provided

CUTLASS’s fast_math.h does not expose cutlass::fast_sin/cutlass::fast_cos; the macros in src/tl_templates/cuda/common.h (lines 21–22) will reference undefined symbols and likely break builds.

  • Replace the macros with device-callable alternatives (sinf/cosf or CUDA device intrinsics) or add explicit device wrappers that call a supported fast-trig implementation.
  • Add overloads/fallbacks for half_t and bfloat16_t (convert to float → trig → convert back) — CUTLASS lacks bfloat16 fast-trig helpers.
  • Ensure wrappers are annotated for device/host as needed.

src/tl_templates/cuda/common.h:21-22

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 21-22, the macros #define hsin
cutlass::fast_sin and #define hcos cutlass::fast_cos reference non-existent
CUTLASS symbols; replace them with guarded, device-callable wrappers that call
supported trig functions (e.g., sinf/cosf or CUDA device intrinsics) and mark
them __host__ __device__; provide overloads/fallbacks for half_t and bfloat16_t
by converting to float, performing the trig operation, then converting back;
protect the replacements with #ifdef/#else to use cutlass implementations if
available, and ensure proper includes and namespace qualification so builds
won’t reference undefined symbols.

#define htanh cutlass::fast_tanh
#define hpow powf

Expand Down
Loading