aboutsummaryrefslogtreecommitdiff
path: root/trmul.cc
diff options
context:
space:
mode:
Diffstat (limited to 'trmul.cc')
-rw-r--r--trmul.cc24
1 files changed, 10 insertions, 14 deletions
diff --git a/trmul.cc b/trmul.cc
index 535ad45..a41372a 100644
--- a/trmul.cc
+++ b/trmul.cc
@@ -90,12 +90,11 @@ struct TrMulTask final : Task {
GetBlockByIndex(block_map, block_id, &block);
// Get coordinates of the current block to handle, in matrix space.
GetBlockMatrixCoords(block_map, block, &start, &end);
- TraceRecordBlockCoordsComputed(block_id, trace);
// Maybe pack the current LHS/RHS block, if not already packed.
- EnsurePacked(block_id, local_packed, block, start, end, tuning);
+ EnsurePacked(local_packed, block, start, end, tuning);
// Actually do matrix multiplication work
params->RunKernel(tuning, start, end);
- TraceRecordBlockFinished(block_id, trace);
+ TraceRecordBlockFinished(thread_id, block_id, trace);
// Move on to the next block as obtained by the atomic increment
// at the start of this while loop iteration.
block_id = next_block_id;
@@ -107,8 +106,8 @@ struct TrMulTask final : Task {
}
private:
- bool TryEnsurePacked(Side side, int block_id, bool* local_packed, int block,
- int start, int end, Tuning tuning) {
+ bool TryEnsurePacked(Side side, bool* local_packed, int block, int start,
+ int end, Tuning tuning) {
if (local_packed && !local_packed[block]) {
PackingStatus not_started = PackingStatus::kNotStarted;
std::atomic<PackingStatus>& status = packing_status[side][block];
@@ -119,7 +118,7 @@ struct TrMulTask final : Task {
// changed it to kInProgress as we are about to handle the packing
// ourselves.
params->RunPack(side, tuning, start, end);
- TraceRecordBlockPacked(side, block_id, trace);
+ TraceRecordBlockPacked(thread_id, side, block, trace);
status.store(PackingStatus::kFinished, std::memory_order_release);
} else if (status.load(std::memory_order_acquire) ==
PackingStatus::kInProgress) {
@@ -133,15 +132,15 @@ struct TrMulTask final : Task {
return true;
}
- void EnsurePacked(int block_id, const SidePair<bool*> local_packed,
+ void EnsurePacked(const SidePair<bool*> local_packed,
const SidePair<int>& block, const SidePair<int>& start,
const SidePair<int>& end, Tuning tuning) {
while (true) {
bool both_sides_packed = true;
for (Side side : {Side::kLhs, Side::kRhs}) {
both_sides_packed &=
- TryEnsurePacked(side, block_id, local_packed[side], block[side],
- start[side], end[side], tuning);
+ TryEnsurePacked(side, local_packed[side], block[side], start[side],
+ end[side], tuning);
}
if (both_sides_packed) {
break;
@@ -282,9 +281,7 @@ void TrMul(TrMulParams* params, Context* context) {
}
// Do the computation.
- TraceRecordExecute(trace);
- TraceStartRecordingBlockAndThreadFields(block_map, thread_count, trace);
-
+ TraceRecordExecute(block_map, thread_count, trace);
context->workers_pool.Execute(thread_count, tasks);
// Finish up.
@@ -292,9 +289,8 @@ void TrMul(TrMulParams* params, Context* context) {
tasks[i].~TrMulTask();
}
- TraceRecordEnd(trace);
-
allocator->FreeAll();
+ TraceRecordEnd(trace);
}
} // namespace ruy