-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrating fast_hadamard_transform on C++ level #17
Conversation
@HanGuo97 Hi! Works on my machine. Could you check if it compiles and works for you as well? |
flute/csrc/hadamard.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a file a copy of that in the original "fast-hadamard-transform"? (Looks slightly different to me, and if not a copy, what's the difference?)
setup.py
Outdated
@@ -89,6 +89,11 @@ def get_extensions() -> List: | |||
sources = ( | |||
list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) + | |||
list(glob.glob(os.path.join(extensions_dir, "*.cu")))) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor, could we put them inside include_dirs
and sources
definition? (Looks slightly cleaner to me)
flute/csrc/qgemm.cpp
Outdated
> | ||
at::Tensor | ||
qgemm_hadamard(const at::Tensor& input, | ||
const at::Tensor& weight, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor indentation inconsistency :)
flute/csrc/qgemm.cpp
Outdated
@@ -7,6 +7,9 @@ | |||
#include "cute/numeric/integral_constant.hpp" | |||
|
|||
|
|||
at::Tensor fast_hadamard_transform(at::Tensor &x, float scale); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we make the signature style similar to others (one line for output type, and one argument per line) --- this is the style used in CUTLASS
flute/csrc/qgemm.cpp
Outdated
@@ -369,6 +372,45 @@ qgemm_raw_simple(const at::Tensor& input, | |||
} | |||
|
|||
|
|||
at::Tensor apply_hadamard(const at::Tensor& input, const cute::int64_t hadamard_size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we make the signature style similar to others (one line for output type, and one argument per line) --- this is the style used in CUTLASS
Upd: swapped the Tri Dao's implementation for the HadaCore implementation because the former wouldn't work with |
This PR adds a kernel that wraps together a Fast Hadamard Transform (built from Tri Dao's implementation) and FLUTE GEMM to support HIGGS.
Notably, submodule initialization will be needed to build it.