Skip to content

Commit

Permalink
2024-11-15 nightly release (abbb5dc)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 15, 2024
1 parent 52a9559 commit 2007686
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ enum SSDTensor {
class {{ autograd_func }} :
public torch::autograd::Function<{{ autograd_func }}> {
public:
static constexpr bool is_traceable = true;
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& placeholder_autograd_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__( # noqa C901
cache_load_factor: float = 0.2,
cache_sets: int = 0,
cache_reserved_memory: float = 0.0,
cache_precision: SparseType = SparseType.FP32,
cache_precision: Optional[SparseType] = None,
weights_precision: SparseType = SparseType.FP32,
output_dtype: SparseType = SparseType.FP32,
enforce_hbm: bool = False,
Expand Down Expand Up @@ -619,6 +619,7 @@ def __init__( # noqa C901
uvm_host_mapped: bool = False,
) -> None:
super(SplitTableBatchedEmbeddingBagsCodegen, self).__init__()

self.uuid = str(uuid.uuid4())
self.logging_table_name: str = self.get_table_name_for_logging(table_names)
self.pooling_mode = pooling_mode
Expand All @@ -627,6 +628,9 @@ def __init__( # noqa C901
os.environ.get("FBGEMM_TBE_BOUNDS_CHECK_MODE", bounds_check_mode.value)
)
self.weights_precision = weights_precision
cache_precision = (
weights_precision if cache_precision is None else cache_precision
)
self.output_dtype: int = output_dtype.as_int()
assert (
not prefetch_pipeline or cache_algorithm == CacheAlgorithm.LRU
Expand Down Expand Up @@ -1175,20 +1179,13 @@ def __init__( # noqa C901
),
)

if cache_precision == SparseType.FP32:
cache_embedding_dtype = torch.float32
elif cache_precision == SparseType.FP16:
cache_embedding_dtype = torch.float16
else:
raise AssertionError(f"cache_precision {cache_precision} not supported!")

self._apply_cache_state(
cache_state,
cache_algorithm,
cache_load_factor,
cache_sets,
cache_reserved_memory,
dtype=cache_embedding_dtype,
cache_precision,
)

self.log(f"Contents: {table_names}")
Expand Down Expand Up @@ -2643,7 +2640,7 @@ def _apply_cache_state(
cache_load_factor: float,
cache_sets: int,
cache_reserved_memory: float,
dtype: torch.dtype,
cache_precision: SparseType,
) -> None:
self.cache_algorithm = cache_algorithm
self.timestep = 1
Expand All @@ -2663,6 +2660,17 @@ def _apply_cache_state(

self._init_uvm_cache_stats()

if cache_precision == SparseType.FP32:
dtype = torch.float32
elif cache_precision == SparseType.FP16:
dtype = torch.float16
else:
dtype = torch.float32 # not relevant, but setting it to keep linter happy
if not self.use_cpu > 0:
raise AssertionError(
f"cache_precision {cache_precision} not supported!"
)

# NOTE: no cache for CPU mode!
if cache_state.total_cache_hash_size == 0 or self.use_cpu:
self.register_buffer(
Expand Down Expand Up @@ -2740,7 +2748,8 @@ def _apply_cache_state(
f"{cache_algorithm}, {cache_sets} sets, "
f"load_factor: {cache_load_factor : .3f}, "
f"cache_size: {cache_size / 1024.0 / 1024.0 / 1024.0 : .2f}GB, "
f"cache_precision: {dtype}"
f"cache_precision: {dtype}, "
f"weights_precision: {self.weights_precision}"
)

self.total_cache_hash_size = cache_state.total_cache_hash_size
Expand Down
18 changes: 10 additions & 8 deletions src/EmbeddingSpMDMAutovec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

namespace fbgemm {

static constexpr size_t LOCAL_STORAGE_SIZE = 512;

template <typename OutType>
static inline void fill_output(
OutType* out,
Expand Down Expand Up @@ -99,10 +101,10 @@ bool EmbeddingSpMDM8Bit_autovec(
const int64_t scale_bias_offset = scale_bias_last ? block_size : 0;
const int64_t input_offset = scale_bias_last ? 0 : scale_bias_size;

std::array<float, 256> local_storage;
std::array<float, LOCAL_STORAGE_SIZE> local_storage;
std::unique_ptr<float[]> heap_storage;
float* buf;
if (block_size <= 256) {
if (static_cast<size_t>(block_size) <= LOCAL_STORAGE_SIZE) {
buf = local_storage.data();
} else {
heap_storage.reset(new float[block_size]);
Expand Down Expand Up @@ -363,10 +365,10 @@ bool EmbeddingSpMDMNBit_autovec(
int64_t current = 0;
const int64_t rounded_block_size = round_up(block_size, num_elem_per_byte);

std::array<float, 256> local_storage;
std::array<float, LOCAL_STORAGE_SIZE> local_storage;
std::unique_ptr<float[]> heap_storage;
float* buf;
if (rounded_block_size <= 256) {
if (static_cast<size_t>(rounded_block_size) <= LOCAL_STORAGE_SIZE) {
buf = local_storage.data();
} else {
heap_storage.reset(new float[rounded_block_size]);
Expand Down Expand Up @@ -504,10 +506,10 @@ bool EmbeddingSpMDM_autovec(
output_stride = block_size;
}

std::array<float, 256> local_storage;
std::array<float, LOCAL_STORAGE_SIZE> local_storage;
std::unique_ptr<float[]> heap_storage;
float* buf;
if (block_size <= 256) {
if (static_cast<size_t>(block_size) <= LOCAL_STORAGE_SIZE) {
buf = local_storage.data();
} else {
heap_storage.reset(new float[block_size]);
Expand Down Expand Up @@ -862,10 +864,10 @@ bool EmbeddingSpMDMFP8_autovec(
output_stride = block_size;
}

std::array<float, 256> local_storage;
std::array<float, LOCAL_STORAGE_SIZE> local_storage;
std::unique_ptr<float[]> heap_storage;
float* buf;
if (block_size <= 256) {
if (static_cast<size_t>(block_size) <= LOCAL_STORAGE_SIZE) {
buf = local_storage.data();
} else {
heap_storage.reset(new float[block_size]);
Expand Down

0 comments on commit 2007686

Please sign in to comment.