diff options
Diffstat (limited to 'internal/block_params.h')
-rw-r--r-- | internal/block_params.h | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/internal/block_params.h b/internal/block_params.h index b2fc3ff..aedd261 100644 --- a/internal/block_params.h +++ b/internal/block_params.h @@ -43,13 +43,12 @@ struct BlockParams { int l2_depth; template <typename KernelFormat> - void Init(int rows, int cols, int depth, int num_threads, - int l1_bytes_to_use, int l2_bytes_to_use, float l2_rhs_factor) { + void Init(int rows, int cols, int depth, int num_threads, int l1_bytes_to_use, + int l2_bytes_to_use, float l2_rhs_factor) { FindL2BlockSizes<KernelFormat>(rows, cols, depth, num_threads, - l2_bytes_to_use, l2_rhs_factor, - &l2_rows, &l2_cols, &l2_depth); - FindL1BlockSizes<KernelFormat>(l2_rows, l2_cols, l2_depth, - l1_bytes_to_use, + l2_bytes_to_use, l2_rhs_factor, &l2_rows, + &l2_cols, &l2_depth); + FindL1BlockSizes<KernelFormat>(l2_rows, l2_cols, l2_depth, l1_bytes_to_use, &l1_rows, &l1_cols, &l1_depth); } @@ -61,6 +60,10 @@ struct BlockParams { int l2_rows = 0; int l2_cols = 0; int l2_depth = 0; + + int per_thread_rows = + std::max(1, RoundUp<KernelFormat::kRows>(rows) / num_threads); + // No L2 blocking in the depth dimension at the moment. // Too much loss of accuracy due to storing intermediate results in // low precision. @@ -81,15 +84,15 @@ struct BlockParams { // dimension concerns only the LHS. Blocking only RHS matrix for L2 enhances // the performance on x86. if (l2_rhs_factor == 1.0f) { - l2_rows = RoundUp<KernelFormat::kRows>(rows); + l2_rows = RoundUp<KernelFormat::kRows>(per_thread_rows); } else { int max_cache_friendly_l2_rows = std::max(1, (l2_bytes_to_use - l2_depth * l2_cols) / (num_threads * (l2_depth + 4 * l2_cols))); - int min_l2_rows_blocks = - std::max(1, CeilQuotient(rows, max_cache_friendly_l2_rows)); - l2_rows = - RoundUp<KernelFormat::kRows>(CeilQuotient(rows, min_l2_rows_blocks)); + int min_l2_rows_blocks = std::max( + 1, CeilQuotient(per_thread_rows, max_cache_friendly_l2_rows)); + l2_rows = RoundUp<KernelFormat::kRows>( + CeilQuotient(per_thread_rows, min_l2_rows_blocks)); } *out_l2_rows = l2_rows; |