Skip to content

Commit

Permalink
numbers matching through logits processor, decided not to use gather …
Browse files Browse the repository at this point in the history
…cuda op, instead used C++ function
  • Loading branch information
balisujohn committed Feb 7, 2024
1 parent 1282153 commit c2d5c72
Showing 1 changed file with 173 additions and 8 deletions.
181 changes: 173 additions & 8 deletions examples/tortoise/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ bool autoregressive_model_load(const std::string & fname, autoregressive_model &
struct ggml_cgraph * autoregressive_graph(
const autoregressive_model & model,
struct ggml_allocr * allocr,
const std::vector<int> mel_transformer_inputs_vector,
const std::vector<gpt_vocab::id> & tokens){

const int token_count = tokens.size();
Expand Down Expand Up @@ -543,6 +544,7 @@ struct ggml_cgraph * autoregressive_graph(

struct ggml_tensor * reshaped_embedding = ggml_reshape_4d(ctx0, embedding, 1,1,token_count,1024);

/*
struct ggml_tensor * fake_inputs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, token_count+2);
ggml_allocr_alloc(allocr, fake_inputs);
if (!ggml_allocr_is_measure(allocr)) {
Expand All @@ -556,13 +558,24 @@ struct ggml_cgraph * autoregressive_graph(
}
int32_t truncation_index = token_count + 2;

*/


struct ggml_tensor * mel_transformer_inputs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32,4*( token_count+2));
ggml_allocr_alloc(allocr, mel_transformer_inputs);

mel_transformer_inputs = ggml_repeat(ctx0, fake_inputs, mel_transformer_inputs);
if (!ggml_allocr_is_measure(allocr)) {
for (int i = 0; i < 4*( token_count+2); ++i) {
int v = mel_transformer_inputs_vector[i];
ggml_backend_tensor_set(mel_transformer_inputs, &v, i*sizeof(int32_t), sizeof(v));

}

}



//mel_transformer_inputs = ggml_repeat(ctx0, fake_inputs, mel_transformer_inputs);

mel_transformer_inputs = ggml_reshape_2d(ctx0, mel_transformer_inputs, 4, (token_count + 2));

Expand Down Expand Up @@ -829,14 +842,19 @@ struct ggml_cgraph * autoregressive_graph(

next_token_logits = ggml_reshape_4d(ctx0, next_token_logits, 8194, 4, 1,1);

mel_transformer_inputs = ggml_reshape_4d(ctx0, mel_transformer_inputs, 18, 4, 1, 1);
//mel_transformer_inputs = ggml_reshape_4d(ctx0, mel_transformer_inputs, 18, 4, 1, 1);

ggml_tensor * score = ggml_gather(ctx0, next_token_logits, mel_transformer_inputs, 1);
//ggml_tensor * score = ggml_gather(ctx0, next_token_logits, mel_transformer_inputs, 1);


std::cout << "didn't reach here" << std::endl;

ggml_build_forward_expand(gf, score);
ggml_build_forward_expand(gf, next_token_logits);

//embd_w.resize(n_vocab);
// memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);



std::cout << "reached end graph build" << std::endl;

Expand All @@ -848,6 +866,99 @@ struct ggml_cgraph * autoregressive_graph(
}


template <typename T>
void printVector(std::vector<T> vector, int n, std::string name) {
std::cout << name << ":\n";

// Print first n elements
for (int i = 0; i < n && i < vector.size(); i++) {
std::cout << vector[i] << " ";
}

std::cout << "\n";

// Print last n elements
for (int i = vector.size() - n; i < vector.size(); i++) {
std::cout << vector[i] << " ";
}

std::cout << std::endl;
}

std::vector<float> apply_penalty(const std::vector<float> score, float penalty) {
std::vector<float> result(score.size());
for (size_t i = 0; i < score.size(); ++i) {
result[i] = (score[i] < 0) ? score[i] * penalty : score[i] / penalty;
}
return result;
}


std::vector<float> gather(std::vector<float> src, std::vector<int> input_ids)
{

const int BATCH_SIZE = 4; //hardcoding for now;
const int sequence_length = input_ids.size()/4;
const int vocab_size = src.size()/4; //this is 8194, hardcoding for now

std::vector<float> result(input_ids.size());

for (int i = 0; i < input_ids.size(); i ++)
{

const int rowIndex = i / sequence_length;

const int colIndex = input_ids[i];


result[i] = src[rowIndex * vocab_size + colIndex];
}
std::cout << "gather result" << std::endl;
return result;
}

std::vector<float> scatter(std::vector<float> src1, std::vector<float> src2, std::vector<int> input_ids)
{
std::vector<float> result;
result.resize(src1.size());
std::copy(src1.begin(), src1.end(), result.begin());

const int BATCH_SIZE = 4; //hardcoding for now;
const int sequence_length = input_ids.size()/4;
const int vocab_size = src1.size()/4; //this is 8194, hardcoding for now

//std::vector<float> result(input_ids.size());

for (int i = 0; i < input_ids.size(); i ++)
{

const int rowIndex = i / sequence_length;

const int colIndex = input_ids[i];




result[rowIndex * vocab_size + colIndex] = src2[i];
}
printVector(result, 3, "scatter_result");
return result;


}

void temp_inplace(std::vector<float> &src, float temp)
{
for(int i = 0; i < src.size(); i++)
{
src[i] *= temp;
}
}





int main(int argc, char ** argv) {

std::cout << "hello world" << std::endl;
Expand Down Expand Up @@ -898,6 +1009,22 @@ int main(int argc, char ** argv) {
std::cout << "completed" << std::endl;


std::vector<int> mel_transformer_inputs_vector = std::vector<int>();
mel_transformer_inputs_vector.resize((tokens.size() + 2) * 4);
assert(tokens.size() == 16);

for (int i = 0; i < mel_transformer_inputs_vector.size(); i ++)
{
if (i % (tokens.size()+2) == tokens.size()+2-1){
mel_transformer_inputs_vector[i] = 8192;
}
else{
mel_transformer_inputs_vector[i] = 1;
}
}



ggml_backend_buffer_t buf_compute;

struct ggml_allocr * allocr = NULL;
Expand All @@ -913,7 +1040,7 @@ int main(int argc, char ** argv) {
//int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);
//int n_past = model.hparams.n_ctx - n_tokens;
ggml_allocr_reset(allocr);
struct ggml_cgraph * gf = autoregressive_graph(model, allocr, tokens);
struct ggml_cgraph * gf = autoregressive_graph(model, allocr, mel_transformer_inputs_vector, tokens);
ggml_graph_print(gf);

std::cout << "graph created" << std::endl;
Expand All @@ -924,14 +1051,52 @@ int main(int argc, char ** argv) {
ggml_allocr_reset(allocr);
buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size);
allocr = ggml_allocr_new_from_buffer(buf_compute);
gf = autoregressive_graph(model, allocr, tokens);
gf = autoregressive_graph(model, allocr,mel_transformer_inputs_vector, tokens);
ggml_allocr_alloc_graph(allocr, gf);
std::cout << "reached computing time" << std::endl;
ggml_backend_graph_compute(model.backend, gf);
ggml_graph_print(gf);




std::cout << "---------------------------------------------------" << std::endl;
ggml_tensor * next_token_logits = gf->nodes[gf->n_nodes-1];

std::cout << "NAME:" << std::endl;
std::cout << next_token_logits->name << std::endl;
std::cout << "TYPE" << std::endl;
std::cout << next_token_logits->type << std::endl;
std::cout << "SHAPE:" << std::endl;
std::cout << next_token_logits->ne[0]<< std::endl;
std::cout << next_token_logits->ne[1]<< std::endl;
std::cout << next_token_logits->ne[2]<< std::endl;
std::cout << next_token_logits->ne[3]<< std::endl;
std::cout << "DATA:" << std::endl;

int elements = next_token_logits->ne[0] * next_token_logits->ne[1] * next_token_logits->ne[2] * next_token_logits->ne[3];


std::vector<float> next_token_logits_vector( elements);
ggml_backend_tensor_get(next_token_logits,next_token_logits_vector.data(), 0 ,sizeof(float)* elements);
for (int c = 0; c < elements ; c++)
{

if (c < 3 || c > elements-4 || c == 1024*18-1|| c == 1024*18-2|| c == 1024*18 || c == 1024*18+2 || c == 17)
{

std::cout << (next_token_logits_vector.data()[c])<< std::endl;
}
}

std::cout << "reaced end" << std::endl;

std::vector<float> gather_result = gather(next_token_logits_vector, mel_transformer_inputs_vector);
gather_result = apply_penalty(gather_result, 2.0);
std::cout << "BEGIN" << std::endl;
std::vector<float> transformed_mel_transformer_inputs_vector = scatter(next_token_logits_vector, gather_result, mel_transformer_inputs_vector);

/*
for (int i =0; i < gf->n_nodes; i ++)
{
std::cout << "---------------------------------------------------" << std::endl;
Expand Down Expand Up @@ -1013,7 +1178,7 @@ int main(int argc, char ** argv) {
}

*/
// ggml_graph_print (gf);


Expand Down

0 comments on commit c2d5c72

Please sign in to comment.