Skip to content

Commit

Permalink
Fix bug where gradient checkpointing will generate extra parameter_in…
Browse files Browse the repository at this point in the history
…dices.
  • Loading branch information
liuliu committed Nov 15, 2024
1 parent dbdf28d commit 4ecbf13
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
3 changes: 2 additions & 1 deletion lib/nnc/_ccv_cnnp_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,9 @@ static inline void ccv_cnnp_model_add_to_parameter_indices(ccv_cnnp_model_t* con
}

typedef struct {
ccv_cnnp_model_sequence_t* sequence;
uint8_t add_parameter_indices;
char prefix;
ccv_cnnp_model_sequence_t* sequence;
ccv_array_t* symbols;
ccv_array_t* ids;
ccv_array_t* trainables;
Expand Down
12 changes: 7 additions & 5 deletions lib/nnc/ccv_cnnp_model.c
Original file line number Diff line number Diff line change
Expand Up @@ -133,22 +133,22 @@ void ccv_cnnp_model_add_to_array(void* const context, const ccv_nnc_tensor_symbo
ccv_cnnp_model_add_to_array_context_t* const add_to_array_context = (ccv_cnnp_model_add_to_array_context_t*)context;
ccv_cnnp_model_t* const model = add_to_array_context->sequence->model;
int i;
if (!model->parameter_indices)
if (add_to_array_context->add_parameter_indices && !model->parameter_indices)
model->parameter_indices = ccv_array_new(sizeof(int), 0, 0);
for (i = 0; i < add_to_array_context->symbols->rnum; i++)
{
const ccv_nnc_tensor_symbol_t other_symbol = *(ccv_nnc_tensor_symbol_t*)ccv_array_get(add_to_array_context->symbols, i);
if (other_symbol.d == symbol.d && other_symbol.graph == symbol.graph)
{
// Only add to parameter_indices if it is trainable.
if (add_to_array_context->prefix == 't')
if (add_to_array_context->add_parameter_indices)
ccv_array_add_unique_int(model->parameter_indices, i);
// Found it, return, don't add it.
return;
}
}
// Only add to parameter_indices if it is trainable.
if (add_to_array_context->prefix == 't')
if (add_to_array_context->add_parameter_indices)
ccv_array_push(model->parameter_indices, &add_to_array_context->symbols->rnum);
// This is a new one, no need to add_unique_int, it is unique.
ccv_array_push(add_to_array_context->symbols, &symbol);
Expand Down Expand Up @@ -202,17 +202,19 @@ static void _ccv_cnnp_model_compile(ccv_cnnp_model_t* const model, const ccv_nnc
.bank = kh_init(ccv_cnnp_model_name_bank)
};
ccv_cnnp_model_add_to_array_context_t add_to_parameter_context = {
.sequence = &model_sequence,
.add_parameter_indices = 1,
.prefix = 't',
.sequence = &model_sequence,
.symbols = parameters,
.ids = parameter_ids,
.trainables = parameter_trainables,
};
ccv_array_t* const internals = ccv_array_new(sizeof(ccv_nnc_tensor_symbol_t), 0, 0);
ccv_array_t* const internal_ids = ccv_array_new(sizeof(char*), 0, 0);
ccv_cnnp_model_add_to_array_context_t add_to_output_context = {
.sequence = &model_sequence,
.add_parameter_indices = 0,
.prefix = 'r',
.sequence = &model_sequence,
.symbols = internals,
.ids = internal_ids,
.trainables = 0,
Expand Down
6 changes: 4 additions & 2 deletions lib/nnc/ccv_cnnp_model_gradient_checkpointing.c
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,17 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
.bank = kh_init(ccv_cnnp_model_name_bank)
};
ccv_cnnp_model_add_to_array_context_t add_to_parameter_context = {
.sequence = &model_sequence,
.add_parameter_indices = 0,
.prefix = 't',
.sequence = &model_sequence,
.symbols = parameters,
.ids = parameter_ids,
.trainables = parameter_trainables,
};
ccv_cnnp_model_add_to_array_context_t add_to_output_context = {
.sequence = &model_sequence,
.add_parameter_indices = 0,
.prefix = 'r',
.sequence = &model_sequence,
.symbols = internals,
.ids = internal_ids,
.trainables = 0,
Expand Down

0 comments on commit 4ecbf13

Please sign in to comment.