Skip to content

Commit

Permalink
Update embedding_split_host_pt2_autograd_template.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
spcyppt authored Jan 29, 2025
1 parent 0db643a commit c0efcc0
Showing 1 changed file with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ enum SSDTensor {
{%- endif %}
hash_size_cumsum,
indices,
offsets,
offsets_,
{%- if not nobag %}
pooling_mode,
indice_weights_value,
Expand Down Expand Up @@ -706,6 +706,13 @@ class {{ autograd_func }} :
info_B_num_bits,
/*total_B=*/offsets.sym_size(0) - 1
);
Tensor offsets_;
if (weights_host.numel()){
offsets_ = reshape_offsets(offsets_, B_offsets, max_B, T);
}
else {
offsets_ = offsets;
}
{%- endif %} // vbe

{%- if is_gwd %}
Expand All @@ -728,7 +735,7 @@ class {{ autograd_func }} :
{%- endif %}
hash_size_cumsum,
indices,
offsets,
offsets_,
{%- if not nobag %}
indice_weights_value,
feature_requires_grad.value_or(Tensor()),
Expand Down

0 comments on commit c0efcc0

Please sign in to comment.