From 9a768abdc9e65699490384907d19a52479979b17 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 25 Sep 2024 16:34:56 -0700 Subject: [PATCH] Fix up CMakeLists and reorganize some code locations Summary: This diff cleans up CMakeFiles and reorganizes some of the code directories. Archetecture-specific kernel code remains in experimental/kernels/cpu/aarch64. This is not changed. Code related to ops (parallel.h, memory.h, "high-level" op interface, examples/torch_custom_op, etc) is moved to experimental/ops. The example code that builds the custom ops for ATen/ExecuTorch (experimental/kernels/cpu/linear/examples/torch_custom_op) is moved out of examples and into a more descriptive directory: experimental/ops/linear/linear_a8wxdq_op. These are the kernels that will be used in torchchat initially. The "high-level" op interface is in ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h. A library for it is defined in the CMakeFiles.txt. It uses no templates and is not specific to the universal kernels and can be used with KleidiAI kernels. The ops in experimental/ops/linear/linear_a8wxdq_op use the high-level library and the universal kernels in experimental/kernels/cpu/aarch64 to define custom ops. Finally, there is a CMakeLists at the top-level directory experimental. The script build_torchao_ops.sh builds the torchao_ops for different platforms: ``` sh build_torchao_ops.sh ATEN ``` For ExecuTorch, you must define the environment variables EXECUTORCH_INCLUDE_DIRS and EXECUTORCH_LIBRARIES manually before calling: ``` EXECUTORCH_INCLUDE_DIRS=$HOME EXECUTORCH_LIBRARIES=$HOME/executorch/cmake-out/lib/libexecutorch_no_prim_ops.a sh build_torchao_ops.sh EXECUTORCH ``` Reviewed By: digantdesai Differential Revision: D62711903 --- torchao/experimental/CMakeLists.txt | 50 ++++++++++++++ .../{kernels/cpu => }/Utils.cmake | 0 ...uild_custom_op.sh => build_torchao_ops.sh} | 15 ++-- .../kernels/cpu/aarch64/CMakeLists.txt | 16 +++-- .../cpu/linear/benchmarks/CMakeLists.txt | 57 --------------- .../cpu/linear/examples/CMakeLists.txt | 38 ---------- .../examples/torch_custom_op/CMakeLists.txt | 58 ---------------- .../examples/torch_custom_op/run_custom_op.py | 69 ------------------- .../kernels/cpu/linear/tests/CMakeLists.txt | 41 ----------- .../cpu/linear/tests/build_and_run_tests.sh | 12 ---- .../experimental/ops/linear/CMakeLists.txt | 12 ++++ .../ops/linear/benchmarks/CMakeLists.txt | 40 +++++++++++ .../benchmarks/benchmark_linear_operator.cpp | 40 +++++++++-- .../benchmarks/build_and_run_benchmarks.sh | 6 +- ...it_activation_groupwise_lowbit_weight.cpp} | 57 +++------------ ..._8bit_activation_groupwise_lowbit_weight.h | 21 +----- .../ops/linear/examples/CMakeLists.txt | 42 +++++++++++ ...ationGroupwiseLowbitWeightLinearOperator.h | 20 +++--- .../linear/examples/build_and_run_examples.sh | 10 +-- .../examples/separate_function_wrappers.cpp | 36 ++++++++-- .../examples/stateful_class_wrapper.cpp | 32 ++++++++- .../linear/linear_a8wxdq_op/CMakeLists.txt | 45 ++++++++++++ .../linear_a8wxdq_op/linear_a8wxdq-impl.h} | 40 +++++++++-- .../linear_a8wxdq_op/linear_a8wxdq_aten.cpp} | 2 +- .../linear_a8wxdq_executorch}/w2s.cpp | 2 +- .../linear_a8wxdq_executorch}/w2sz.cpp | 2 +- .../linear_a8wxdq_executorch}/w3s.cpp | 2 +- .../linear_a8wxdq_executorch}/w3sz.cpp | 2 +- .../linear_a8wxdq_executorch}/w4s.cpp | 2 +- .../linear_a8wxdq_executorch}/w4sz.cpp | 2 +- .../linear_a8wxdq_executorch}/w5s.cpp | 2 +- .../linear_a8wxdq_executorch}/w5sz.cpp | 2 +- .../ops/linear/tests/CMakeLists.txt | 43 ++++++++++++ .../ops/linear/tests/build_and_run_tests.sh | 14 ++++ .../linear/tests/test_linear_operator.cpp | 51 ++++++++++---- .../experimental/{kernels/cpu => ops}/macro.h | 0 .../{kernels/cpu => ops}/memory.h | 0 .../{kernels/cpu => ops}/parallel-aten-impl.h | 0 .../cpu => ops}/parallel-openmp-impl.h | 0 .../cpu => ops}/parallel-pthreadpool-impl.h | 0 .../parallel-single_threaded-impl.h | 0 .../cpu => ops}/parallel-test_dummy-impl.h | 0 .../{kernels/cpu => ops}/parallel.h | 22 +++--- ...test_int8_dyn_act_intx_weight_quantizer.py | 47 ++++++++++++- 44 files changed, 521 insertions(+), 431 deletions(-) create mode 100644 torchao/experimental/CMakeLists.txt rename torchao/experimental/{kernels/cpu => }/Utils.cmake (100%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh => build_torchao_ops.sh} (51%) delete mode 100644 torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt delete mode 100644 torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py delete mode 100644 torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt delete mode 100644 torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh create mode 100644 torchao/experimental/ops/linear/CMakeLists.txt create mode 100644 torchao/experimental/ops/linear/benchmarks/CMakeLists.txt rename torchao/experimental/{kernels/cpu => ops}/linear/benchmarks/benchmark_linear_operator.cpp (77%) rename torchao/experimental/{kernels/cpu => ops}/linear/benchmarks/build_and_run_benchmarks.sh (70%) rename torchao/experimental/{kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h => ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp} (83%) rename torchao/experimental/{kernels/cpu => ops}/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h (81%) create mode 100644 torchao/experimental/ops/linear/examples/CMakeLists.txt rename torchao/experimental/{kernels/cpu => ops}/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h (91%) rename torchao/experimental/{kernels/cpu => ops}/linear/examples/build_and_run_examples.sh (67%) rename torchao/experimental/{kernels/cpu => ops}/linear/examples/separate_function_wrappers.cpp (80%) rename torchao/experimental/{kernels/cpu => ops}/linear/examples/stateful_class_wrapper.cpp (71%) create mode 100644 torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h => ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h} (87%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp => ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp} (98%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w2s.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w2sz.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w3s.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w3sz.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w4s.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w4sz.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w5s.cpp (90%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch => ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch}/w5sz.cpp (90%) create mode 100644 torchao/experimental/ops/linear/tests/CMakeLists.txt create mode 100644 torchao/experimental/ops/linear/tests/build_and_run_tests.sh rename torchao/experimental/{kernels/cpu => ops}/linear/tests/test_linear_operator.cpp (78%) rename torchao/experimental/{kernels/cpu => ops}/macro.h (100%) rename torchao/experimental/{kernels/cpu => ops}/memory.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-aten-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-openmp-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-pthreadpool-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-single_threaded-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel-test_dummy-impl.h (100%) rename torchao/experimental/{kernels/cpu => ops}/parallel.h (73%) rename torchao/experimental/{kernels/cpu/linear/examples/torch_custom_op => tests}/test_int8_dyn_act_intx_weight_quantizer.py (63%) diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt new file mode 100644 index 0000000000..198e9ebd44 --- /dev/null +++ b/torchao/experimental/CMakeLists.txt @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(torchao) + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_CXX_STANDARD 17) + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release) +endif() + + +# Source root directory for torchao/experimental +if(NOT TORCHAO_ROOT) + set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) +endif() + +if(NOT TORCHAO_INCLUDE_DIRS) + set(TORCHAO_INCLUDE_DIRS ${TORCHAO_ROOT}/../..) +endif() + +if (NOT TORCHAO_PARALLEL_BACKEND) + if (TORCHAO_OP_TARGET STREQUAL "ATEN") + set(TORCHAO_PARALLEL_BACKEND "ATEN_OPENMP") + elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH") + set(TORCHAO_PARALLEL_BACKEND "PTHREADPOOL") + else() + message(TORCHAO_PARALLEL_BACKEND "TORCHAO_PARALLEL_BACKEND is not set. Please set it directly or set TORCHAO_OP_TARGET to get a default.") + endif() +endif() + +include(CMakePrintHelpers) + +add_compile_options("-Wall" "-Werror") + +include(CMakePrintHelpers) +message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") +include_directories(${TORCHAO_INCLUDE_DIRS}) + +if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + # Defines target torchao_kernels_aarch64 + add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64) + add_subdirectory(${TORCHAO_ROOT}/ops/linear) + add_subdirectory(${TORCHAO_ROOT}/ops/linear/linear_a8wxdq_op) +endif() diff --git a/torchao/experimental/kernels/cpu/Utils.cmake b/torchao/experimental/Utils.cmake similarity index 100% rename from torchao/experimental/kernels/cpu/Utils.cmake rename to torchao/experimental/Utils.cmake diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh b/torchao/experimental/build_torchao_ops.sh similarity index 51% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh rename to torchao/experimental/build_torchao_ops.sh index c657857fcc..de6d8e17d8 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/build_custom_op.sh +++ b/torchao/experimental/build_torchao_ops.sh @@ -5,15 +5,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../.. - export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" export CMAKE_OUT=/tmp/cmake-out/torchao -cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \ - -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ - -DPLATFORM="ATEN" \ - -S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \ +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \ + -DTORCHAO_OP_TARGET="$1" \ + -DEXECUTORCH_LIBRARIES=${EXECUTORCH_LIBRARIES} \ + -DEXECUTORCH_INCLUDE_DIRS=${EXECUTORCH_INCLUDE_DIRS} \ + -S . \ -B ${CMAKE_OUT} -cmake --build ${CMAKE_OUT} +cmake --build ${CMAKE_OUT} --target install --config Release diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index a13737d874..ec497a1871 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -4,10 +4,12 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -add_library( - kernel_aarch64 - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) +if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + add_library( + torchao_kernels_aarch64 + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp + ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp + ) +endif() diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt deleted file mode 100644 index 61e5eeae27..0000000000 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/CMakeLists.txt +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.19) -project(benchmarks) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -include(FetchContent) -FetchContent_Declare(googlebenchmark - GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG main) # need main for benchmark::benchmark - -set(BENCHMARK_ENABLE_TESTING OFF) -FetchContent_MakeAvailable( - googlebenchmark) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) - -add_executable(benchmark_linear_operator benchmark_linear_operator.cpp) -target_link_libraries( - benchmark_linear_operator - PRIVATE - benchmark::benchmark - dep -) - -option(TORCHAO_PARALLEL_OMP "" OFF) -option(TORCHAO_PARALLEL_SINGLE_THREADED "" ON) - -if (TORCHAO_PARALLEL_OMP) - message("OpenMP_ROOT: ${OpenMP_ROOT}") - add_definitions(-DTORCHAO_PARALLEL_OMP=1) - find_package(OpenMP REQUIRED) - if(OpenMP_CXX_FOUND) - target_link_libraries(benchmark_linear_operator PUBLIC OpenMP::OpenMP_CXX) - endif() -endif() - -if (TORCHAO_PARALLEL_SINGLE_THREADED) - add_definitions(-DTORCHAO_PARALLEL_SINGLE_THREADED=1) -endif() diff --git a/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt deleted file mode 100644 index 4489dc7c36..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/CMakeLists.txt +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -project(examples) - -cmake_minimum_required(VERSION 3.19) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) - -add_executable(separate_function_wrappers separate_function_wrappers.cpp) -target_link_libraries( - separate_function_wrappers - PRIVATE - kernel_aarch64 -) - -add_executable(stateful_class_wrapper stateful_class_wrapper.cpp) -target_link_libraries( - stateful_class_wrapper - PRIVATE - kernel_aarch64 -) - -include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake) - -target_link_torchao_parallel_backend(stateful_class_wrapper "openmp") -target_link_torchao_parallel_backend(separate_function_wrappers "openmp") diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt deleted file mode 100644 index 10e44a79a8..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/CMakeLists.txt +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -project(torch_custom_op) - -cmake_minimum_required(VERSION 3.19) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Release) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") -include_directories(${TORCHAO_INCLUDE_DIRS}) - -add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64) - -include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake) - -set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH") -string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER) - -if(PLATFORM_TO_UPPER STREQUAL "ATEN") -message(STATUS "Building with PLATFORM=ATEN") - -find_package(Torch REQUIRED) -add_library(lowbit_op_aten SHARED lowbit_op_aten.cpp) -target_link_libraries(lowbit_op_aten PRIVATE kernel_aarch64) -target_include_directories(lowbit_op_aten PRIVATE "${TORCH_INCLUDE_DIRS}") -target_link_libraries(lowbit_op_aten PRIVATE "${TORCH_LIBRARIES}") -target_compile_definitions(lowbit_op_aten PRIVATE USE_ATEN=1) -target_link_torchao_parallel_backend(lowbit_op_aten "ATEN_OPENMP") - -elseif(PLATFORM_TO_UPPER STREQUAL "EXECUTORCH") -message(STATUS "Building with PLATFORM=EXECUTORCH") - -add_library(lowbit_op_executorch SHARED - lowbit_op_executorch/w2s.cpp - lowbit_op_executorch/w2sz.cpp - lowbit_op_executorch/w3s.cpp - lowbit_op_executorch/w3sz.cpp - lowbit_op_executorch/w4s.cpp - lowbit_op_executorch/w4sz.cpp - lowbit_op_executorch/w5s.cpp - lowbit_op_executorch/w5sz.cpp -) -target_include_directories(lowbit_op_executorch PRIVATE ${EXECUTORCH_INCLUDE_DIRS}) -target_compile_definitions(lowbit_op_executorch PRIVATE USE_EXECUTORCH=1) -target_link_torchao_parallel_backend(lowbit_op_executorch "SINGLE_THREADED") -target_link_libraries(lowbit_op_executorch PRIVATE ${EXECUTORCH_LIBRARIES}) -target_link_libraries(lowbit_op_executorch PRIVATE kernel_aarch64) - -else() -message(FATAL_ERROR "Unknown PLATFORM: ${PLATFORM}. Please choose one of: ATEN, EXECUTORCH.") -endif() diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py b/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py deleted file mode 100644 index e3d96df63c..0000000000 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/run_custom_op.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import glob -import os - -import sys - -import torch - -sys.path.insert( - 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) -) -from quant_api import Int8DynActIntxWeightQuantizer - -libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") -libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) -torch.ops.load_library(libs[0]) - -group_size = 256 -m = 1 -n = 4096 -k = 4096 -nbit = 4 -has_weight_zeros = False -n_layers = 5 - -print("Creating random model") -layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)] -model = torch.nn.Sequential(*layers) -model = model.eval() - -print("Quantizing random model") -quantized_model = copy.deepcopy(model) -quantizer = Int8DynActIntxWeightQuantizer( - device="cpu", - precision=torch.float32, - bitwidth=nbit, - groupsize=group_size, - has_weight_zeros=has_weight_zeros, -) -quantized_model = quantizer.quantize(quantized_model) -quantized_model = quantized_model.eval() - -print("Creating random activations") -activations = torch.randn(m, k, dtype=torch.float32) - -print("Exporting quantized model") -exported = torch.export.export(quantized_model, (activations,)) - -print("Using torch.compile on quantized model") -quantized_model_compiled = torch.compile(quantized_model) -with torch.no_grad(): - quantized_model_compiled(activations) - -print("Compiling quantized model with AOTI") -torch._export.aot_compile( - quantized_model, - (activations,), - options={"aot_inductor.output_path": "/tmp/torch_custom_op_example_model.so"}, -) - -print("Running AOTI") -fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu") -fn(activations) diff --git a/torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt deleted file mode 100644 index 3a415d8edd..0000000000 --- a/torchao/experimental/kernels/cpu/linear/tests/CMakeLists.txt +++ /dev/null @@ -1,41 +0,0 @@ -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -cmake_minimum_required(VERSION 3.19) -project(tests) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_BUILD_TYPE Debug) - -include(FetchContent) -FetchContent_Declare( - googletest - URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip -) -FetchContent_MakeAvailable(googletest) - -add_compile_options("-Wall" "-Werror") - -include(CMakePrintHelpers) -message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}") -include_directories(${TORCHAO_LIBRARIES}) - -add_library( - dep - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp - ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp -) - -enable_testing() - -add_definitions(-DTORCHAO_PARALLEL_TEST_DUMMY=1) -add_executable(test_linear_operator test_linear_operator.cpp) -target_link_libraries( - test_linear_operator - PRIVATE - GTest::gtest_main - dep -) - -include(GoogleTest) -gtest_discover_tests(test_linear_operator) diff --git a/torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh deleted file mode 100644 index ad9a855084..0000000000 --- a/torchao/experimental/kernels/cpu/linear/tests/build_and_run_tests.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/bash -# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/tests -B ${CMAKE_OUT} - -cmake --build ${CMAKE_OUT} - -# Run -${CMAKE_OUT}/test_linear_operator diff --git a/torchao/experimental/ops/linear/CMakeLists.txt b/torchao/experimental/ops/linear/CMakeLists.txt new file mode 100644 index 0000000000..2f7b91bbf9 --- /dev/null +++ b/torchao/experimental/ops/linear/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +include(${TORCHAO_ROOT}/Utils.cmake) + +add_library(torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} STATIC channelwise_8bit_activation_groupwise_lowbit_weight.cpp) +target_link_torchao_parallel_backend(torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/ops/linear/benchmarks/CMakeLists.txt b/torchao/experimental/ops/linear/benchmarks/CMakeLists.txt new file mode 100644 index 0000000000..70d6bf2cba --- /dev/null +++ b/torchao/experimental/ops/linear/benchmarks/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(benchmarks) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Release) +add_compile_options("-Wall" "-Werror") + +set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) + +include(FetchContent) +FetchContent_Declare(googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG main) # need main for benchmark::benchmark + +set(BENCHMARK_ENABLE_TESTING OFF) +FetchContent_MakeAvailable( + googlebenchmark) + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TORCHAO_PARALLEL_BACKEND "OPENMP") +add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +add_executable(benchmark_linear_operator benchmark_linear_operator.cpp) +target_link_libraries( + benchmark_linear_operator + PRIVATE + benchmark::benchmark + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(benchmark_linear_operator "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp b/torchao/experimental/ops/linear/benchmarks/benchmark_linear_operator.cpp similarity index 77% rename from torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp rename to torchao/experimental/ops/linear/benchmarks/benchmark_linear_operator.cpp index ad6563eabe..8d7cd4a908 100644 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/benchmark_linear_operator.cpp +++ b/torchao/experimental/ops/linear/benchmarks/benchmark_linear_operator.cpp @@ -5,11 +5,40 @@ // LICENSE file in the root directory of this source tree. #include +#include #include -#include -#include +#include +#include +#include #include +using namespace torchao::ops::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + template static void channelwise_8bit_activation_groupwise_lowbit_weight( benchmark::State& state) { @@ -24,9 +53,6 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight( int num_test_cases = state.range(5); // Initialize config and tiling params - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; - auto ukernel_config = get_ukernel_config(); auto pack_weight_data_tiling_params = @@ -66,7 +92,7 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight( std::vector> packed_weight_data; for (int i = 0; i < test_cases.size(); i++) { - packed_weight_data.emplace_back(torchao::make_aligned_byte_array_unique_ptr( + packed_weight_data.emplace_back(torchao::make_aligned_byte_ptr( packed_weight_data_alignment, packed_weight_data_size)); pack_weight_data_operator( ukernel_config, @@ -91,7 +117,7 @@ static void channelwise_8bit_activation_groupwise_lowbit_weight( size_t activation_data_buffer_alignment = get_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_array_unique_ptr( + auto activation_data_buffer = torchao::make_aligned_byte_ptr( activation_data_buffer_alignment, activation_data_buffer_size); auto output = std::vector(m * n); diff --git a/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh b/torchao/experimental/ops/linear/benchmarks/build_and_run_benchmarks.sh similarity index 70% rename from torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh rename to torchao/experimental/ops/linear/benchmarks/build_and_run_benchmarks.sh index 18da0e992d..ed80d34e2f 100644 --- a/torchao/experimental/kernels/cpu/linear/benchmarks/build_and_run_benchmarks.sh +++ b/torchao/experimental/ops/linear/benchmarks/build_and_run_benchmarks.sh @@ -7,11 +7,9 @@ # Call script with sh build_and_run_benchmarks.sh {BENCHAMRK} -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. -export CMAKE_OUT=/tmp/cmake-out/torch_ao/benchmarks +export CMAKE_OUT=/tmp/cmake-out/torchao/benchmarks cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/benchmarks \ + -S . \ -B ${CMAKE_OUT} \ -DOpenMP_ROOT=$(brew --prefix libomp) \ -DTORCHAO_PARALLEL_OMP=ON diff --git a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp similarity index 83% rename from torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h rename to torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp index 37ad74b0f0..ae611d3ccc 100644 --- a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight-impl.h +++ b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.cpp @@ -4,18 +4,18 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#pragma once #include -#include -#include +#include +#include +#include #include #include #include -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { -inline PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( +PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( const UKernelConfig& ukernel_config, int n, int target_panels_per_thread) { @@ -40,7 +40,7 @@ inline PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( return tiling_params; } -inline void pack_weight_data_operator( +void pack_weight_data_operator( const UKernelConfig& ukernel_config, const PackWeightDataTilingParams& tiling_params, // Outputs @@ -81,7 +81,7 @@ inline void pack_weight_data_operator( } // This default mimics XNNPACK behavior if target_tiles_per_thread = 5 -inline LinearTilingParams get_default_linear_tiling_params( +LinearTilingParams get_default_linear_tiling_params( const UKernelConfig& ukernel_config, int m, int n, @@ -118,8 +118,7 @@ inline LinearTilingParams get_default_linear_tiling_params( namespace internal { -inline int -get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( +inline int get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( const UKernelConfig& ukernel_config, const LinearTilingParams& tiling_params, int m, @@ -273,7 +272,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( } } // namespace internal -inline void linear_operator( +void linear_operator( const UKernelConfig& ukernel_config, const LinearTilingParams& tiling_params, LinearTileSchedulingPolicy scheduling_policy, @@ -333,7 +332,7 @@ inline void linear_operator( } } -inline int get_activation_data_buffer_size( +int get_activation_data_buffer_size( const UKernelConfig& ukernel_config, const LinearTilingParams& tiling_params, LinearTileSchedulingPolicy scheduling_policy, @@ -355,38 +354,4 @@ inline int get_activation_data_buffer_size( } } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - -// TODO: may move to different fil or namespace. This method is not part of the -// high-level interface, but specific to the universal kernels we wrote in -// torchao -#include -namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight { -template - -inline UKernelConfig get_ukernel_config() { - UKernelConfig config; - - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - - return config; -} -} // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - // torchao::kernels::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight diff --git a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h similarity index 81% rename from torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h rename to torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h index 5d8f11b821..c92c94acfb 100644 --- a/torchao/experimental/kernels/cpu/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h +++ b/torchao/experimental/ops/linear/channelwise_8bit_activation_groupwise_lowbit_weight.h @@ -7,8 +7,7 @@ #pragma once #include -// TODO: maybe move to operator directory -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { struct UKernelConfig { @@ -147,20 +146,4 @@ void linear_operator( float clamp_max); } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - -// TODO: may move to different file or namespace -// It is not part of the high-level interface, but specific to the universal -// kernels in torchao. -// Kleidi will need to implement their own get_ukernel_config -// In future, we may build a high-level get_ukernel_config with CPU-runtime -// selection -namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight { -template -UKernelConfig get_ukernel_config(); - -} // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight - -#include + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight diff --git a/torchao/experimental/ops/linear/examples/CMakeLists.txt b/torchao/experimental/ops/linear/examples/CMakeLists.txt new file mode 100644 index 0000000000..2b69adb3d8 --- /dev/null +++ b/torchao/experimental/ops/linear/examples/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +project(examples) + +cmake_minimum_required(VERSION 3.19) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Release) + +include(CMakePrintHelpers) + +set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TORCHAO_PARALLEL_BACKEND "OPENMP") +add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +include(${TORCHAO_ROOT}/Utils.cmake) + +add_executable(separate_function_wrappers separate_function_wrappers.cpp) +target_link_libraries( + separate_function_wrappers + PRIVATE + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(separate_function_wrappers "${TORCHAO_PARALLEL_BACKEND}") + +add_executable(stateful_class_wrapper stateful_class_wrapper.cpp) +target_link_libraries( + stateful_class_wrapper + PRIVATE + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(stateful_class_wrapper "${TORCHAO_PARALLEL_BACKEND}") diff --git a/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h b/torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h similarity index 91% rename from torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h rename to torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h index 575093f21b..a7755dadf4 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h +++ b/torchao/experimental/ops/linear/examples/Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator.h @@ -5,26 +5,22 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include -#include -#include +#include +#include +#include #include #include -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator { private: - torchao::aligned_byte_ptr packed_weight_data_{ - nullptr, - nullptr}; + torchao::aligned_byte_ptr packed_weight_data_{nullptr, nullptr}; int packed_weight_data_size_{0}; int packed_weight_data_alignment_{0}; - torchao::aligned_byte_ptr activation_data_buffer_{ - nullptr, - nullptr}; + torchao::aligned_byte_ptr activation_data_buffer_{nullptr, nullptr}; int m_{0}; int n_{0}; @@ -114,7 +110,7 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator { get_packed_weight_data_size(ukernel_config_, n_, k_, group_size_); auto packed_weight_data_alignment = get_packed_weight_data_alignment(ukernel_config_); - + packed_weight_data_size_ = packed_weight_data_size; packed_weight_data_alignment_ = packed_weight_data_alignment; packed_weight_data_ = torchao::make_aligned_byte_ptr( @@ -199,4 +195,4 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator { } }; } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight diff --git a/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh b/torchao/experimental/ops/linear/examples/build_and_run_examples.sh similarity index 67% rename from torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh rename to torchao/experimental/ops/linear/examples/build_and_run_examples.sh index 9c244e54cc..01185fdd3f 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/build_and_run_examples.sh +++ b/torchao/experimental/ops/linear/examples/build_and_run_examples.sh @@ -5,15 +5,11 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../.. - export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}" -export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples -cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ - -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples \ +export CMAKE_OUT=/tmp/cmake-out/torchao/examples +cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \ + -S . \ -B ${CMAKE_OUT} \ -DOpenMP_ROOT=$(brew --prefix libomp) cmake --build ${CMAKE_OUT} diff --git a/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp b/torchao/experimental/ops/linear/examples/separate_function_wrappers.cpp similarity index 80% rename from torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp rename to torchao/experimental/ops/linear/examples/separate_function_wrappers.cpp index ba3e5b29b3..144fe5c08d 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/separate_function_wrappers.cpp +++ b/torchao/experimental/ops/linear/examples/separate_function_wrappers.cpp @@ -4,9 +4,11 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. +#include #include -#include -#include +#include +#include +#include #include // This file contains an example of wrapping the torchao weight packing and // linear operators into two operators: one for weight packing and another @@ -20,9 +22,33 @@ // one stateful class, but not all surfaces support this (see // examples/stateful_class_wrapper.cpp for an example of this). -namespace torchao::operators::cpu::linear:: +namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight { +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + torchao::aligned_byte_ptr pack_weight_data_operator( UKernelConfig ukernel_config, int n, @@ -115,10 +141,10 @@ void linear_operator( } } // namespace - // torchao::operators::cpu::linear::channelwise_8bit_activation_groupwise_lowbit_weight + // torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight int main() { - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; torchao::set_num_threads(8); diff --git a/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp b/torchao/experimental/ops/linear/examples/stateful_class_wrapper.cpp similarity index 71% rename from torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp rename to torchao/experimental/ops/linear/examples/stateful_class_wrapper.cpp index 5fb24c683d..c1cd2d110b 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/stateful_class_wrapper.cpp +++ b/torchao/experimental/ops/linear/examples/stateful_class_wrapper.cpp @@ -4,9 +4,10 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. +#include #include -#include -#include +#include +#include #include #include @@ -21,9 +22,33 @@ // examples/separate_function_wrappers.cpp for an example of how to split the // operations into two steps. -using namespace torchao::operators::cpu::linear:: +using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + int main() { int m = 13; int n = 4096 + 1; @@ -54,6 +79,7 @@ int main() { std::cout << "Initializing linear_operator." << std::endl; auto ukernel_config = get_ukernel_config(); + auto linear_operator = Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator( ukernel_config, diff --git a/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt b/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt new file mode 100644 index 0000000000..f69d884cd8 --- /dev/null +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +include(${TORCHAO_ROOT}/Utils.cmake) + +if(TORCHAO_OP_TARGET STREQUAL "ATEN") + message(STATUS "Building with TORCHAO_OP_TARGET=ATEN") + find_package(Torch REQUIRED) + add_library(linear_a8wxdq_${TORCHAO_OP_TARGET} SHARED linear_a8wxdq_aten.cpp) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_kernels_aarch64) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) + target_include_directories(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE "${TORCH_INCLUDE_DIRS}") + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE "${TORCH_LIBRARIES}") + target_compile_definitions(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE USE_ATEN=1) +elseif(TORCHAO_OP_TARGET STREQUAL "EXECUTORCH") + message(STATUS "Building with TORCHAO_OP_TARGET=EXECUTORCH") + add_library(linear_a8wxdq_${TORCHAO_OP_TARGET} SHARED + linear_a8wxdq_executorch/w2s.cpp + linear_a8wxdq_executorch/w2sz.cpp + linear_a8wxdq_executorch/w3s.cpp + linear_a8wxdq_executorch/w3sz.cpp + linear_a8wxdq_executorch/w4s.cpp + linear_a8wxdq_executorch/w4sz.cpp + linear_a8wxdq_executorch/w5s.cpp + linear_a8wxdq_executorch/w5sz.cpp + ) + target_include_directories(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE ${EXECUTORCH_INCLUDE_DIRS}) + target_compile_definitions(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE USE_EXECUTORCH=1) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE ${EXECUTORCH_LIBRARIES}) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_kernels_aarch64) + target_link_libraries(linear_a8wxdq_${TORCHAO_OP_TARGET} PRIVATE torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +else() + message(FATAL_ERROR "Unknown TORCHAO_OP_TARGET: ${TORCHAO_OP_TARGET}. Please choose one of: ATEN, EXECUTORCH.") +endif() + + +install( + TARGETS linear_a8wxdq_${TORCHAO_OP_TARGET} + DESTINATION lib +) diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h similarity index 87% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h index 01b1836981..eee51eafc6 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op-impl.h +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq-impl.h @@ -5,7 +5,8 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include +#include +#include #include #include @@ -28,6 +29,35 @@ using RuntimeContext = torch::executor::KernelRuntimeContext; #error "Must define either USE_ATEN or USE_EXECUTORCH" #endif +namespace { + +template +inline torchao::ops::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight::UKernelConfig + get_ukernel_config() { + torchao::ops::linear::channelwise_8bit_activation_groupwise_lowbit_weight:: + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + #ifdef USE_ATEN template Tensor pack_weights_cpu( @@ -69,7 +99,7 @@ Tensor pack_weights_cpu( weight_zeros_ptr = weight_zeros.value().const_data_ptr(); } - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< @@ -137,7 +167,7 @@ Tensor pack_weights_meta( int n = weight_qvals.size(0); int k = weight_qvals.size(1); - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< @@ -221,7 +251,7 @@ Tensor linear_out_cpu( CHECK_MSG(out.size(1) == n, "out shape is incorrect"); #endif // USE_EXECUTORCH - using namespace torchao::operators::cpu::linear:: + using namespace torchao::ops::linear:: channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< @@ -311,3 +341,5 @@ Tensor linear_meta( return torch::empty({m, n}).to("meta"); } #endif // USE_ATEN + +} // namespace diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp similarity index 98% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp index 626b3e769f..b1d464e5b5 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_aten.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_aten.cpp @@ -4,7 +4,7 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. -#include +#include #define DEFINE_OP(weight_nbit) \ m.def( \ diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2s.cpp index 592a0190a9..c6ef089995 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2sz.cpp index d2683b36ce..e569e05812 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w2sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w2sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3s.cpp index d59db3e1c7..9f236bd7b3 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3sz.cpp index 7458311b91..24a381fdcc 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w3sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w3sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4s.cpp index 75143050fa..67263d209d 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4sz.cpp index 714192a19b..530ff44370 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w4sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w4sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5s.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5s.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5s.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5s.cpp index 08c2d42ee8..de04a09f6a 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5s.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5s.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5sz.cpp b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5sz.cpp similarity index 90% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5sz.cpp rename to torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5sz.cpp index c1e3e953d3..91c5a16312 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/lowbit_op_executorch/w5sz.cpp +++ b/torchao/experimental/ops/linear/linear_a8wxdq_op/linear_a8wxdq_executorch/w5sz.cpp @@ -8,7 +8,7 @@ // EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new // file is needed for each variant -#include +#include namespace { Tensor _op_out( diff --git a/torchao/experimental/ops/linear/tests/CMakeLists.txt b/torchao/experimental/ops/linear/tests/CMakeLists.txt new file mode 100644 index 0000000000..866d832ccd --- /dev/null +++ b/torchao/experimental/ops/linear/tests/CMakeLists.txt @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) +project(tests) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Debug) +add_compile_options("-Wall" "-Werror") + +set(TORCHAO_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) + +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip +) +FetchContent_MakeAvailable(googletest) +enable_testing() + +include_directories(${TORCHAO_INCLUDE_DIRS}) + +set(TORCHAO_PARALLEL_BACKEND "TEST_DUMMY") +add_subdirectory(${TORCHAO_ROOT}/ops/linear ${CMAKE_CURRENT_BINARY_DIR}/torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND}) +add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +include(${TORCHAO_ROOT}/Utils.cmake) +add_executable(test_linear_operator test_linear_operator.cpp) +target_link_libraries( + test_linear_operator + PRIVATE + GTest::gtest_main + torchao_kernels_aarch64 + torchao_ops_linear_${TORCHAO_PARALLEL_BACKEND} +) +target_link_torchao_parallel_backend(test_linear_operator "${TORCHAO_PARALLEL_BACKEND}") + +include(GoogleTest) +gtest_discover_tests(test_linear_operator) diff --git a/torchao/experimental/ops/linear/tests/build_and_run_tests.sh b/torchao/experimental/ops/linear/tests/build_and_run_tests.sh new file mode 100644 index 0000000000..3fbe78c172 --- /dev/null +++ b/torchao/experimental/ops/linear/tests/build_and_run_tests.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +export CMAKE_OUT=/tmp/cmake-out/torchao/tests +cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S . -B ${CMAKE_OUT} + +cmake --build ${CMAKE_OUT} + +# Run +${CMAKE_OUT}/test_linear_operator diff --git a/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp b/torchao/experimental/ops/linear/tests/test_linear_operator.cpp similarity index 78% rename from torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp rename to torchao/experimental/ops/linear/tests/test_linear_operator.cpp index 5408e426bf..6d563111cc 100644 --- a/torchao/experimental/kernels/cpu/linear/tests/test_linear_operator.cpp +++ b/torchao/experimental/ops/linear/tests/test_linear_operator.cpp @@ -1,22 +1,52 @@ -// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. #include // TODO: move test_utils.h out of aarch64 +#include #include -#include -#include -#include +#include +#include +#include const float kTol = 1.0e-5; +using namespace torchao::ops::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight; + +template +UKernelConfig get_ukernel_config() { + UKernelConfig config; + + namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + + return config; +} + template void test_channelwise_8bit_activation_groupwise_lowbit_weight( int m, int n, int k, int group_size) { - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config(); @@ -47,7 +77,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight( get_packed_weight_data_size(ukernel_config, n, k, group_size); auto packed_weight_data_alignment = get_packed_weight_data_alignment(ukernel_config); - auto packed_weight_data = torchao::make_aligned_byte_array_unique_ptr( + auto packed_weight_data = torchao::make_aligned_byte_ptr( packed_weight_data_alignment, packed_weight_data_size); pack_weight_data_operator( @@ -74,7 +104,7 @@ void test_channelwise_8bit_activation_groupwise_lowbit_weight( group_size); auto activation_data_buffer_alignment = get_activation_data_buffer_alignment(ukernel_config); - auto activation_data_buffer = torchao::make_aligned_byte_array_unique_ptr( + auto activation_data_buffer = torchao::make_aligned_byte_ptr( activation_data_buffer_alignment, activation_data_buffer_size); // Run linear @@ -153,9 +183,6 @@ TEST( int n = 1; int k = 16 + 1; int group_size = 16; - - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< 3 /*weight_nbit*/, true /*has_weight_zeros*/, @@ -187,8 +214,6 @@ TEST( int k = 20; int group_size = 10; - using namespace torchao::operators::cpu::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight; auto ukernel_config = get_ukernel_config< 3 /*weight_nbit*/, true /*has_weight_zeros*/, diff --git a/torchao/experimental/kernels/cpu/macro.h b/torchao/experimental/ops/macro.h similarity index 100% rename from torchao/experimental/kernels/cpu/macro.h rename to torchao/experimental/ops/macro.h diff --git a/torchao/experimental/kernels/cpu/memory.h b/torchao/experimental/ops/memory.h similarity index 100% rename from torchao/experimental/kernels/cpu/memory.h rename to torchao/experimental/ops/memory.h diff --git a/torchao/experimental/kernels/cpu/parallel-aten-impl.h b/torchao/experimental/ops/parallel-aten-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-aten-impl.h rename to torchao/experimental/ops/parallel-aten-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-openmp-impl.h b/torchao/experimental/ops/parallel-openmp-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-openmp-impl.h rename to torchao/experimental/ops/parallel-openmp-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-pthreadpool-impl.h b/torchao/experimental/ops/parallel-pthreadpool-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-pthreadpool-impl.h rename to torchao/experimental/ops/parallel-pthreadpool-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-single_threaded-impl.h b/torchao/experimental/ops/parallel-single_threaded-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-single_threaded-impl.h rename to torchao/experimental/ops/parallel-single_threaded-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel-test_dummy-impl.h b/torchao/experimental/ops/parallel-test_dummy-impl.h similarity index 100% rename from torchao/experimental/kernels/cpu/parallel-test_dummy-impl.h rename to torchao/experimental/ops/parallel-test_dummy-impl.h diff --git a/torchao/experimental/kernels/cpu/parallel.h b/torchao/experimental/ops/parallel.h similarity index 73% rename from torchao/experimental/kernels/cpu/parallel.h rename to torchao/experimental/ops/parallel.h index 0d12c3acf9..e3949b8551 100644 --- a/torchao/experimental/kernels/cpu/parallel.h +++ b/torchao/experimental/ops/parallel.h @@ -10,7 +10,7 @@ namespace torchao { // F has signature [&](int64_t idx) template -void parallel_1d(const int64_t begin, const int64_t end, const F& f); +void parallel_1d(const int64_t begin, const int64_t end, const F& f); void set_num_threads(int num_threads); @@ -18,16 +18,17 @@ int get_num_threads(); } // namespace torchao - #ifdef TORCHAO_PARALLEL_ATEN #pragma message("TORCHAO_PARALLEL_ATEN is set. Using ATen parallel backend.") #ifndef INTRA_OP_PARALLEL - #pragma message("INTRA_OP_PARALLEL is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") +#pragma message( \ + "INTRA_OP_PARALLEL is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") #endif #ifndef AT_PARALLEL_OPENMP - #pragma message("AT_PARALLEL_OPENMP is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") +#pragma message( \ + "AT_PARALLEL_OPENMP is not set; TORCHAO_PARALLEL_ATEN may be single-threaded.") #endif -#include +#include #else #ifdef TORCHAO_PARALLEL_EXECUTORCH @@ -40,24 +41,25 @@ int get_num_threads(); #ifdef TORCHAO_PARALLEL_PTHREADPOOL #pragma message( \ "TORCHAO_PARALLEL_PTHREADPOOL is set. Using pthreadpool parallel backend.") -#include +#include #else #ifdef TORCHAO_PARALLEL_OPENMP -#pragma message("TORCHAO_PARALLEL_OPENMP is set. Using OPENMP parallel backend.") -#include +#pragma message( \ + "TORCHAO_PARALLEL_OPENMP is set. Using OPENMP parallel backend.") +#include #else #if defined TORCHAO_PARALLEL_SINGLE_THREADED #pragma message( \ "TORCHAO_PARALLEL_SINGLE_THREADED is set. Using single-threaded parallel backend.") -#include +#include #else #if defined TORCHAO_PARALLEL_TEST_DUMMY #pragma message( \ "TORCHAO_PARALLEL_TEST_DUMMY is set. Using test dummy parallel backend.") -#include +#include #else #error \ diff --git a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py b/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py similarity index 63% rename from torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py rename to torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py index 513088d2f0..d431d26939 100644 --- a/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py +++ b/torchao/experimental/tests/test_int8_dyn_act_intx_weight_quantizer.py @@ -11,18 +11,19 @@ import sys import unittest +import tempfile import torch sys.path.insert( - 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) ) from quant_api import ( _Int8DynActIntxWeightQuantizedLinearFallback, Int8DynActIntxWeightQuantizer, ) -libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") +libs = glob.glob("/tmp/cmake-out/torchao/lib/liblinear_a8wxdq_ATEN.*") libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) if len(libs) == 0: print( @@ -73,7 +74,49 @@ def test_accuracy(self): # Assert at most 5% of entries are not close at a low tolerance self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + + def test_export_compile_aoti(self): + group_size = 32 + m = 1 + n = 256 + k = 256 + nbit = 4 + has_weight_zeros = False + n_layers = 3 + layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)] + model = torch.nn.Sequential(*layers) + + activations = torch.randn(m, k, dtype=torch.float32) + + print("Quantizing model") + quantizer = Int8DynActIntxWeightQuantizer( + device="cpu", + precision=torch.float32, + bitwidth=nbit, + groupsize=group_size, + has_weight_zeros=has_weight_zeros, + ) + quantized_model = quantizer.quantize(model) + + print("Exporting quantized model") + exported = torch.export.export(quantized_model, (activations,)) + + print("Compiling quantized model") + quantized_model_compiled = torch.compile(quantized_model) + with torch.no_grad(): + quantized_model_compiled(activations) + + with tempfile.TemporaryDirectory() as tmpdirname: + print("Exporting quantized model with AOTI") + torch._export.aot_compile( + quantized_model, + (activations,), + options={"aot_inductor.output_path": f"{tmpdirname}/model.so"}, + ) + print("Running quantized model in AOTI") + fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu") + fn(activations) if __name__ == "__main__": unittest.main()