Skip to content

Commit 9568209

Browse files
committed
Use weight cache for quantized tensor scale data
Summary: When enabling the XNNPACK weight cache and running a model with qb4 or qc8-quantized linear weights, it triggers an assertion that is intended to make sure all data is in the weight cache. This can be reproduced by running the XNNPACK backend linear op tests with weight cache enabled. The root cause appears to be that tensor scale data was bypassing the weight cache - likely an oversight in initial implementation. This isn't a correctness issue, but does cause the aforementioned assert to fail and uses marginally more memory than it otherwise needs to. This PR updates the XNNPACK compileModel call to use the weight cache for scale data (instead of putting it in the unpacked_buffers list). With this change, the linear op tests pass with weight cache enabled. Differential Revision: D82862629
1 parent 07d1092 commit 9568209

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,15 @@ Error defineTensor(
440440
qparams->scale_buffer_idx());
441441
const std::string& data_name =
442442
scale_buffer_offset->named_key()->str();
443+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
444+
auto load_result = weights_cache->load_unpacked_data(data_name);
445+
ET_CHECK_OR_RETURN_ERROR(
446+
load_result.ok(),
447+
Internal,
448+
"Failed to load block scales from cache: %u.",
449+
load_result.error());
450+
scale = reinterpret_cast<const float*>(load_result.get());
451+
#else // ENABLE_XNNPACK_WEIGHTS_CACHE disabled
443452
Result<FreeableBuffer> scale_buffer =
444453
named_data_map->get_data(data_name.c_str());
445454
ET_CHECK_OR_RETURN_ERROR(
@@ -450,6 +459,7 @@ Error defineTensor(
450459
static_cast<uint32_t>(scale_buffer.error()));
451460
scale = reinterpret_cast<const float*>(scale_buffer.get().data());
452461
freeable_buffers.push_back(std::move(scale_buffer.get()));
462+
#endif
453463
}
454464
status = xnn_define_channelwise_quantized_tensor_value_v2(
455465
/*subgraph=*/subgraph_ptr,
@@ -488,6 +498,15 @@ Error defineTensor(
488498
qparams->scale_buffer_idx());
489499
const std::string& data_name =
490500
scale_buffer_offset->named_key()->str();
501+
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
502+
auto load_result = weights_cache->load_unpacked_data(data_name);
503+
ET_CHECK_OR_RETURN_ERROR(
504+
load_result.ok(),
505+
Internal,
506+
"Failed to load tensor scales from cache: %u.",
507+
load_result.error());
508+
scale_data = reinterpret_cast<const uint16_t*>(load_result.get());
509+
#else // ENABLE_XNNPACK_WEIGHTS_CACHE disabled
491510
Result<FreeableBuffer> scale_buffer =
492511
named_data_map->get_data(data_name.c_str());
493512
ET_CHECK_OR_RETURN_ERROR(
@@ -499,6 +518,7 @@ Error defineTensor(
499518
scale_data =
500519
reinterpret_cast<const uint16_t*>(scale_buffer.get().data());
501520
freeable_buffers.push_back(std::move(scale_buffer.get()));
521+
#endif
502522
scale_numel = qparams->num_scales();
503523
} else {
504524
// Read fp32 scales, convert to bf16.

0 commit comments

Comments
 (0)