Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating fast_hadamard_transform on C++ level #17

Merged
merged 13 commits into from
Dec 12, 2024

Conversation

BlackSamorez
Copy link
Contributor

@BlackSamorez BlackSamorez commented Dec 10, 2024

This PR adds a kernel that wraps together a Fast Hadamard Transform (built from Tri Dao's implementation) and FLUTE GEMM to support HIGGS.

Notably, submodule initialization will be needed to build it.

@BlackSamorez BlackSamorez changed the title [WIP] Integrating fast_hadamard_transform on C++ level Integrating fast_hadamard_transform on C++ level Dec 10, 2024
@BlackSamorez
Copy link
Contributor Author

@HanGuo97 Hi! Works on my machine. Could you check if it compiles and works for you as well?

flute/__init__.py Show resolved Hide resolved
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a file a copy of that in the original "fast-hadamard-transform"? (Looks slightly different to me, and if not a copy, what's the difference?)

setup.py Outdated
@@ -89,6 +89,11 @@ def get_extensions() -> List:
sources = (
list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) +
list(glob.glob(os.path.join(extensions_dir, "*.cu"))))

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor, could we put them inside include_dirs and sources definition? (Looks slightly cleaner to me)

>
at::Tensor
qgemm_hadamard(const at::Tensor& input,
const at::Tensor& weight,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor indentation inconsistency :)

@@ -7,6 +7,9 @@
#include "cute/numeric/integral_constant.hpp"


at::Tensor fast_hadamard_transform(at::Tensor &x, float scale);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make the signature style similar to others (one line for output type, and one argument per line) --- this is the style used in CUTLASS

@@ -369,6 +372,45 @@ qgemm_raw_simple(const at::Tensor& input,
}


at::Tensor apply_hadamard(const at::Tensor& input, const cute::int64_t hadamard_size) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make the signature style similar to others (one line for output type, and one argument per line) --- this is the style used in CUTLASS

@BlackSamorez
Copy link
Contributor Author

Upd: swapped the Tri Dao's implementation for the HadaCore implementation because the former wouldn't work with torch.compile for some batch sizes.
It's ~20% slower though. TODO: choose the better option later.

@HanGuo97 HanGuo97 merged commit 284771d into HanGuo97:main Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants