Skip to content

[RFC] Support DistillT5 #652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions t5.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,25 @@ struct T5Block : public GGMLBlock {
}
};

struct T5Projection : public UnaryBlock {
public:
T5Projection(int64_t model_dim, int64_t projection_dim) {
blocks["0"] = std::shared_ptr<GGMLBlock>(new Linear(model_dim, projection_dim, false));
blocks["3"] = std::shared_ptr<GGMLBlock>(new Linear(projection_dim, projection_dim, false));
}

struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, n_token, model_dim]
auto wi = std::dynamic_pointer_cast<Linear>(blocks["0"]);
auto wo = std::dynamic_pointer_cast<Linear>(blocks["3"]);

x = wi->forward(ctx, x);
x = ggml_relu_inplace(ctx, x);
x = wo->forward(ctx, x);
return x;
}
};

struct T5Stack : public GGMLBlock {
int64_t num_layers;

Expand Down Expand Up @@ -682,6 +701,7 @@ struct T5Stack : public GGMLBlock {
auto final_layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["final_layer_norm"]);

x = final_layer_norm->forward(ctx, x);

return x;
}
};
Expand All @@ -692,9 +712,11 @@ struct T5 : public GGMLBlock {
int64_t model_dim,
int64_t ff_dim,
int64_t num_heads,
int64_t vocab_size) {
int64_t vocab_size,
int64_t projection_dim) {
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new T5Stack(num_layers, model_dim, model_dim, ff_dim, num_heads));
blocks["shared"] = std::shared_ptr<GGMLBlock>(new Embedding(vocab_size, model_dim));
blocks["final_projection"] = std::shared_ptr<GGMLBlock>(new T5Projection(model_dim, projection_dim));
}

struct ggml_tensor* forward(struct ggml_context* ctx,
Expand All @@ -709,6 +731,9 @@ struct T5 : public GGMLBlock {

auto x = shared->forward(ctx, input_ids);
x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket);

auto final_projection = std::dynamic_pointer_cast<T5Projection>(blocks["final_projection"]);
x = final_projection->forward(ctx, x);
return x;
}
};
Expand All @@ -720,12 +745,13 @@ struct T5Runner : public GGMLRunner {
T5Runner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
const std::string prefix,
int64_t num_layers = 24,
int64_t model_dim = 4096,
int64_t ff_dim = 10240,
int64_t num_heads = 64,
int64_t vocab_size = 32128)
: GGMLRunner(backend), model(num_layers, model_dim, ff_dim, num_heads, vocab_size) {
int64_t num_layers = 12,
int64_t model_dim = 768,
int64_t ff_dim = 2048,
int64_t num_heads = 12,
int64_t vocab_size = 32128,
int64_t projection_dim = 4096)
: GGMLRunner(backend), model(num_layers, model_dim, ff_dim, num_heads, vocab_size, projection_dim) {
model.init(params_ctx, tensor_types, prefix);
}

Expand Down Expand Up @@ -861,12 +887,13 @@ struct T5Embedder {
T5Embedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "",
int64_t num_layers = 24,
int64_t model_dim = 4096,
int64_t ff_dim = 10240,
int64_t num_heads = 64,
int64_t vocab_size = 32128)
: model(backend, tensor_types, prefix, num_layers, model_dim, ff_dim, num_heads, vocab_size) {
int64_t num_layers = 12,
int64_t model_dim = 768,
int64_t ff_dim = 2048,
int64_t num_heads = 12,
int64_t vocab_size = 32128,
int64_t projection_dim = 4096)
: model(backend, tensor_types, prefix, num_layers, model_dim, ff_dim, num_heads, vocab_size, projection_dim) {
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
Expand Down Expand Up @@ -983,4 +1010,4 @@ struct T5Embedder {
}
};

#endif // __T5_HPP__
#endif // __T5_HPP__
Loading