aboutsummaryrefslogtreecommitdiff
path: root/internal/block_params.h
diff options
context:
space:
mode:
Diffstat (limited to 'internal/block_params.h')
-rw-r--r--internal/block_params.h25
1 files changed, 14 insertions, 11 deletions
diff --git a/internal/block_params.h b/internal/block_params.h
index b2fc3ff..aedd261 100644
--- a/internal/block_params.h
+++ b/internal/block_params.h
@@ -43,13 +43,12 @@ struct BlockParams {
int l2_depth;
template <typename KernelFormat>
- void Init(int rows, int cols, int depth, int num_threads,
- int l1_bytes_to_use, int l2_bytes_to_use, float l2_rhs_factor) {
+ void Init(int rows, int cols, int depth, int num_threads, int l1_bytes_to_use,
+ int l2_bytes_to_use, float l2_rhs_factor) {
FindL2BlockSizes<KernelFormat>(rows, cols, depth, num_threads,
- l2_bytes_to_use, l2_rhs_factor,
- &l2_rows, &l2_cols, &l2_depth);
- FindL1BlockSizes<KernelFormat>(l2_rows, l2_cols, l2_depth,
- l1_bytes_to_use,
+ l2_bytes_to_use, l2_rhs_factor, &l2_rows,
+ &l2_cols, &l2_depth);
+ FindL1BlockSizes<KernelFormat>(l2_rows, l2_cols, l2_depth, l1_bytes_to_use,
&l1_rows, &l1_cols, &l1_depth);
}
@@ -61,6 +60,10 @@ struct BlockParams {
int l2_rows = 0;
int l2_cols = 0;
int l2_depth = 0;
+
+ int per_thread_rows =
+ std::max(1, RoundUp<KernelFormat::kRows>(rows) / num_threads);
+
// No L2 blocking in the depth dimension at the moment.
// Too much loss of accuracy due to storing intermediate results in
// low precision.
@@ -81,15 +84,15 @@ struct BlockParams {
// dimension concerns only the LHS. Blocking only RHS matrix for L2 enhances
// the performance on x86.
if (l2_rhs_factor == 1.0f) {
- l2_rows = RoundUp<KernelFormat::kRows>(rows);
+ l2_rows = RoundUp<KernelFormat::kRows>(per_thread_rows);
} else {
int max_cache_friendly_l2_rows =
std::max(1, (l2_bytes_to_use - l2_depth * l2_cols) /
(num_threads * (l2_depth + 4 * l2_cols)));
- int min_l2_rows_blocks =
- std::max(1, CeilQuotient(rows, max_cache_friendly_l2_rows));
- l2_rows =
- RoundUp<KernelFormat::kRows>(CeilQuotient(rows, min_l2_rows_blocks));
+ int min_l2_rows_blocks = std::max(
+ 1, CeilQuotient(per_thread_rows, max_cache_friendly_l2_rows));
+ l2_rows = RoundUp<KernelFormat::kRows>(
+ CeilQuotient(per_thread_rows, min_l2_rows_blocks));
}
*out_l2_rows = l2_rows;