diff --git a/src/search_inter.c b/src/search_inter.c index 7e659c3b9..0586dfc9e 100644 --- a/src/search_inter.c +++ b/src/search_inter.c @@ -1374,6 +1374,143 @@ static void search_pu_inter_ref(inter_search_info_t *info, } +/** + * \brief Search bipred modes for a PU. + */ +static void search_pu_inter_bipred(inter_search_info_t *info, + int depth, + lcu_t *lcu, cu_info_t *cur_cu, + double *inter_cost, + uint32_t *inter_bitcost) +{ + const image_list_t *const ref = info->state->frame->ref; + uint8_t (*ref_LX)[16] = info->state->frame->ref_LX; + const videoframe_t * const frame = info->state->tile->frame; + const int cu_width = LCU_WIDTH >> depth; + const int x = info->origin.x; + const int y = info->origin.y; + const int width = info->width; + const int height = info->height; + + static const uint8_t priorityList0[] = { 0, 1, 0, 2, 1, 2, 0, 3, 1, 3, 2, 3 }; + static const uint8_t priorityList1[] = { 1, 0, 2, 0, 2, 1, 3, 0, 3, 1, 3, 2 }; + const unsigned num_cand_pairs = + MIN(info->num_merge_cand * (info->num_merge_cand - 1), 12); + + inter_merge_cand_t *merge_cand = info->merge_cand; + + for (int32_t idx = 0; idx < num_cand_pairs; idx++) { + uint8_t i = priorityList0[idx]; + uint8_t j = priorityList1[idx]; + if (i >= info->num_merge_cand || j >= info->num_merge_cand) break; + + // Find one L0 and L1 candidate according to the priority list + if (!(merge_cand[i].dir & 0x1) || !(merge_cand[j].dir & 0x2)) continue; + + if (ref_LX[0][merge_cand[i].ref[0]] == ref_LX[1][merge_cand[j].ref[1]] && + merge_cand[i].mv[0][0] == merge_cand[j].mv[1][0] && + merge_cand[i].mv[0][1] == merge_cand[j].mv[1][1]) + { + continue; + } + + int16_t mv[2][2]; + mv[0][0] = merge_cand[i].mv[0][0]; + mv[0][1] = merge_cand[i].mv[0][1]; + mv[1][0] = merge_cand[j].mv[1][0]; + mv[1][1] = merge_cand[j].mv[1][1]; + + // Don't try merge candidates that don't satisfy mv constraints. + if (!fracmv_within_tile(info, mv[0][0], mv[0][1]) || + !fracmv_within_tile(info, mv[1][0], mv[1][1])) + { + continue; + } + + kvz_inter_recon_bipred(info->state, + ref->images[ref_LX[0][merge_cand[i].ref[0]]], + ref->images[ref_LX[1][merge_cand[j].ref[1]]], + x, y, + width, + height, + mv, + lcu); + + const kvz_pixel *rec = &lcu->rec.y[SUB_SCU(y) * LCU_WIDTH + SUB_SCU(x)]; + const kvz_pixel *src = &frame->source->y[x + y * frame->source->width]; + uint32_t cost = + kvz_satd_any_size(cu_width, cu_width, rec, LCU_WIDTH, src, frame->source->width); + + uint32_t bitcost[2] = { 0, 0 }; + + cost += info->mvd_cost_func(info->state, + merge_cand[i].mv[0][0], + merge_cand[i].mv[0][1], + 0, + info->mv_cand, + NULL, 0, 0, + &bitcost[0]); + cost += info->mvd_cost_func(info->state, + merge_cand[i].mv[1][0], + merge_cand[i].mv[1][1], + 0, + info->mv_cand, + NULL, 0, 0, + &bitcost[1]); + + const uint8_t mv_ref_coded[2] = { + merge_cand[i].ref[0], + merge_cand[j].ref[1] + }; + const int extra_bits = mv_ref_coded[0] + mv_ref_coded[1] + 2 /* mv dir cost */; + cost += info->state->lambda_sqrt * extra_bits + 0.5; + + if (cost < *inter_cost) { + cur_cu->inter.mv_dir = 3; + + cur_cu->inter.mv_ref[0] = merge_cand[i].ref[0]; + cur_cu->inter.mv_ref[1] = merge_cand[j].ref[1]; + + cur_cu->inter.mv[0][0] = merge_cand[i].mv[0][0]; + cur_cu->inter.mv[0][1] = merge_cand[i].mv[0][1]; + cur_cu->inter.mv[1][0] = merge_cand[j].mv[1][0]; + cur_cu->inter.mv[1][1] = merge_cand[j].mv[1][1]; + cur_cu->merged = 0; + + // Check every candidate to find a match + for (int merge_idx = 0; merge_idx < info->num_merge_cand; merge_idx++) { + if (merge_cand[merge_idx].mv[0][0] == cur_cu->inter.mv[0][0] && + merge_cand[merge_idx].mv[0][1] == cur_cu->inter.mv[0][1] && + merge_cand[merge_idx].mv[1][0] == cur_cu->inter.mv[1][0] && + merge_cand[merge_idx].mv[1][1] == cur_cu->inter.mv[1][1] && + merge_cand[merge_idx].ref[0] == cur_cu->inter.mv_ref[0] && + merge_cand[merge_idx].ref[1] == cur_cu->inter.mv_ref[1]) + { + cur_cu->merged = 1; + cur_cu->merge_idx = merge_idx; + break; + } + } + + // Each motion vector has its own candidate + for (int reflist = 0; reflist < 2; reflist++) { + kvz_inter_get_mv_cand(info->state, x, y, width, height, info->mv_cand, cur_cu, lcu, reflist); + int cu_mv_cand = select_mv_cand( + info->state, + info->mv_cand, + cur_cu->inter.mv[reflist][0], + cur_cu->inter.mv[reflist][1], + NULL); + CU_SET_MV_CAND(cur_cu, reflist, cu_mv_cand); + } + + *inter_cost = cost; + *inter_bitcost = bitcost[0] + bitcost[1] + extra_bits; + } + } +} + + /** * \brief Update PU to have best modes at this depth. * @@ -1455,139 +1592,7 @@ static void search_pu_inter(encoder_state_t * const state, && width + height >= 16; // 4x8 and 8x4 PBs are restricted to unipred if (can_use_bipred) { - lcu_t *templcu = MALLOC(lcu_t, 1); - unsigned cu_width = LCU_WIDTH >> depth; - static const uint8_t priorityList0[] = { 0, 1, 0, 2, 1, 2, 0, 3, 1, 3, 2, 3 }; - static const uint8_t priorityList1[] = { 1, 0, 2, 0, 2, 1, 3, 0, 3, 1, 3, 2 }; - const unsigned num_cand_pairs = - MIN(info.num_merge_cand * (info.num_merge_cand - 1), 12); - - inter_merge_cand_t *merge_cand = info.merge_cand; - - for (int32_t idx = 0; idx < num_cand_pairs; idx++) { - uint8_t i = priorityList0[idx]; - uint8_t j = priorityList1[idx]; - if (i >= info.num_merge_cand || j >= info.num_merge_cand) break; - - // Find one L0 and L1 candidate according to the priority list - if ((merge_cand[i].dir & 0x1) && (merge_cand[j].dir & 0x2)) { - if (state->frame->ref_LX[0][merge_cand[i].ref[0]] != - state->frame->ref_LX[1][merge_cand[j].ref[1]] || - - merge_cand[i].mv[0][0] != merge_cand[j].mv[1][0] || - merge_cand[i].mv[0][1] != merge_cand[j].mv[1][1]) - { - uint32_t bitcost[2]; - uint32_t cost = 0; - int16_t mv[2][2]; - kvz_pixel tmp_block[64 * 64]; - kvz_pixel tmp_pic[64 * 64]; - - mv[0][0] = merge_cand[i].mv[0][0]; - mv[0][1] = merge_cand[i].mv[0][1]; - mv[1][0] = merge_cand[j].mv[1][0]; - mv[1][1] = merge_cand[j].mv[1][1]; - - // Don't try merge candidates that don't satisfy mv constraints. - if (!fracmv_within_tile(&info, mv[0][0], mv[0][1]) || - !fracmv_within_tile(&info, mv[1][0], mv[1][1])) - { - continue; - } - - kvz_inter_recon_bipred(state, - state->frame->ref->images[ - state->frame->ref_LX[0][merge_cand[i].ref[0]] - ], - state->frame->ref->images[ - state->frame->ref_LX[1][merge_cand[j].ref[1]] - ], - x, y, - width, - height, - mv, - templcu); - - for (int ypos = 0; ypos < height; ++ypos) { - int dst_y = ypos * width; - for (int xpos = 0; xpos < width; ++xpos) { - tmp_block[dst_y + xpos] = templcu->rec.y[ - SUB_SCU(y + ypos) * LCU_WIDTH + SUB_SCU(x + xpos)]; - tmp_pic[dst_y + xpos] = frame->source->y[x + xpos + (y + ypos)*frame->source->width]; - } - } - - cost = kvz_satd_any_size(cu_width, cu_width, tmp_pic, cu_width, tmp_block, cu_width); - - cost += info.mvd_cost_func(state, - merge_cand[i].mv[0][0], - merge_cand[i].mv[0][1], - 0, - info.mv_cand, - NULL, 0, 0, - &bitcost[0]); - cost += info.mvd_cost_func(state, - merge_cand[i].mv[1][0], - merge_cand[i].mv[1][1], - 0, - info.mv_cand, - NULL, 0, 0, - &bitcost[1]); - - const uint8_t mv_ref_coded[2] = { - merge_cand[i].ref[0], - merge_cand[j].ref[1] - }; - const int extra_bits = mv_ref_coded[0] + mv_ref_coded[1] + 2 /* mv dir cost */; - cost += state->lambda_sqrt * extra_bits + 0.5; - - - if (cost < *inter_cost) { - cur_cu->inter.mv_dir = 3; - - cur_cu->inter.mv_ref[0] = merge_cand[i].ref[0]; - cur_cu->inter.mv_ref[1] = merge_cand[j].ref[1]; - - cur_cu->inter.mv[0][0] = merge_cand[i].mv[0][0]; - cur_cu->inter.mv[0][1] = merge_cand[i].mv[0][1]; - cur_cu->inter.mv[1][0] = merge_cand[j].mv[1][0]; - cur_cu->inter.mv[1][1] = merge_cand[j].mv[1][1]; - cur_cu->merged = 0; - - // Check every candidate to find a match - for (int merge_idx = 0; merge_idx < info.num_merge_cand; merge_idx++) { - if (merge_cand[merge_idx].mv[0][0] == cur_cu->inter.mv[0][0] && - merge_cand[merge_idx].mv[0][1] == cur_cu->inter.mv[0][1] && - merge_cand[merge_idx].mv[1][0] == cur_cu->inter.mv[1][0] && - merge_cand[merge_idx].mv[1][1] == cur_cu->inter.mv[1][1] && - merge_cand[merge_idx].ref[0] == cur_cu->inter.mv_ref[0] && - merge_cand[merge_idx].ref[1] == cur_cu->inter.mv_ref[1]) - { - cur_cu->merged = 1; - cur_cu->merge_idx = merge_idx; - break; - } - } - - // Each motion vector has its own candidate - for (int reflist = 0; reflist < 2; reflist++) { - kvz_inter_get_mv_cand(state, x, y, width, height, info.mv_cand, cur_cu, lcu, reflist); - int cu_mv_cand = select_mv_cand( - state, - info.mv_cand, - cur_cu->inter.mv[reflist][0], - cur_cu->inter.mv[reflist][1], - NULL); - CU_SET_MV_CAND(cur_cu, reflist, cu_mv_cand); - } - - *inter_cost = cost; - *inter_bitcost = bitcost[0] + bitcost[1] + extra_bits; - } - } - } - } - FREE_POINTER(templcu); + search_pu_inter_bipred(&info, depth, lcu, cur_cu, inter_cost, inter_bitcost); } if (*inter_cost < INT_MAX && cur_cu->inter.mv_dir == 1) {