aboutsummaryrefslogtreecommitdiff
path: root/av1/qmode_rc/ratectrl_qmode.cc
diff options
context:
space:
mode:
Diffstat (limited to 'av1/qmode_rc/ratectrl_qmode.cc')
-rw-r--r--av1/qmode_rc/ratectrl_qmode.cc1552
1 files changed, 1552 insertions, 0 deletions
diff --git a/av1/qmode_rc/ratectrl_qmode.cc b/av1/qmode_rc/ratectrl_qmode.cc
new file mode 100644
index 000000000..0a2892d89
--- /dev/null
+++ b/av1/qmode_rc/ratectrl_qmode.cc
@@ -0,0 +1,1552 @@
+/*
+ * Copyright (c) 2022, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+#include "av1/qmode_rc/ratectrl_qmode.h"
+
+#include <algorithm>
+#include <cassert>
+#include <climits>
+#include <functional>
+#include <numeric>
+#include <sstream>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "aom/aom_codec.h"
+#include "av1/encoder/pass2_strategy.h"
+#include "av1/encoder/tpl_model.h"
+
+namespace aom {
+
+// This is used before division to ensure that the divisor isn't zero or
+// too close to zero.
+static double ModifyDivisor(double divisor) {
+ const double kEpsilon = 0.0000001;
+ return (divisor < 0 ? std::min(divisor, -kEpsilon)
+ : std::max(divisor, kEpsilon));
+}
+
+GopFrame GopFrameInvalid() {
+ GopFrame gop_frame = {};
+ gop_frame.is_valid = false;
+ gop_frame.coding_idx = -1;
+ gop_frame.order_idx = -1;
+ return gop_frame;
+}
+
+void SetGopFrameByType(GopFrameType gop_frame_type, GopFrame *gop_frame) {
+ gop_frame->update_type = gop_frame_type;
+ switch (gop_frame_type) {
+ case GopFrameType::kRegularKey:
+ gop_frame->is_key_frame = 1;
+ gop_frame->is_arf_frame = 0;
+ gop_frame->is_show_frame = 1;
+ gop_frame->is_golden_frame = 1;
+ gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
+ break;
+ case GopFrameType::kRegularGolden:
+ gop_frame->is_key_frame = 0;
+ gop_frame->is_arf_frame = 0;
+ gop_frame->is_show_frame = 1;
+ gop_frame->is_golden_frame = 1;
+ gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
+ break;
+ case GopFrameType::kRegularArf:
+ gop_frame->is_key_frame = 0;
+ gop_frame->is_arf_frame = 1;
+ gop_frame->is_show_frame = 0;
+ gop_frame->is_golden_frame = 1;
+ gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
+ break;
+ case GopFrameType::kIntermediateArf:
+ gop_frame->is_key_frame = 0;
+ gop_frame->is_arf_frame = 1;
+ gop_frame->is_show_frame = 0;
+ gop_frame->is_golden_frame = gop_frame->layer_depth <= 2 ? 1 : 0;
+ gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
+ break;
+ case GopFrameType::kRegularLeaf:
+ gop_frame->is_key_frame = 0;
+ gop_frame->is_arf_frame = 0;
+ gop_frame->is_show_frame = 1;
+ gop_frame->is_golden_frame = 0;
+ gop_frame->encode_ref_mode = EncodeRefMode::kRegular;
+ break;
+ case GopFrameType::kIntermediateOverlay:
+ gop_frame->is_key_frame = 0;
+ gop_frame->is_arf_frame = 0;
+ gop_frame->is_show_frame = 1;
+ gop_frame->is_golden_frame = 0;
+ gop_frame->encode_ref_mode = EncodeRefMode::kShowExisting;
+ break;
+ case GopFrameType::kOverlay:
+ gop_frame->is_key_frame = 0;
+ gop_frame->is_arf_frame = 0;
+ gop_frame->is_show_frame = 1;
+ gop_frame->is_golden_frame = 0;
+ gop_frame->encode_ref_mode = EncodeRefMode::kOverlay;
+ break;
+ }
+}
+
+GopFrame GopFrameBasic(int global_coding_idx_offset,
+ int global_order_idx_offset, int coding_idx,
+ int order_idx, int depth, int display_idx,
+ GopFrameType gop_frame_type) {
+ GopFrame gop_frame = {};
+ gop_frame.is_valid = true;
+ gop_frame.coding_idx = coding_idx;
+ gop_frame.order_idx = order_idx;
+ gop_frame.display_idx = display_idx;
+ gop_frame.global_coding_idx = global_coding_idx_offset + coding_idx;
+ gop_frame.global_order_idx = global_order_idx_offset + order_idx;
+ gop_frame.layer_depth = depth + kLayerDepthOffset;
+ gop_frame.colocated_ref_idx = -1;
+ gop_frame.update_ref_idx = -1;
+ SetGopFrameByType(gop_frame_type, &gop_frame);
+ return gop_frame;
+}
+
+// This function create gop frames with indices of display order from
+// order_start to order_end - 1. The function will recursively introduce
+// intermediate ARF untill maximum depth is met or the number of regular frames
+// in between two ARFs are less than 3. Than the regular frames will be added
+// into the gop_struct.
+void ConstructGopMultiLayer(GopStruct *gop_struct,
+ RefFrameManager *ref_frame_manager, int max_depth,
+ int depth, int order_start, int order_end) {
+ GopFrame gop_frame;
+ int num_frames = order_end - order_start;
+ const int global_coding_idx_offset = gop_struct->global_coding_idx_offset;
+ const int global_order_idx_offset = gop_struct->global_order_idx_offset;
+ // If there are less than kMinIntervalToAddArf frames, stop introducing ARF
+ if (depth < max_depth && num_frames >= kMinIntervalToAddArf) {
+ int order_mid = (order_start + order_end) / 2;
+ // intermediate ARF
+ gop_frame = GopFrameBasic(
+ global_coding_idx_offset, global_order_idx_offset,
+ static_cast<int>(gop_struct->gop_frame_list.size()), order_mid, depth,
+ gop_struct->display_tracker, GopFrameType::kIntermediateArf);
+ ref_frame_manager->UpdateRefFrameTable(&gop_frame);
+ gop_struct->gop_frame_list.push_back(gop_frame);
+ ConstructGopMultiLayer(gop_struct, ref_frame_manager, max_depth, depth + 1,
+ order_start, order_mid);
+ // show existing intermediate ARF
+ gop_frame =
+ GopFrameBasic(global_coding_idx_offset, global_order_idx_offset,
+ static_cast<int>(gop_struct->gop_frame_list.size()),
+ order_mid, max_depth, gop_struct->display_tracker,
+ GopFrameType::kIntermediateOverlay);
+ ref_frame_manager->UpdateRefFrameTable(&gop_frame);
+ gop_struct->gop_frame_list.push_back(gop_frame);
+ ++gop_struct->display_tracker;
+ ConstructGopMultiLayer(gop_struct, ref_frame_manager, max_depth, depth + 1,
+ order_mid + 1, order_end);
+ } else {
+ // regular frame
+ for (int i = order_start; i < order_end; ++i) {
+ gop_frame = GopFrameBasic(
+ global_coding_idx_offset, global_order_idx_offset,
+ static_cast<int>(gop_struct->gop_frame_list.size()), i, max_depth,
+ gop_struct->display_tracker, GopFrameType::kRegularLeaf);
+ ref_frame_manager->UpdateRefFrameTable(&gop_frame);
+ gop_struct->gop_frame_list.push_back(gop_frame);
+ ++gop_struct->display_tracker;
+ }
+ }
+}
+
+GopStruct ConstructGop(RefFrameManager *ref_frame_manager, int show_frame_count,
+ bool has_key_frame, int global_coding_idx_offset,
+ int global_order_idx_offset) {
+ GopStruct gop_struct;
+ gop_struct.show_frame_count = show_frame_count;
+ gop_struct.global_coding_idx_offset = global_coding_idx_offset;
+ gop_struct.global_order_idx_offset = global_order_idx_offset;
+ int order_start = 0;
+ int order_end = show_frame_count - 1;
+
+ // TODO(jingning): Re-enable the use of pyramid coding structure.
+ bool has_arf_frame = show_frame_count > kMinIntervalToAddArf;
+
+ gop_struct.display_tracker = 0;
+
+ GopFrame gop_frame;
+ if (has_key_frame) {
+ const int key_frame_depth = -1;
+ ref_frame_manager->Reset();
+ gop_frame = GopFrameBasic(
+ global_coding_idx_offset, global_order_idx_offset,
+ static_cast<int>(gop_struct.gop_frame_list.size()), order_start,
+ key_frame_depth, gop_struct.display_tracker, GopFrameType::kRegularKey);
+ ref_frame_manager->UpdateRefFrameTable(&gop_frame);
+ gop_struct.gop_frame_list.push_back(gop_frame);
+ order_start++;
+ ++gop_struct.display_tracker;
+ }
+
+ const int arf_depth = 0;
+ if (has_arf_frame) {
+ // Use multi-layer pyrmaid coding structure.
+ gop_frame = GopFrameBasic(
+ global_coding_idx_offset, global_order_idx_offset,
+ static_cast<int>(gop_struct.gop_frame_list.size()), order_end,
+ arf_depth, gop_struct.display_tracker, GopFrameType::kRegularArf);
+ ref_frame_manager->UpdateRefFrameTable(&gop_frame);
+ gop_struct.gop_frame_list.push_back(gop_frame);
+ ConstructGopMultiLayer(&gop_struct, ref_frame_manager,
+ ref_frame_manager->MaxRefFrame() - 1, arf_depth + 1,
+ order_start, order_end);
+ // Overlay
+ gop_frame =
+ GopFrameBasic(global_coding_idx_offset, global_order_idx_offset,
+ static_cast<int>(gop_struct.gop_frame_list.size()),
+ order_end, ref_frame_manager->MaxRefFrame() - 1,
+ gop_struct.display_tracker, GopFrameType::kOverlay);
+ ref_frame_manager->UpdateRefFrameTable(&gop_frame);
+ gop_struct.gop_frame_list.push_back(gop_frame);
+ ++gop_struct.display_tracker;
+ } else {
+ // Use IPPP format.
+ for (int i = order_start; i <= order_end; ++i) {
+ gop_frame = GopFrameBasic(
+ global_coding_idx_offset, global_order_idx_offset,
+ static_cast<int>(gop_struct.gop_frame_list.size()), i, arf_depth + 1,
+ gop_struct.display_tracker, GopFrameType::kRegularLeaf);
+ ref_frame_manager->UpdateRefFrameTable(&gop_frame);
+ gop_struct.gop_frame_list.push_back(gop_frame);
+ ++gop_struct.display_tracker;
+ }
+ }
+
+ return gop_struct;
+}
+
+Status AV1RateControlQMode::SetRcParam(const RateControlParam &rc_param) {
+ std::ostringstream error_message;
+ if (rc_param.max_gop_show_frame_count <
+ std::max(4, rc_param.min_gop_show_frame_count)) {
+ error_message << "max_gop_show_frame_count ("
+ << rc_param.max_gop_show_frame_count
+ << ") must be at least 4 and may not be less than "
+ "min_gop_show_frame_count ("
+ << rc_param.min_gop_show_frame_count << ")";
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ if (rc_param.ref_frame_table_size < 1 || rc_param.ref_frame_table_size > 8) {
+ error_message << "ref_frame_table_size (" << rc_param.ref_frame_table_size
+ << ") must be in the range [1, 8].";
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ if (rc_param.max_ref_frames < 1 || rc_param.max_ref_frames > 7) {
+ error_message << "max_ref_frames (" << rc_param.max_ref_frames
+ << ") must be in the range [1, 7].";
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ if (rc_param.base_q_index < 0 || rc_param.base_q_index > 255) {
+ error_message << "base_q_index (" << rc_param.base_q_index
+ << ") must be in the range [0, 255].";
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ if (rc_param.frame_width < 16 || rc_param.frame_width > 16384 ||
+ rc_param.frame_height < 16 || rc_param.frame_height > 16384) {
+ error_message << "frame_width (" << rc_param.frame_width
+ << ") and frame_height (" << rc_param.frame_height
+ << ") must be in the range [16, 16384].";
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ rc_param_ = rc_param;
+ return { AOM_CODEC_OK, "" };
+}
+
+// Threshold for use of the lagging second reference frame. High second ref
+// usage may point to a transient event like a flash or occlusion rather than
+// a real scene cut.
+// We adapt the threshold based on number of frames in this key-frame group so
+// far.
+static double GetSecondRefUsageThreshold(int frame_count_so_far) {
+ const int adapt_upto = 32;
+ const double min_second_ref_usage_thresh = 0.085;
+ const double second_ref_usage_thresh_max_delta = 0.035;
+ if (frame_count_so_far >= adapt_upto) {
+ return min_second_ref_usage_thresh + second_ref_usage_thresh_max_delta;
+ }
+ return min_second_ref_usage_thresh +
+ ((double)frame_count_so_far / (adapt_upto - 1)) *
+ second_ref_usage_thresh_max_delta;
+}
+
+// Slide show transition detection.
+// Tests for case where there is very low error either side of the current frame
+// but much higher just for this frame. This can help detect key frames in
+// slide shows even where the slides are pictures of different sizes.
+// Also requires that intra and inter errors are very similar to help eliminate
+// harmful false positives.
+// It will not help if the transition is a fade or other multi-frame effect.
+static bool DetectSlideTransition(const FIRSTPASS_STATS &this_frame,
+ const FIRSTPASS_STATS &last_frame,
+ const FIRSTPASS_STATS &next_frame) {
+ // Intra / Inter threshold very low
+ constexpr double kVeryLowII = 1.5;
+ // Clean slide transitions we expect a sharp single frame spike in error.
+ constexpr double kErrorSpike = 5.0;
+
+ // TODO(angiebird): Understand the meaning of these conditions.
+ return (this_frame.intra_error < (this_frame.coded_error * kVeryLowII)) &&
+ (this_frame.coded_error > (last_frame.coded_error * kErrorSpike)) &&
+ (this_frame.coded_error > (next_frame.coded_error * kErrorSpike));
+}
+
+// Check if there is a significant intra/inter error change between the current
+// frame and its neighbor. If so, we should further test whether the current
+// frame should be a key frame.
+static bool DetectIntraInterErrorChange(const FIRSTPASS_STATS &this_stats,
+ const FIRSTPASS_STATS &last_stats,
+ const FIRSTPASS_STATS &next_stats) {
+ // Minimum % intra coding observed in first pass (1.0 = 100%)
+ constexpr double kMinIntraLevel = 0.25;
+ // Minimum ratio between the % of intra coding and inter coding in the first
+ // pass after discounting neutral blocks (discounting neutral blocks in this
+ // way helps catch scene cuts in clips with very flat areas or letter box
+ // format clips with image padding.
+ constexpr double kIntraVsInterRatio = 2.0;
+
+ const double modified_pcnt_inter =
+ this_stats.pcnt_inter - this_stats.pcnt_neutral;
+ const double pcnt_intra_min =
+ std::max(kMinIntraLevel, kIntraVsInterRatio * modified_pcnt_inter);
+
+ // In real scene cuts there is almost always a sharp change in the intra
+ // or inter error score.
+ constexpr double kErrorChangeThreshold = 0.4;
+ const double last_this_error_ratio =
+ fabs(last_stats.coded_error - this_stats.coded_error) /
+ ModifyDivisor(this_stats.coded_error);
+
+ const double this_next_error_ratio =
+ fabs(last_stats.intra_error - this_stats.intra_error) /
+ ModifyDivisor(this_stats.intra_error);
+
+ // Maximum threshold for the relative ratio of intra error score vs best
+ // inter error score.
+ constexpr double kThisIntraCodedErrorRatioMax = 1.9;
+ const double this_intra_coded_error_ratio =
+ this_stats.intra_error / ModifyDivisor(this_stats.coded_error);
+
+ // For real scene cuts we expect an improvment in the intra inter error
+ // ratio in the next frame.
+ constexpr double kNextIntraCodedErrorRatioMin = 3.5;
+ const double next_intra_coded_error_ratio =
+ next_stats.intra_error / ModifyDivisor(next_stats.coded_error);
+
+ double pcnt_intra = 1.0 - this_stats.pcnt_inter;
+ return pcnt_intra > pcnt_intra_min &&
+ this_intra_coded_error_ratio < kThisIntraCodedErrorRatioMax &&
+ (last_this_error_ratio > kErrorChangeThreshold ||
+ this_next_error_ratio > kErrorChangeThreshold ||
+ next_intra_coded_error_ratio > kNextIntraCodedErrorRatioMin);
+}
+
+// Check whether the candidate can be a key frame.
+// This is a rewrite of test_candidate_kf().
+static bool TestCandidateKey(const FirstpassInfo &first_pass_info,
+ int candidate_key_idx, int frames_since_prev_key) {
+ const auto &stats_list = first_pass_info.stats_list;
+ const int stats_count = static_cast<int>(stats_list.size());
+ if (candidate_key_idx + 1 >= stats_count || candidate_key_idx - 1 < 0) {
+ return false;
+ }
+ const auto &last_stats = stats_list[candidate_key_idx - 1];
+ const auto &this_stats = stats_list[candidate_key_idx];
+ const auto &next_stats = stats_list[candidate_key_idx + 1];
+
+ if (frames_since_prev_key < 3) return false;
+ const double second_ref_usage_threshold =
+ GetSecondRefUsageThreshold(frames_since_prev_key);
+ if (this_stats.pcnt_second_ref >= second_ref_usage_threshold) return false;
+ if (next_stats.pcnt_second_ref >= second_ref_usage_threshold) return false;
+
+ // Hard threshold where the first pass chooses intra for almost all blocks.
+ // In such a case even if the frame is not a scene cut coding a key frame
+ // may be a good option.
+ constexpr double kVeryLowInterThreshold = 0.05;
+ if (this_stats.pcnt_inter < kVeryLowInterThreshold ||
+ DetectSlideTransition(this_stats, last_stats, next_stats) ||
+ DetectIntraInterErrorChange(this_stats, last_stats, next_stats)) {
+ double boost_score = 0.0;
+ double decay_accumulator = 1.0;
+
+ // We do "-1" because the candidate key is not counted.
+ int stats_after_this_stats = stats_count - candidate_key_idx - 1;
+
+ // Number of frames required to test for scene cut detection
+ constexpr int kSceneCutKeyTestIntervalMax = 16;
+
+ // Make sure we have enough stats after the candidate key.
+ const int frames_to_test_after_candidate_key =
+ std::min(kSceneCutKeyTestIntervalMax, stats_after_this_stats);
+
+ // Examine how well the key frame predicts subsequent frames.
+ int i;
+ for (i = 1; i <= frames_to_test_after_candidate_key; ++i) {
+ // Get the next frame details
+ const auto &stats = stats_list[candidate_key_idx + i];
+
+ // Cumulative effect of decay in prediction quality.
+ if (stats.pcnt_inter > 0.85) {
+ decay_accumulator *= stats.pcnt_inter;
+ } else {
+ decay_accumulator *= (0.85 + stats.pcnt_inter) / 2.0;
+ }
+
+ constexpr double kBoostFactor = 12.5;
+ double next_iiratio =
+ (kBoostFactor * stats.intra_error / ModifyDivisor(stats.coded_error));
+ next_iiratio = std::min(next_iiratio, 128.0);
+ double boost_score_increment = decay_accumulator * next_iiratio;
+
+ // Keep a running total.
+ boost_score += boost_score_increment;
+
+ // Test various breakout clauses.
+ // TODO(any): Test of intra error should be normalized to an MB.
+ // TODO(angiebird): Investigate the following questions.
+ // Question 1: next_iiratio (intra_error / coded_error) * kBoostFactor
+ // We know intra_error / coded_error >= 1 and kBoostFactor = 12.5,
+ // therefore, (intra_error / coded_error) * kBoostFactor will always
+ // greater than 1.5. Is "next_iiratio < 1.5" always false?
+ // Question 2: Similar to question 1, is "next_iiratio < 3.0" always true?
+ // Question 3: Why do we need to divide 200 with num_mbs_16x16?
+ if ((stats.pcnt_inter < 0.05) || (next_iiratio < 1.5) ||
+ (((stats.pcnt_inter - stats.pcnt_neutral) < 0.20) &&
+ (next_iiratio < 3.0)) ||
+ (boost_score_increment < 3.0) ||
+ (stats.intra_error <
+ (200.0 / static_cast<double>(first_pass_info.num_mbs_16x16)))) {
+ break;
+ }
+ }
+
+ // If there is tolerable prediction for at least the next 3 frames then
+ // break out else discard this potential key frame and move on
+ const int count_for_tolerable_prediction = 3;
+ if (boost_score > 30.0 && (i > count_for_tolerable_prediction)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Compute key frame location from first_pass_info.
+std::vector<int> GetKeyFrameList(const FirstpassInfo &first_pass_info) {
+ std::vector<int> key_frame_list;
+ key_frame_list.push_back(0); // The first frame is always a key frame
+ int candidate_key_idx = 1;
+ while (candidate_key_idx <
+ static_cast<int>(first_pass_info.stats_list.size())) {
+ const int frames_since_prev_key = candidate_key_idx - key_frame_list.back();
+ // Check for a scene cut.
+ const bool scenecut_detected = TestCandidateKey(
+ first_pass_info, candidate_key_idx, frames_since_prev_key);
+ if (scenecut_detected) {
+ key_frame_list.push_back(candidate_key_idx);
+ }
+ ++candidate_key_idx;
+ }
+ return key_frame_list;
+}
+
+// initialize GF_GROUP_STATS
+static void InitGFStats(GF_GROUP_STATS *gf_stats) {
+ gf_stats->gf_group_err = 0.0;
+ gf_stats->gf_group_raw_error = 0.0;
+ gf_stats->gf_group_skip_pct = 0.0;
+ gf_stats->gf_group_inactive_zone_rows = 0.0;
+
+ gf_stats->mv_ratio_accumulator = 0.0;
+ gf_stats->decay_accumulator = 1.0;
+ gf_stats->zero_motion_accumulator = 1.0;
+ gf_stats->loop_decay_rate = 1.0;
+ gf_stats->last_loop_decay_rate = 1.0;
+ gf_stats->this_frame_mv_in_out = 0.0;
+ gf_stats->mv_in_out_accumulator = 0.0;
+ gf_stats->abs_mv_in_out_accumulator = 0.0;
+
+ gf_stats->avg_sr_coded_error = 0.0;
+ gf_stats->avg_pcnt_second_ref = 0.0;
+ gf_stats->avg_new_mv_count = 0.0;
+ gf_stats->avg_wavelet_energy = 0.0;
+ gf_stats->avg_raw_err_stdev = 0.0;
+ gf_stats->non_zero_stdev_count = 0;
+}
+
+static int FindRegionIndex(const std::vector<REGIONS> &regions, int frame_idx) {
+ for (int k = 0; k < static_cast<int>(regions.size()); k++) {
+ if (regions[k].start <= frame_idx && regions[k].last >= frame_idx) {
+ return k;
+ }
+ }
+ return -1;
+}
+
+// This function detects a flash through the high relative pcnt_second_ref
+// score in the frame following a flash frame. The offset passed in should
+// reflect this.
+static bool DetectFlash(const std::vector<FIRSTPASS_STATS> &stats_list,
+ int index) {
+ int next_index = index + 1;
+ if (next_index >= static_cast<int>(stats_list.size())) return false;
+ const FIRSTPASS_STATS &next_frame = stats_list[next_index];
+
+ // What we are looking for here is a situation where there is a
+ // brief break in prediction (such as a flash) but subsequent frames
+ // are reasonably well predicted by an earlier (pre flash) frame.
+ // The recovery after a flash is indicated by a high pcnt_second_ref
+ // compared to pcnt_inter.
+ return next_frame.pcnt_second_ref > next_frame.pcnt_inter &&
+ next_frame.pcnt_second_ref >= 0.5;
+}
+
+#define MIN_SHRINK_LEN 6
+
+// This function takes in a suggesting gop interval from cur_start to cur_last,
+// analyzes firstpass stats and region stats and then return a better gop cut
+// location.
+// TODO(b/231517281): Simplify the indices once we have an unit test.
+// We are using four indices here, order_index, cur_start, cur_last, and
+// frames_since_key. Ideally, only three indices are needed.
+// 1) start_index = order_index + cur_start
+// 2) end_index = order_index + cur_end
+// 3) key_index
+int FindBetterGopCut(const std::vector<FIRSTPASS_STATS> &stats_list,
+ const std::vector<REGIONS> &regions_list,
+ int min_gop_show_frame_count, int max_gop_show_frame_count,
+ int order_index, int cur_start, int cur_last,
+ int frames_since_key) {
+ // only try shrinking if interval smaller than active_max_gf_interval
+ if (cur_last - cur_start > max_gop_show_frame_count ||
+ cur_start >= cur_last) {
+ return cur_last;
+ }
+ int num_regions = static_cast<int>(regions_list.size());
+ int num_stats = static_cast<int>(stats_list.size());
+ const int min_shrink_int = std::max(MIN_SHRINK_LEN, min_gop_show_frame_count);
+
+ // find the region indices of where the first and last frame belong.
+ int k_start = FindRegionIndex(regions_list, cur_start + frames_since_key);
+ int k_last = FindRegionIndex(regions_list, cur_last + frames_since_key);
+ if (cur_start + frames_since_key == 0) k_start = 0;
+
+ int scenecut_idx = -1;
+ // See if we have a scenecut in between
+ for (int r = k_start + 1; r <= k_last; r++) {
+ if (regions_list[r].type == SCENECUT_REGION &&
+ regions_list[r].last - frames_since_key - cur_start >
+ min_gop_show_frame_count) {
+ scenecut_idx = r;
+ break;
+ }
+ }
+
+ // if the found scenecut is very close to the end, ignore it.
+ if (scenecut_idx >= 0 &&
+ regions_list[num_regions - 1].last - regions_list[scenecut_idx].last <
+ 4) {
+ scenecut_idx = -1;
+ }
+
+ if (scenecut_idx != -1) {
+ // If we have a scenecut, then stop at it.
+ // TODO(bohanli): add logic here to stop before the scenecut and for
+ // the next gop start from the scenecut with GF
+ int is_minor_sc =
+ (regions_list[scenecut_idx].avg_cor_coeff *
+ (1 - stats_list[order_index + regions_list[scenecut_idx].start -
+ frames_since_key]
+ .noise_var /
+ regions_list[scenecut_idx].avg_intra_err) >
+ 0.6);
+ cur_last =
+ regions_list[scenecut_idx].last - frames_since_key - !is_minor_sc;
+ } else {
+ int is_last_analysed =
+ (k_last == num_regions - 1) &&
+ (cur_last + frames_since_key == regions_list[k_last].last);
+ int not_enough_regions =
+ k_last - k_start <= 1 + (regions_list[k_start].type == SCENECUT_REGION);
+ // if we are very close to the end, then do not shrink since it may
+ // introduce intervals that are too short
+ if (!(is_last_analysed && not_enough_regions)) {
+ const double arf_length_factor = 0.1;
+ double best_score = 0;
+ int best_j = -1;
+ const int first_frame = regions_list[0].start - frames_since_key;
+ const int last_frame =
+ regions_list[num_regions - 1].last - frames_since_key;
+ // score of how much the arf helps the whole GOP
+ double base_score = 0.0;
+ // Accumulate base_score in
+ for (int j = cur_start + 1; j < cur_start + min_shrink_int; j++) {
+ if (order_index + j >= num_stats) break;
+ base_score = (base_score + 1.0) * stats_list[order_index + j].cor_coeff;
+ }
+ int met_blending = 0; // Whether we have met blending areas before
+ int last_blending = 0; // Whether the previous frame if blending
+ for (int j = cur_start + min_shrink_int; j <= cur_last; j++) {
+ if (order_index + j >= num_stats) break;
+ base_score = (base_score + 1.0) * stats_list[order_index + j].cor_coeff;
+ int this_reg = FindRegionIndex(regions_list, j + frames_since_key);
+ if (this_reg < 0) continue;
+ // A GOP should include at most 1 blending region.
+ if (regions_list[this_reg].type == BLENDING_REGION) {
+ last_blending = 1;
+ if (met_blending) {
+ break;
+ } else {
+ base_score = 0;
+ continue;
+ }
+ } else {
+ if (last_blending) met_blending = 1;
+ last_blending = 0;
+ }
+
+ // Add the factor of how good the neighborhood is for this
+ // candidate arf.
+ double this_score = arf_length_factor * base_score;
+ double temp_accu_coeff = 1.0;
+ // following frames
+ int count_f = 0;
+ for (int n = j + 1; n <= j + 3 && n <= last_frame; n++) {
+ if (order_index + n >= num_stats) break;
+ temp_accu_coeff *= stats_list[order_index + n].cor_coeff;
+ this_score +=
+ temp_accu_coeff *
+ (1 - stats_list[order_index + n].noise_var /
+ AOMMAX(regions_list[this_reg].avg_intra_err, 0.001));
+ count_f++;
+ }
+ // preceding frames
+ temp_accu_coeff = 1.0;
+ for (int n = j; n > j - 3 * 2 + count_f && n > first_frame; n--) {
+ if (order_index + n < 0) break;
+ temp_accu_coeff *= stats_list[order_index + n].cor_coeff;
+ this_score +=
+ temp_accu_coeff *
+ (1 - stats_list[order_index + n].noise_var /
+ AOMMAX(regions_list[this_reg].avg_intra_err, 0.001));
+ }
+
+ if (this_score > best_score) {
+ best_score = this_score;
+ best_j = j;
+ }
+ }
+
+ // For blending areas, move one more frame in case we missed the
+ // first blending frame.
+ int best_reg = FindRegionIndex(regions_list, best_j + frames_since_key);
+ if (best_reg < num_regions - 1 && best_reg > 0) {
+ if (regions_list[best_reg - 1].type == BLENDING_REGION &&
+ regions_list[best_reg + 1].type == BLENDING_REGION) {
+ if (best_j + frames_since_key == regions_list[best_reg].start &&
+ best_j + frames_since_key < regions_list[best_reg].last) {
+ best_j += 1;
+ } else if (best_j + frames_since_key == regions_list[best_reg].last &&
+ best_j + frames_since_key > regions_list[best_reg].start) {
+ best_j -= 1;
+ }
+ }
+ }
+
+ if (cur_last - best_j < 2) best_j = cur_last;
+ if (best_j > 0 && best_score > 0.1) cur_last = best_j;
+ // if cannot find anything, just cut at the original place.
+ }
+ }
+
+ return cur_last;
+}
+
+// Function to test for a condition where a complex transition is followed
+// by a static section. For example in slide shows where there is a fade
+// between slides. This is to help with more optimal kf and gf positioning.
+static bool DetectTransitionToStill(
+ const std::vector<FIRSTPASS_STATS> &stats_list, int next_stats_index,
+ int min_gop_show_frame_count, int frame_interval, int still_interval,
+ double loop_decay_rate, double last_decay_rate) {
+ // Break clause to detect very still sections after motion
+ // For example a static image after a fade or other transition
+ // instead of a clean scene cut.
+ if (frame_interval > min_gop_show_frame_count && loop_decay_rate >= 0.999 &&
+ last_decay_rate < 0.9) {
+ int stats_count = static_cast<int>(stats_list.size());
+ int stats_left = stats_count - next_stats_index;
+ if (stats_left >= still_interval) {
+ // Look ahead a few frames to see if static condition persists...
+ int j;
+ for (j = 0; j < still_interval; ++j) {
+ const FIRSTPASS_STATS &stats = stats_list[next_stats_index + j];
+ if (stats.pcnt_inter - stats.pcnt_motion < 0.999) break;
+ }
+ // Only if it does do we signal a transition to still.
+ return j == still_interval;
+ }
+ }
+ return false;
+}
+
+static int DetectGopCut(const std::vector<FIRSTPASS_STATS> &stats_list,
+ int start_idx, int candidate_cut_idx, int next_key_idx,
+ int flash_detected, int min_gop_show_frame_count,
+ int max_gop_show_frame_count, int frame_width,
+ int frame_height, const GF_GROUP_STATS &gf_stats) {
+ (void)max_gop_show_frame_count;
+ const int candidate_gop_size = candidate_cut_idx - start_idx;
+
+ if (!flash_detected) {
+ // Break clause to detect very still sections after motion. For example,
+ // a static image after a fade or other transition.
+ if (DetectTransitionToStill(stats_list, start_idx, min_gop_show_frame_count,
+ candidate_gop_size, 5, gf_stats.loop_decay_rate,
+ gf_stats.last_loop_decay_rate)) {
+ return 1;
+ }
+ const double arf_abs_zoom_thresh = 4.4;
+ // Motion breakout threshold for loop below depends on image size.
+ const double mv_ratio_accumulator_thresh =
+ (frame_height + frame_width) / 4.0;
+ // Some conditions to breakout after min interval.
+ if (candidate_gop_size >= min_gop_show_frame_count &&
+ // If possible don't break very close to a kf
+ (next_key_idx - candidate_cut_idx >= min_gop_show_frame_count) &&
+ (candidate_gop_size & 0x01) &&
+ (gf_stats.mv_ratio_accumulator > mv_ratio_accumulator_thresh ||
+ gf_stats.abs_mv_in_out_accumulator > arf_abs_zoom_thresh)) {
+ return 1;
+ }
+ }
+
+ // TODO(b/231489624): Check if we need this part.
+ // If almost totally static, we will not use the the max GF length later,
+ // so we can continue for more frames.
+ // if ((candidate_gop_size >= active_max_gf_interval + 1) &&
+ // !is_almost_static(gf_stats->zero_motion_accumulator,
+ // twopass->kf_zeromotion_pct, cpi->ppi->lap_enabled)) {
+ // return 0;
+ // }
+ return 0;
+}
+
+/*!\brief Determine the length of future GF groups.
+ *
+ * \ingroup gf_group_algo
+ * This function decides the gf group length of future frames in batch
+ *
+ * \param[in] rc_param Rate control parameters
+ * \param[in] stats_list List of first pass stats
+ * \param[in] regions_list List of regions from av1_identify_regions
+ * \param[in] order_index Index of current frame in stats_list
+ * \param[in] frames_since_key Number of frames since the last key frame
+ * \param[in] frames_to_key Number of frames to the next key frame
+ *
+ * \return Returns a vector of decided GF group lengths.
+ */
+static std::vector<int> PartitionGopIntervals(
+ const RateControlParam &rc_param,
+ const std::vector<FIRSTPASS_STATS> &stats_list,
+ const std::vector<REGIONS> &regions_list, int order_index,
+ int frames_since_key, int frames_to_key) {
+ int i = 0;
+ // If cpi->gf_state.arf_gf_boost_lst is 0, we are starting with a KF or GF.
+ int cur_start = 0;
+ // Each element is the last frame of the previous GOP. If there are n GOPs,
+ // you need n + 1 cuts to find the durations. So cut_pos starts out with -1,
+ // which is the last frame of the previous GOP.
+ std::vector<int> cut_pos(1, -1);
+ int cut_here = 0;
+ GF_GROUP_STATS gf_stats;
+ InitGFStats(&gf_stats);
+ int num_stats = static_cast<int>(stats_list.size());
+
+ while (i + order_index < num_stats) {
+ // reaches next key frame, break here
+ if (i >= frames_to_key - 1) {
+ cut_here = 2;
+ } else if (i - cur_start >= rc_param.max_gop_show_frame_count) {
+ // reached maximum len, but nothing special yet (almost static)
+ // let's look at the next interval
+ cut_here = 2;
+ } else {
+ // Test for the case where there is a brief flash but the prediction
+ // quality back to an earlier frame is then restored.
+ const int gop_start_idx = cur_start + order_index;
+ const int candidate_gop_cut_idx = i + order_index;
+ const int next_key_idx = frames_to_key + order_index;
+ const bool flash_detected =
+ DetectFlash(stats_list, candidate_gop_cut_idx);
+
+ // TODO(bohanli): remove redundant accumulations here, or unify
+ // this and the ones in define_gf_group
+ const FIRSTPASS_STATS *stats = &stats_list[candidate_gop_cut_idx];
+ av1_accumulate_next_frame_stats(stats, flash_detected, frames_since_key,
+ i, &gf_stats, rc_param.frame_width,
+ rc_param.frame_height);
+
+ // TODO(angiebird): Can we simplify this part? Looks like we are going to
+ // change the gop cut index with FindBetterGopCut() anyway.
+ cut_here = DetectGopCut(
+ stats_list, gop_start_idx, candidate_gop_cut_idx, next_key_idx,
+ flash_detected, rc_param.min_gop_show_frame_count,
+ rc_param.max_gop_show_frame_count, rc_param.frame_width,
+ rc_param.frame_height, gf_stats);
+ }
+
+ if (!cut_here) {
+ ++i;
+ continue;
+ }
+
+ // the current last frame in the gf group
+ int original_last = cut_here > 1 ? i : i - 1;
+ int cur_last = FindBetterGopCut(
+ stats_list, regions_list, rc_param.min_gop_show_frame_count,
+ rc_param.max_gop_show_frame_count, order_index, cur_start,
+ original_last, frames_since_key);
+ // only try shrinking if interval smaller than active_max_gf_interval
+ cut_pos.push_back(cur_last);
+
+ // reset pointers to the shrunken location
+ cur_start = cur_last;
+ int cur_region_idx =
+ FindRegionIndex(regions_list, cur_start + 1 + frames_since_key);
+ if (cur_region_idx >= 0)
+ if (regions_list[cur_region_idx].type == SCENECUT_REGION) cur_start++;
+
+ // reset accumulators
+ InitGFStats(&gf_stats);
+ i = cur_last + 1;
+
+ if (cut_here == 2 && i >= frames_to_key) break;
+ }
+
+ std::vector<int> gf_intervals;
+ // save intervals
+ for (size_t n = 1; n < cut_pos.size(); n++) {
+ gf_intervals.push_back(cut_pos[n] - cut_pos[n - 1]);
+ }
+
+ return gf_intervals;
+}
+
+StatusOr<GopStructList> AV1RateControlQMode::DetermineGopInfo(
+ const FirstpassInfo &firstpass_info) {
+ const int stats_size = static_cast<int>(firstpass_info.stats_list.size());
+ GopStructList gop_list;
+ RefFrameManager ref_frame_manager(rc_param_.ref_frame_table_size,
+ rc_param_.max_ref_frames);
+
+ // Make a copy of the first pass stats, and analyze them
+ FirstpassInfo fp_info_copy = firstpass_info;
+ av1_mark_flashes(fp_info_copy.stats_list.data(),
+ fp_info_copy.stats_list.data() + stats_size);
+ av1_estimate_noise(fp_info_copy.stats_list.data(),
+ fp_info_copy.stats_list.data() + stats_size);
+ av1_estimate_coeff(fp_info_copy.stats_list.data(),
+ fp_info_copy.stats_list.data() + stats_size);
+
+ int global_coding_idx_offset = 0;
+ int global_order_idx_offset = 0;
+ std::vector<int> key_frame_list = GetKeyFrameList(fp_info_copy);
+ key_frame_list.push_back(stats_size); // a sentinel value
+ for (size_t ki = 0; ki + 1 < key_frame_list.size(); ++ki) {
+ int frames_to_key = key_frame_list[ki + 1] - key_frame_list[ki];
+ int key_order_index = key_frame_list[ki]; // The key frame's display order
+
+ std::vector<REGIONS> regions_list(MAX_FIRSTPASS_ANALYSIS_FRAMES);
+ int total_regions = 0;
+ av1_identify_regions(fp_info_copy.stats_list.data() + key_order_index,
+ frames_to_key, 0, regions_list.data(), &total_regions);
+ regions_list.resize(total_regions);
+ std::vector<int> gf_intervals = PartitionGopIntervals(
+ rc_param_, fp_info_copy.stats_list, regions_list, key_order_index,
+ /*frames_since_key=*/0, frames_to_key);
+ for (size_t gi = 0; gi < gf_intervals.size(); ++gi) {
+ const bool has_key_frame = gi == 0;
+ const int show_frame_count = gf_intervals[gi];
+ GopStruct gop =
+ ConstructGop(&ref_frame_manager, show_frame_count, has_key_frame,
+ global_coding_idx_offset, global_order_idx_offset);
+ assert(gop.show_frame_count == show_frame_count);
+ global_coding_idx_offset += static_cast<int>(gop.gop_frame_list.size());
+ global_order_idx_offset += gop.show_frame_count;
+ gop_list.push_back(gop);
+ }
+ }
+ return gop_list;
+}
+
+TplFrameDepStats CreateTplFrameDepStats(int frame_height, int frame_width,
+ int min_block_size) {
+ const int unit_rows = (frame_height + min_block_size - 1) / min_block_size;
+ const int unit_cols = (frame_width + min_block_size - 1) / min_block_size;
+ TplFrameDepStats frame_dep_stats;
+ frame_dep_stats.unit_size = min_block_size;
+ frame_dep_stats.unit_stats.resize(unit_rows);
+ for (auto &row : frame_dep_stats.unit_stats) {
+ row.resize(unit_cols);
+ }
+ return frame_dep_stats;
+}
+
+TplUnitDepStats TplBlockStatsToDepStats(const TplBlockStats &block_stats,
+ int unit_count) {
+ TplUnitDepStats dep_stats = {};
+ dep_stats.intra_cost = block_stats.intra_cost * 1.0 / unit_count;
+ dep_stats.inter_cost = block_stats.inter_cost * 1.0 / unit_count;
+ // In rare case, inter_cost may be greater than intra_cost.
+ // If so, we need to modify inter_cost such that inter_cost <= intra_cost
+ // because it is required by GetPropagationFraction()
+ dep_stats.inter_cost = std::min(dep_stats.intra_cost, dep_stats.inter_cost);
+ dep_stats.mv = block_stats.mv;
+ dep_stats.ref_frame_index = block_stats.ref_frame_index;
+ return dep_stats;
+}
+
+namespace {
+Status ValidateBlockStats(const TplFrameStats &frame_stats,
+ const TplBlockStats &block_stats,
+ int min_block_size) {
+ if (block_stats.col >= frame_stats.frame_width ||
+ block_stats.row >= frame_stats.frame_height) {
+ std::ostringstream error_message;
+ error_message << "Block position (" << block_stats.col << ", "
+ << block_stats.row
+ << ") is out of range; frame dimensions are "
+ << frame_stats.frame_width << " x "
+ << frame_stats.frame_height;
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ if (block_stats.col % min_block_size != 0 ||
+ block_stats.row % min_block_size != 0 ||
+ block_stats.width % min_block_size != 0 ||
+ block_stats.height % min_block_size != 0) {
+ std::ostringstream error_message;
+ error_message
+ << "Invalid block position or dimension, must be a multiple of "
+ << min_block_size << "; col = " << block_stats.col
+ << ", row = " << block_stats.row << ", width = " << block_stats.width
+ << ", height = " << block_stats.height;
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ return { AOM_CODEC_OK, "" };
+}
+
+Status ValidateTplStats(const GopStruct &gop_struct,
+ const TplGopStats &tpl_gop_stats) {
+ constexpr char kAdvice[] =
+ "Do the current RateControlParam settings match those used to generate "
+ "the TPL stats?";
+ if (gop_struct.gop_frame_list.size() !=
+ tpl_gop_stats.frame_stats_list.size()) {
+ std::ostringstream error_message;
+ error_message << "Frame count of GopStruct ("
+ << gop_struct.gop_frame_list.size()
+ << ") doesn't match frame count of TPL stats ("
+ << tpl_gop_stats.frame_stats_list.size() << "). " << kAdvice;
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ for (int i = 0; i < static_cast<int>(gop_struct.gop_frame_list.size()); ++i) {
+ const bool is_ref_frame = gop_struct.gop_frame_list[i].update_ref_idx >= 0;
+ const bool has_tpl_stats =
+ !tpl_gop_stats.frame_stats_list[i].block_stats_list.empty();
+ if (is_ref_frame && !has_tpl_stats) {
+ std::ostringstream error_message;
+ error_message << "The frame with global_coding_idx "
+ << gop_struct.gop_frame_list[i].global_coding_idx
+ << " is a reference frame, but has no TPL stats. "
+ << kAdvice;
+ return { AOM_CODEC_INVALID_PARAM, error_message.str() };
+ }
+ }
+ return { AOM_CODEC_OK, "" };
+}
+} // namespace
+
+StatusOr<TplFrameDepStats> CreateTplFrameDepStatsWithoutPropagation(
+ const TplFrameStats &frame_stats) {
+ if (frame_stats.block_stats_list.empty()) {
+ return TplFrameDepStats();
+ }
+ const int min_block_size = frame_stats.min_block_size;
+ const int unit_rows =
+ (frame_stats.frame_height + min_block_size - 1) / min_block_size;
+ const int unit_cols =
+ (frame_stats.frame_width + min_block_size - 1) / min_block_size;
+ TplFrameDepStats frame_dep_stats = CreateTplFrameDepStats(
+ frame_stats.frame_height, frame_stats.frame_width, min_block_size);
+ for (const TplBlockStats &block_stats : frame_stats.block_stats_list) {
+ Status status =
+ ValidateBlockStats(frame_stats, block_stats, min_block_size);
+ if (!status.ok()) {
+ return status;
+ }
+ const int block_unit_row = block_stats.row / min_block_size;
+ const int block_unit_col = block_stats.col / min_block_size;
+ // The block must start within the frame boundaries, but it may extend past
+ // the right edge or bottom of the frame. Find the number of unit rows and
+ // columns in the block which are fully within the frame.
+ const int block_unit_rows = std::min(block_stats.height / min_block_size,
+ unit_rows - block_unit_row);
+ const int block_unit_cols = std::min(block_stats.width / min_block_size,
+ unit_cols - block_unit_col);
+ const int unit_count = block_unit_rows * block_unit_cols;
+ TplUnitDepStats unit_stats =
+ TplBlockStatsToDepStats(block_stats, unit_count);
+ for (int r = 0; r < block_unit_rows; r++) {
+ for (int c = 0; c < block_unit_cols; c++) {
+ frame_dep_stats.unit_stats[block_unit_row + r][block_unit_col + c] =
+ unit_stats;
+ }
+ }
+ }
+
+ frame_dep_stats.rdcost = TplFrameDepStatsAccumulateInterCost(frame_dep_stats);
+
+ return frame_dep_stats;
+}
+
+int GetRefCodingIdxList(const TplUnitDepStats &unit_dep_stats,
+ const RefFrameTable &ref_frame_table,
+ int *ref_coding_idx_list) {
+ int ref_frame_count = 0;
+ for (int i = 0; i < kBlockRefCount; ++i) {
+ ref_coding_idx_list[i] = -1;
+ int ref_frame_index = unit_dep_stats.ref_frame_index[i];
+ if (ref_frame_index != -1) {
+ assert(ref_frame_index < static_cast<int>(ref_frame_table.size()));
+ ref_coding_idx_list[i] = ref_frame_table[ref_frame_index].coding_idx;
+ ref_frame_count++;
+ }
+ }
+ return ref_frame_count;
+}
+
+int GetBlockOverlapArea(int r0, int c0, int r1, int c1, int size) {
+ const int r_low = std::max(r0, r1);
+ const int r_high = std::min(r0 + size, r1 + size);
+ const int c_low = std::max(c0, c1);
+ const int c_high = std::min(c0 + size, c1 + size);
+ if (r_high >= r_low && c_high >= c_low) {
+ return (r_high - r_low) * (c_high - c_low);
+ }
+ return 0;
+}
+
+// TODO(angiebird): Merge TplFrameDepStatsAccumulateIntraCost and
+// TplFrameDepStatsAccumulate.
+double TplFrameDepStatsAccumulateIntraCost(
+ const TplFrameDepStats &frame_dep_stats) {
+ auto getIntraCost = [](double sum, const TplUnitDepStats &unit) {
+ return sum + unit.intra_cost;
+ };
+ double sum = 0;
+ for (const auto &row : frame_dep_stats.unit_stats) {
+ sum = std::accumulate(row.begin(), row.end(), sum, getIntraCost);
+ }
+ return std::max(sum, 1.0);
+}
+
+double TplFrameDepStatsAccumulateInterCost(
+ const TplFrameDepStats &frame_dep_stats) {
+ auto getInterCost = [](double sum, const TplUnitDepStats &unit) {
+ return sum + unit.inter_cost;
+ };
+ double sum = 0;
+ for (const auto &row : frame_dep_stats.unit_stats) {
+ sum = std::accumulate(row.begin(), row.end(), sum, getInterCost);
+ }
+ return std::max(sum, 1.0);
+}
+
+double TplFrameDepStatsAccumulate(const TplFrameDepStats &frame_dep_stats) {
+ auto getOverallCost = [](double sum, const TplUnitDepStats &unit) {
+ return sum + unit.propagation_cost + unit.intra_cost;
+ };
+ double sum = 0;
+ for (const auto &row : frame_dep_stats.unit_stats) {
+ sum = std::accumulate(row.begin(), row.end(), sum, getOverallCost);
+ }
+ return std::max(sum, 1.0);
+}
+
+// This is a generalization of GET_MV_RAWPEL that allows for an arbitrary
+// number of fractional bits.
+// TODO(angiebird): Add unit test to this function
+int GetFullpelValue(int subpel_value, int subpel_bits) {
+ const int subpel_scale = (1 << subpel_bits);
+ const int sign = subpel_value >= 0 ? 1 : -1;
+ int fullpel_value = (abs(subpel_value) + subpel_scale / 2) >> subpel_bits;
+ fullpel_value *= sign;
+ return fullpel_value;
+}
+
+double GetPropagationFraction(const TplUnitDepStats &unit_dep_stats) {
+ assert(unit_dep_stats.intra_cost >= unit_dep_stats.inter_cost);
+ return (unit_dep_stats.intra_cost - unit_dep_stats.inter_cost) /
+ ModifyDivisor(unit_dep_stats.intra_cost);
+}
+
+void TplFrameDepStatsPropagate(int coding_idx,
+ const RefFrameTable &ref_frame_table,
+ TplGopDepStats *tpl_gop_dep_stats) {
+ assert(!tpl_gop_dep_stats->frame_dep_stats_list.empty());
+ TplFrameDepStats *frame_dep_stats =
+ &tpl_gop_dep_stats->frame_dep_stats_list[coding_idx];
+
+ if (frame_dep_stats->unit_stats.empty()) return;
+
+ const int unit_size = frame_dep_stats->unit_size;
+ const int frame_unit_rows =
+ static_cast<int>(frame_dep_stats->unit_stats.size());
+ const int frame_unit_cols =
+ static_cast<int>(frame_dep_stats->unit_stats[0].size());
+ for (int unit_row = 0; unit_row < frame_unit_rows; ++unit_row) {
+ for (int unit_col = 0; unit_col < frame_unit_cols; ++unit_col) {
+ TplUnitDepStats &unit_dep_stats =
+ frame_dep_stats->unit_stats[unit_row][unit_col];
+ int ref_coding_idx_list[kBlockRefCount] = { -1, -1 };
+ int ref_frame_count = GetRefCodingIdxList(unit_dep_stats, ref_frame_table,
+ ref_coding_idx_list);
+ if (ref_frame_count == 0) continue;
+ for (int i = 0; i < kBlockRefCount; ++i) {
+ if (ref_coding_idx_list[i] == -1) continue;
+ assert(
+ ref_coding_idx_list[i] <
+ static_cast<int>(tpl_gop_dep_stats->frame_dep_stats_list.size()));
+ TplFrameDepStats &ref_frame_dep_stats =
+ tpl_gop_dep_stats->frame_dep_stats_list[ref_coding_idx_list[i]];
+ assert(!ref_frame_dep_stats.unit_stats.empty());
+ const auto &mv = unit_dep_stats.mv[i];
+ const int mv_row = GetFullpelValue(mv.row, mv.subpel_bits);
+ const int mv_col = GetFullpelValue(mv.col, mv.subpel_bits);
+ const int ref_pixel_r = unit_row * unit_size + mv_row;
+ const int ref_pixel_c = unit_col * unit_size + mv_col;
+ const int ref_unit_row_low =
+ (unit_row * unit_size + mv_row) / unit_size;
+ const int ref_unit_col_low =
+ (unit_col * unit_size + mv_col) / unit_size;
+
+ for (int j = 0; j < 2; ++j) {
+ for (int k = 0; k < 2; ++k) {
+ const int ref_unit_row = ref_unit_row_low + j;
+ const int ref_unit_col = ref_unit_col_low + k;
+ if (ref_unit_row >= 0 && ref_unit_row < frame_unit_rows &&
+ ref_unit_col >= 0 && ref_unit_col < frame_unit_cols) {
+ const int overlap_area = GetBlockOverlapArea(
+ ref_pixel_r, ref_pixel_c, ref_unit_row * unit_size,
+ ref_unit_col * unit_size, unit_size);
+ const double overlap_ratio =
+ overlap_area * 1.0 / (unit_size * unit_size);
+ const double propagation_fraction =
+ GetPropagationFraction(unit_dep_stats);
+ const double propagation_ratio =
+ 1.0 / ref_frame_count * overlap_ratio * propagation_fraction;
+ TplUnitDepStats &ref_unit_stats =
+ ref_frame_dep_stats.unit_stats[ref_unit_row][ref_unit_col];
+ ref_unit_stats.propagation_cost +=
+ (unit_dep_stats.intra_cost +
+ unit_dep_stats.propagation_cost) *
+ propagation_ratio;
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+std::vector<RefFrameTable> AV1RateControlQMode::GetRefFrameTableList(
+ const GopStruct &gop_struct,
+ const std::vector<LookaheadStats> &lookahead_stats,
+ RefFrameTable ref_frame_table) {
+ if (gop_struct.global_coding_idx_offset == 0) {
+ // For the first GOP, ref_frame_table need not be initialized. This is fine,
+ // because the first frame (a key frame) will fully initialize it.
+ ref_frame_table.assign(rc_param_.ref_frame_table_size, GopFrameInvalid());
+ } else {
+ // It's not the first GOP, so ref_frame_table must be valid.
+ assert(static_cast<int>(ref_frame_table.size()) ==
+ rc_param_.ref_frame_table_size);
+ assert(std::all_of(ref_frame_table.begin(), ref_frame_table.end(),
+ std::mem_fn(&GopFrame::is_valid)));
+ // Reset the frame processing order of the initial ref_frame_table.
+ for (GopFrame &gop_frame : ref_frame_table) gop_frame.coding_idx = -1;
+ }
+
+ std::vector<RefFrameTable> ref_frame_table_list;
+ ref_frame_table_list.push_back(ref_frame_table);
+ for (const GopFrame &gop_frame : gop_struct.gop_frame_list) {
+ if (gop_frame.is_key_frame) {
+ ref_frame_table.assign(rc_param_.ref_frame_table_size, gop_frame);
+ } else if (gop_frame.update_ref_idx != -1) {
+ assert(gop_frame.update_ref_idx <
+ static_cast<int>(ref_frame_table.size()));
+ ref_frame_table[gop_frame.update_ref_idx] = gop_frame;
+ }
+ ref_frame_table_list.push_back(ref_frame_table);
+ }
+
+ int gop_size_offset = static_cast<int>(gop_struct.gop_frame_list.size());
+
+ for (const auto &lookahead_stat : lookahead_stats) {
+ for (GopFrame gop_frame : lookahead_stat.gop_struct->gop_frame_list) {
+ if (gop_frame.is_key_frame) {
+ ref_frame_table.assign(rc_param_.ref_frame_table_size, gop_frame);
+ } else if (gop_frame.update_ref_idx != -1) {
+ assert(gop_frame.update_ref_idx <
+ static_cast<int>(ref_frame_table.size()));
+ gop_frame.coding_idx += gop_size_offset;
+ ref_frame_table[gop_frame.update_ref_idx] = gop_frame;
+ }
+ ref_frame_table_list.push_back(ref_frame_table);
+ }
+ gop_size_offset +=
+ static_cast<int>(lookahead_stat.gop_struct->gop_frame_list.size());
+ }
+
+ return ref_frame_table_list;
+}
+
+StatusOr<TplGopDepStats> ComputeTplGopDepStats(
+ const TplGopStats &tpl_gop_stats,
+ const std::vector<LookaheadStats> &lookahead_stats,
+ const std::vector<RefFrameTable> &ref_frame_table_list) {
+ std::vector<const TplFrameStats *> tpl_frame_stats_list_with_lookahead;
+ for (const auto &tpl_frame_stats : tpl_gop_stats.frame_stats_list) {
+ tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
+ }
+ for (const auto &lookahead_stat : lookahead_stats) {
+ for (const auto &tpl_frame_stats :
+ lookahead_stat.tpl_gop_stats->frame_stats_list) {
+ tpl_frame_stats_list_with_lookahead.push_back(&tpl_frame_stats);
+ }
+ }
+
+ const int frame_count =
+ static_cast<int>(tpl_frame_stats_list_with_lookahead.size());
+
+ // Create the struct to store TPL dependency stats
+ TplGopDepStats tpl_gop_dep_stats;
+
+ tpl_gop_dep_stats.frame_dep_stats_list.reserve(frame_count);
+ for (int coding_idx = 0; coding_idx < frame_count; coding_idx++) {
+ const StatusOr<TplFrameDepStats> tpl_frame_dep_stats =
+ CreateTplFrameDepStatsWithoutPropagation(
+ *tpl_frame_stats_list_with_lookahead[coding_idx]);
+ if (!tpl_frame_dep_stats.ok()) {
+ return tpl_frame_dep_stats.status();
+ }
+ tpl_gop_dep_stats.frame_dep_stats_list.push_back(
+ std::move(*tpl_frame_dep_stats));
+ }
+
+ // Back propagation
+ for (int coding_idx = frame_count - 1; coding_idx >= 0; coding_idx--) {
+ auto &ref_frame_table = ref_frame_table_list[coding_idx];
+ // TODO(angiebird): Handle/test the case where reference frame
+ // is in the previous GOP
+ TplFrameDepStatsPropagate(coding_idx, ref_frame_table, &tpl_gop_dep_stats);
+ }
+ return tpl_gop_dep_stats;
+}
+
+static std::vector<uint8_t> SetupDeltaQ(const TplFrameDepStats &frame_dep_stats,
+ int frame_width, int frame_height,
+ int base_qindex,
+ double frame_importance) {
+ // TODO(jianj) : Add support to various superblock sizes.
+ const int sb_size = 64;
+ const int delta_q_res = 4;
+ const int num_unit_per_sb = sb_size / frame_dep_stats.unit_size;
+ const int sb_rows = (frame_height + sb_size - 1) / sb_size;
+ const int sb_cols = (frame_width + sb_size - 1) / sb_size;
+ const int unit_rows = (frame_height + frame_dep_stats.unit_size - 1) /
+ frame_dep_stats.unit_size;
+ const int unit_cols =
+ (frame_width + frame_dep_stats.unit_size - 1) / frame_dep_stats.unit_size;
+ std::vector<uint8_t> superblock_q_indices;
+ // Calculate delta_q offset for each superblock.
+ for (int sb_row = 0; sb_row < sb_rows; ++sb_row) {
+ for (int sb_col = 0; sb_col < sb_cols; ++sb_col) {
+ double intra_cost = 0;
+ double mc_dep_cost = 0;
+ const int unit_row_start = sb_row * num_unit_per_sb;
+ const int unit_row_end =
+ std::min((sb_row + 1) * num_unit_per_sb, unit_rows);
+ const int unit_col_start = sb_col * num_unit_per_sb;
+ const int unit_col_end =
+ std::min((sb_col + 1) * num_unit_per_sb, unit_cols);
+ // A simplified version of av1_get_q_for_deltaq_objective()
+ for (int unit_row = unit_row_start; unit_row < unit_row_end; ++unit_row) {
+ for (int unit_col = unit_col_start; unit_col < unit_col_end;
+ ++unit_col) {
+ const TplUnitDepStats &unit_dep_stat =
+ frame_dep_stats.unit_stats[unit_row][unit_col];
+ intra_cost += unit_dep_stat.intra_cost;
+ mc_dep_cost += unit_dep_stat.propagation_cost;
+ }
+ }
+
+ double beta = 1.0;
+ if (mc_dep_cost > 0 && intra_cost > 0) {
+ const double r0 = 1 / frame_importance;
+ const double rk = intra_cost / mc_dep_cost;
+ beta = r0 / rk;
+ assert(beta > 0.0);
+ }
+ int offset = av1_get_deltaq_offset(AOM_BITS_8, base_qindex, beta);
+ offset = std::min(offset, delta_q_res * 9 - 1);
+ offset = std::max(offset, -delta_q_res * 9 + 1);
+ int qindex = offset + base_qindex;
+ qindex = std::min(qindex, MAXQ);
+ qindex = std::max(qindex, MINQ);
+ qindex = av1_adjust_q_from_delta_q_res(delta_q_res, base_qindex, qindex);
+ superblock_q_indices.push_back(static_cast<uint8_t>(qindex));
+ }
+ }
+
+ return superblock_q_indices;
+}
+
+static std::unordered_map<int, double> FindKMeansClusterMap(
+ const std::vector<uint8_t> &qindices,
+ const std::vector<double> &centroids) {
+ std::unordered_map<int, double> cluster_map;
+ for (const uint8_t qindex : qindices) {
+ double nearest_centroid = *std::min_element(
+ centroids.begin(), centroids.end(),
+ [qindex](const double centroid_a, const double centroid_b) {
+ return fabs(centroid_a - qindex) < fabs(centroid_b - qindex);
+ });
+ cluster_map.insert({ qindex, nearest_centroid });
+ }
+ return cluster_map;
+}
+
+namespace internal {
+
+std::unordered_map<int, int> KMeans(std::vector<uint8_t> qindices, int k) {
+ std::vector<double> centroids;
+ // Initialize the centroids with first k qindices
+ std::unordered_set<int> qindices_set;
+
+ for (const uint8_t qp : qindices) {
+ if (!qindices_set.insert(qp).second) continue; // Already added.
+ centroids.push_back(qp);
+ if (static_cast<int>(centroids.size()) >= k) break;
+ }
+
+ std::unordered_map<int, double> intermediate_cluster_map;
+ while (true) {
+ // Find the closest centroid for each qindex
+ intermediate_cluster_map = FindKMeansClusterMap(qindices, centroids);
+ // For each cluster, calculate the new centroids
+ std::unordered_map<double, std::vector<int>> centroid_to_qindices;
+ for (const auto &qindex_centroid : intermediate_cluster_map) {
+ centroid_to_qindices[qindex_centroid.second].push_back(
+ qindex_centroid.first);
+ }
+ bool centroids_changed = false;
+ std::vector<double> new_centroids;
+ for (const auto &cluster : centroid_to_qindices) {
+ double sum = 0.0;
+ for (const int qindex : cluster.second) {
+ sum += qindex;
+ }
+ double new_centroid = sum / cluster.second.size();
+ new_centroids.push_back(new_centroid);
+ if (new_centroid != cluster.first) centroids_changed = true;
+ }
+ if (!centroids_changed) break;
+ centroids = new_centroids;
+ }
+ std::unordered_map<int, int> cluster_map;
+ for (const auto &qindex_centroid : intermediate_cluster_map) {
+ cluster_map.insert(
+ { qindex_centroid.first, static_cast<int>(qindex_centroid.second) });
+ }
+ return cluster_map;
+}
+} // namespace internal
+
+static int GetRDMult(const GopFrame &gop_frame, int q_index) {
+ // TODO(angiebird):
+ // 1) Check if these rdmult rules are good in our use case.
+ // 2) Support high-bit-depth mode
+ if (gop_frame.is_golden_frame) {
+ // Assume ARF_UPDATE/GF_UPDATE share the same remult rule.
+ return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, GF_UPDATE, q_index);
+ } else if (gop_frame.is_key_frame) {
+ return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, KF_UPDATE, q_index);
+ } else {
+ // Assume LF_UPDATE/OVERLAY_UPDATE/INTNL_OVERLAY_UPDATE/INTNL_ARF_UPDATE
+ // share the same remult rule.
+ return av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, LF_UPDATE, q_index);
+ }
+}
+
+StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithNoStats(
+ const GopStruct &gop_struct) {
+ GopEncodeInfo gop_encode_info;
+ const int frame_count = static_cast<int>(gop_struct.gop_frame_list.size());
+ for (int i = 0; i < frame_count; i++) {
+ FrameEncodeParameters param;
+ const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
+ // Use constant QP for TPL pass encoding. Keep the functionality
+ // that allows QP changes across sub-gop.
+ param.q_index = rc_param_.base_q_index;
+ param.rdmult = av1_compute_rd_mult_based_on_qindex(AOM_BITS_8, LF_UPDATE,
+ rc_param_.base_q_index);
+ // TODO(jingning): gop_frame is needed in two pass tpl later.
+ (void)gop_frame;
+
+ if (rc_param_.tpl_pass_index) {
+ if (gop_frame.update_type == GopFrameType::kRegularGolden ||
+ gop_frame.update_type == GopFrameType::kRegularKey ||
+ gop_frame.update_type == GopFrameType::kRegularArf) {
+ double qstep_ratio = 1 / 3.0;
+ param.q_index = av1_get_q_index_from_qstep_ratio(
+ rc_param_.base_q_index, qstep_ratio, AOM_BITS_8);
+ if (rc_param_.base_q_index) param.q_index = AOMMAX(param.q_index, 1);
+ }
+ }
+ gop_encode_info.param_list.push_back(param);
+ }
+ return gop_encode_info;
+}
+
+StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithFp(
+ const GopStruct &gop_struct,
+ const FirstpassInfo &firstpass_info AOM_UNUSED) {
+ // TODO(b/260859962): This is currently a placeholder. Should use the fp
+ // stats to calculate frame-level qp.
+ return GetGopEncodeInfoWithNoStats(gop_struct);
+}
+
+StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfoWithTpl(
+ const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
+ const std::vector<LookaheadStats> &lookahead_stats,
+ const RefFrameTable &ref_frame_table_snapshot_init) {
+ const std::vector<RefFrameTable> ref_frame_table_list = GetRefFrameTableList(
+ gop_struct, lookahead_stats, ref_frame_table_snapshot_init);
+
+ GopEncodeInfo gop_encode_info;
+ gop_encode_info.final_snapshot = ref_frame_table_list.back();
+ StatusOr<TplGopDepStats> gop_dep_stats = ComputeTplGopDepStats(
+ tpl_gop_stats, lookahead_stats, ref_frame_table_list);
+ if (!gop_dep_stats.ok()) {
+ return gop_dep_stats.status();
+ }
+ const int frame_count =
+ static_cast<int>(tpl_gop_stats.frame_stats_list.size());
+ const int active_worst_quality = rc_param_.base_q_index;
+ int active_best_quality = rc_param_.base_q_index;
+ for (int i = 0; i < frame_count; i++) {
+ FrameEncodeParameters param;
+ const GopFrame &gop_frame = gop_struct.gop_frame_list[i];
+
+ if (gop_frame.update_type == GopFrameType::kOverlay ||
+ gop_frame.update_type == GopFrameType::kIntermediateOverlay ||
+ gop_frame.update_type == GopFrameType::kRegularLeaf) {
+ param.q_index = rc_param_.base_q_index;
+ } else if (gop_frame.update_type == GopFrameType::kRegularGolden ||
+ gop_frame.update_type == GopFrameType::kRegularKey ||
+ gop_frame.update_type == GopFrameType::kRegularArf) {
+ const TplFrameDepStats &frame_dep_stats =
+ gop_dep_stats->frame_dep_stats_list[i];
+ const double cost_without_propagation =
+ TplFrameDepStatsAccumulateIntraCost(frame_dep_stats);
+ const double cost_with_propagation =
+ TplFrameDepStatsAccumulate(frame_dep_stats);
+ const double frame_importance =
+ cost_with_propagation / cost_without_propagation;
+ // Imitate the behavior of av1_tpl_get_qstep_ratio()
+ const double qstep_ratio = sqrt(1 / frame_importance);
+ param.q_index = av1_get_q_index_from_qstep_ratio(rc_param_.base_q_index,
+ qstep_ratio, AOM_BITS_8);
+ if (rc_param_.base_q_index) param.q_index = AOMMAX(param.q_index, 1);
+ active_best_quality = param.q_index;
+
+ if (rc_param_.max_distinct_q_indices_per_frame > 1) {
+ std::vector<uint8_t> superblock_q_indices = SetupDeltaQ(
+ frame_dep_stats, rc_param_.frame_width, rc_param_.frame_height,
+ param.q_index, frame_importance);
+ std::unordered_map<int, int> qindex_centroids = internal::KMeans(
+ superblock_q_indices, rc_param_.max_distinct_q_indices_per_frame);
+ for (size_t i = 0; i < superblock_q_indices.size(); ++i) {
+ const int curr_sb_qindex =
+ qindex_centroids.find(superblock_q_indices[i])->second;
+ const int delta_q_res = 4;
+ const int adjusted_qindex =
+ param.q_index +
+ (curr_sb_qindex - param.q_index) / delta_q_res * delta_q_res;
+ const int rd_mult = GetRDMult(gop_frame, adjusted_qindex);
+ param.superblock_encode_params.push_back(
+ { static_cast<uint8_t>(adjusted_qindex), rd_mult });
+ }
+ }
+ } else {
+ // Intermediate ARFs
+ assert(gop_frame.layer_depth >= 1);
+ const int depth_factor = 1 << (gop_frame.layer_depth - 1);
+ param.q_index =
+ (active_worst_quality * (depth_factor - 1) + active_best_quality) /
+ depth_factor;
+ }
+ param.rdmult = GetRDMult(gop_frame, param.q_index);
+ gop_encode_info.param_list.push_back(param);
+ }
+ return gop_encode_info;
+}
+
+StatusOr<GopEncodeInfo> AV1RateControlQMode::GetTplPassGopEncodeInfo(
+ const GopStruct &gop_struct, const FirstpassInfo &firstpass_info) {
+ return GetGopEncodeInfoWithFp(gop_struct, firstpass_info);
+}
+
+StatusOr<GopEncodeInfo> AV1RateControlQMode::GetGopEncodeInfo(
+ const GopStruct &gop_struct, const TplGopStats &tpl_gop_stats,
+ const std::vector<LookaheadStats> &lookahead_stats,
+ const FirstpassInfo &firstpass_info AOM_UNUSED,
+ const RefFrameTable &ref_frame_table_snapshot_init) {
+ // When TPL stats are not valid, use first pass stats.
+ Status status = ValidateTplStats(gop_struct, tpl_gop_stats);
+ if (!status.ok()) {
+ return status;
+ }
+
+ for (const auto &lookahead_stat : lookahead_stats) {
+ Status status = ValidateTplStats(*lookahead_stat.gop_struct,
+ *lookahead_stat.tpl_gop_stats);
+ if (!status.ok()) {
+ return status;
+ }
+ }
+
+ // TODO(b/260859962): Currently firstpass stats are used as an alternative,
+ // but we could also combine it with tpl results in the future for more
+ // stable qp determination.
+ return GetGopEncodeInfoWithTpl(gop_struct, tpl_gop_stats, lookahead_stats,
+ ref_frame_table_snapshot_init);
+}
+
+} // namespace aom