Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions rope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,23 @@ namespace Rope {
int axes_dim_num,
int index = 0,
int h_offset = 0,
int w_offset = 0) {
int w_offset = 0,
bool scale_rope = false) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;

std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));

std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
int h_start = h_offset;
int w_start = w_offset;

if (scale_rope) {
h_start -= h_len / 2;
w_start -= w_len / 2;
}

std::vector<float> row_ids = linspace<float>(h_start, h_start + h_len - 1, h_len);
std::vector<float> col_ids = linspace<float>(w_start, w_start + w_len - 1, w_len);

for (int i = 0; i < h_len; ++i) {
for (int j = 0; j < w_len; ++j) {
Expand Down Expand Up @@ -171,7 +180,8 @@ namespace Rope {
int axes_dim_num,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
float ref_index_scale) {
float ref_index_scale,
bool scale_rope) {
std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0;
Expand All @@ -185,6 +195,7 @@ namespace Rope {
} else {
h_offset = curr_h_offset;
}
scale_rope = false;
}

auto ref_ids = gen_flux_img_ids(ref->ne[1],
Expand All @@ -194,7 +205,8 @@ namespace Rope {
axes_dim_num,
static_cast<int>(index * ref_index_scale),
h_offset,
w_offset);
w_offset,
scale_rope);
ids = concat_ids(ids, ref_ids, bs);

if (increase_ref_index) {
Expand Down Expand Up @@ -222,7 +234,7 @@ namespace Rope {

auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
Expand Down Expand Up @@ -271,10 +283,10 @@ namespace Rope {
}
}
int axes_dim_num = 3;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
Expand Down
Loading