Skip to content

Commit

Permalink
Patch D66310520 to make it build in OSS (#3409)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3409

X-link: facebookresearch/FBGEMM#497

- Patch D66310520 to make the code build in OSS

Differential Revision: D66399304
  • Loading branch information
q10 authored and facebook-github-bot committed Nov 25, 2024
1 parent ee1424f commit 80384ca
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 3 deletions.
23 changes: 22 additions & 1 deletion fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -295,19 +295,30 @@ foreach(optimizer ${SSD_OPTIMIZERS})
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_cta.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_warp.cu")
endforeach()

foreach(wdesc weighted unweighted)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_cuda.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_kernel_cta.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_vbe_kernel_warp.cu")
endforeach()

endforeach()

list(APPEND gen_defused_optim_py_files
${CMAKE_BINARY_DIR}/optimizer_args.py)


################################################################################
# FBGEMM_GPU Generated HIP-Specific Sources
################################################################################

set(gen_hip_kernel_source_files)
foreach(wdesc weighted unweighted unweighted_nobag)
list(APPEND gen_hip_kernel_source_files
"gen_embedding_backward_split_${wdesc}_device_kernel_hip.hip")
endforeach()


################################################################################
# FBGEMM_GPU Generated Sources
################################################################################
Expand Down Expand Up @@ -563,6 +574,9 @@ set(fbgemm_gpu_sources_gpu_gen
${gen_gpu_host_source_files}
${gen_defused_optim_source_files})

set(fbgemm_gpu_sources_hip_gen
${gen_hip_kernel_source_files})

if(USE_ROCM)
prepend_filepaths(
PREFIX ${CMAKE_BINARY_DIR}
Expand All @@ -573,6 +587,11 @@ if(USE_ROCM)
PREFIX ${CMAKE_BINARY_DIR}
INPUT ${fbgemm_gpu_sources_gpu_gen}
OUTPUT fbgemm_gpu_sources_gpu_gen)

prepend_filepaths(
PREFIX ${CMAKE_BINARY_DIR}
INPUT ${fbgemm_gpu_sources_hip_gen}
OUTPUT fbgemm_gpu_sources_hip_gen)
endif()


Expand All @@ -591,6 +610,8 @@ gpu_cpp_library(
GPU_SRCS
${fbgemm_gpu_sources_gpu_static}
${fbgemm_gpu_sources_gpu_gen}
HIP_SPECIFIC_SRCS
${fbgemm_gpu_sources_hip_gen}
OTHER_SRCS
${asmjit_sources}
${fbgemm_sources}
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def generate() -> None:
BackwardSplitGenerator.generate_backward_split(
ssd_tensors=ssd_tensors, **optimizer
)
BackwardSplitGenerator.generate_rocm_backward_split(**optimizer)
BackwardSplitGenerator.generate_rocm_backward_split()

# Generate common device kernels for backwards
BackwardSplitGenerator.generate_backward_device()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def __init__( # noqa C901
assert (
not mixed_D
), "OptimType.NONE does not support mixed embedding dimension"
self.mixed_D = mixed_D
self.mixed_D: bool = mixed_D
if device is None:
self.current_device: torch.device = (
torch.device("cpu")
Expand Down Expand Up @@ -3442,6 +3442,15 @@ def __init__(
torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32),
)
assert self.D_offsets.numel() == T + 1

mixed_D = False
D = dims[0]
for d in dims:
if d != D:
mixed_D = True
break
self.mixed_D: bool = mixed_D

# Required for VBE
self.register_buffer(
"feature_dims",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
*
******************************************************************************/
#pragma once
#include <c10/util/Half.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>

Expand Down

0 comments on commit 80384ca

Please sign in to comment.