diff options
Diffstat (limited to 'trmul.cc')
-rw-r--r-- | trmul.cc | 24 |
1 files changed, 10 insertions, 14 deletions
@@ -90,12 +90,11 @@ struct TrMulTask final : Task { GetBlockByIndex(block_map, block_id, &block); // Get coordinates of the current block to handle, in matrix space. GetBlockMatrixCoords(block_map, block, &start, &end); - TraceRecordBlockCoordsComputed(block_id, trace); // Maybe pack the current LHS/RHS block, if not already packed. - EnsurePacked(block_id, local_packed, block, start, end, tuning); + EnsurePacked(local_packed, block, start, end, tuning); // Actually do matrix multiplication work params->RunKernel(tuning, start, end); - TraceRecordBlockFinished(block_id, trace); + TraceRecordBlockFinished(thread_id, block_id, trace); // Move on to the next block as obtained by the atomic increment // at the start of this while loop iteration. block_id = next_block_id; @@ -107,8 +106,8 @@ struct TrMulTask final : Task { } private: - bool TryEnsurePacked(Side side, int block_id, bool* local_packed, int block, - int start, int end, Tuning tuning) { + bool TryEnsurePacked(Side side, bool* local_packed, int block, int start, + int end, Tuning tuning) { if (local_packed && !local_packed[block]) { PackingStatus not_started = PackingStatus::kNotStarted; std::atomic<PackingStatus>& status = packing_status[side][block]; @@ -119,7 +118,7 @@ struct TrMulTask final : Task { // changed it to kInProgress as we are about to handle the packing // ourselves. params->RunPack(side, tuning, start, end); - TraceRecordBlockPacked(side, block_id, trace); + TraceRecordBlockPacked(thread_id, side, block, trace); status.store(PackingStatus::kFinished, std::memory_order_release); } else if (status.load(std::memory_order_acquire) == PackingStatus::kInProgress) { @@ -133,15 +132,15 @@ struct TrMulTask final : Task { return true; } - void EnsurePacked(int block_id, const SidePair<bool*> local_packed, + void EnsurePacked(const SidePair<bool*> local_packed, const SidePair<int>& block, const SidePair<int>& start, const SidePair<int>& end, Tuning tuning) { while (true) { bool both_sides_packed = true; for (Side side : {Side::kLhs, Side::kRhs}) { both_sides_packed &= - TryEnsurePacked(side, block_id, local_packed[side], block[side], - start[side], end[side], tuning); + TryEnsurePacked(side, local_packed[side], block[side], start[side], + end[side], tuning); } if (both_sides_packed) { break; @@ -282,9 +281,7 @@ void TrMul(TrMulParams* params, Context* context) { } // Do the computation. - TraceRecordExecute(trace); - TraceStartRecordingBlockAndThreadFields(block_map, thread_count, trace); - + TraceRecordExecute(block_map, thread_count, trace); context->workers_pool.Execute(thread_count, tasks); // Finish up. @@ -292,9 +289,8 @@ void TrMul(TrMulParams* params, Context* context) { tasks[i].~TrMulTask(); } - TraceRecordEnd(trace); - allocator->FreeAll(); + TraceRecordEnd(trace); } } // namespace ruy |