Skip to content

Commit

Permalink
memory stability fix, added beta schedule helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
balisujohn committed May 11, 2024
1 parent 79230d4 commit 15dab93
Showing 1 changed file with 68 additions and 5 deletions.
73 changes: 68 additions & 5 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ int32_t NUM_RETURN_SEQUENCES = 4; //hardcoding this for now, analagous to "num_r

std::mt19937 generator(245645656);
std::uniform_real_distribution<float> distribution(0.0, 1.0);
std::normal_distribution<double> normal_distribution(0.0,1.0);

void localAssert(bool condition)
{
Expand Down Expand Up @@ -541,7 +542,7 @@ bool autoregressive_model_load(const std::string & fname, autoregressive_model &
model.language_model_head_linear_bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 8194);


model.layers.resize(1);
model.layers.resize(30);
for (int i= 0; i < 30; i ++)
{
auto & layer = model.layers[i];
Expand Down Expand Up @@ -864,7 +865,7 @@ bool diffusion_model_load(const std::string & fname, diffusion_model & model)



model.latent_conditioner_attention_blocks.resize(1);
model.latent_conditioner_attention_blocks.resize(4);
for (int i = 1; i < 5; i ++)
{

Expand Down Expand Up @@ -2343,7 +2344,7 @@ struct ggml_cgraph * diffusion_graph(

//float scale_factor = output_sequence_length / latent_length;

//cur = ggml_upscale(ctx0, cur, scale_factor);
cur = ggml_upscale_to_shape(ctx0, cur, output_sequence_length, 1024, 1,1);


ggml_set_name(cur, "output");
Expand Down Expand Up @@ -2591,6 +2592,17 @@ void top_p_inplace(std::vector<float > & src){
}


std::vector<float> sample_diffusion_noise(int length)
{
std::vector<float> noise(length);
for (int i = 0; i < length; i ++)
{
noise[i] = normal_distribution(generator);
}
return noise;
}


int multinomial( std::vector<float> probs) // worth changing to a binary search at some point, but for now done the simple way
{

Expand Down Expand Up @@ -3323,6 +3335,21 @@ std::pair<std::vector<std::vector<float>>, std::vector<std::vector<int>>> autore
}


std::vector<double> get_beta_schedule(int num_diffusion_timesteps) {
double scale = 1000.0 / num_diffusion_timesteps;
double beta_start = scale * 0.0001;
double beta_end = scale * 0.02;

std::vector<double> betas;
for (int i = 0; i < num_diffusion_timesteps; ++i) {
betas.push_back(beta_start + i * (float)(beta_end - beta_start) / (num_diffusion_timesteps - 1));
}

return betas;
}




/*
Expand Down Expand Up @@ -3357,7 +3384,7 @@ std::vector<float> load_f32_vector(const std::string& filename, size_t nBytes) {
}

//thanks gpt3.5 !
void save_f32_vector(const std::string& filename, const std::vector<std::vector<float>>& vectors) {
void save_f32_vectors(const std::string& filename, const std::vector<std::vector<float>>& vectors) {
std::ofstream file(filename, std::ios::binary);
if (!file.is_open()) {
std::cerr << "Error: Unable to open file " << filename << " for writing." << std::endl;
Expand All @@ -3376,6 +3403,23 @@ void save_f32_vector(const std::string& filename, const std::vector<std::vector<
}


void save_f32_vector(const std::string& filename, const std::vector<float>& vector) {
std::ofstream file(filename, std::ios::binary);
if (!file.is_open()) {
std::cerr << "Error: Unable to open file " << filename << " for writing." << std::endl;
return;
}

// Write each vector
size_t numFloats = vector.size();
// Write vector elements
file.write(reinterpret_cast<const char*>(vector.data()), numFloats * sizeof(float));


file.close();
}





Expand Down Expand Up @@ -3446,7 +3490,7 @@ void test_autoregressive(){



//save_f32_vector("../assets/target_trimmed_latents.bin", trimmed_latents);
//save_f32_vectors("../assets/target_trimmed_latents.bin", trimmed_latents);
std::vector<float> target_trimmed_latents = load_f32_vector("../assets/target_trimmed_latents.bin" , trimmed_latents_size * sizeof(float)); // 4 is the number of bytes in a float.

std::vector<std::vector<int>> target_sequences ={{8, 7406, 6450, 1601, 2061, 4389, 4954, 134, 1554, 372, 3666, 1580, 20, 83, 45, 8, 248, 8012, 2483, 7396, 37, 7784, 3008, 1126, 283, 1609, 2376, 2061, 4992, 3330, 1350, 469, 1022, 7005, 8193, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 45, 45, 248},
Expand Down Expand Up @@ -3531,8 +3575,10 @@ int main(int argc, char ** argv) {
}
}


std::cout << "reached" << std::endl;


//ggml_backend_t temp_backend = ggml_backend_cuda_init();

ggml_gallocr_t diffusion_allocr = NULL;
Expand Down Expand Up @@ -3609,7 +3655,24 @@ int main(int argc, char ** argv) {



std::vector<float> noise = sample_diffusion_noise( 100 * output_sequence_length);
save_f32_vector("./logs/diffusion_noise.bin", noise);

std::vector<double> beta_schedule = get_beta_schedule(4000);

// Print the first three entries
std::cout << "First three betas: ";
for (int i = 0; i < 3; ++i) {
std::cout << beta_schedule[i] << " ";
}
std::cout << std::endl;

// Print the last three entries
std::cout << "Last three betas: ";
for (int i = beta_schedule.size() - 3; i < beta_schedule.size(); ++i) {
std::cout << beta_schedule[i] << " ";
}
std::cout << std::endl;


return 0;
Expand Down

0 comments on commit 15dab93

Please sign in to comment.