diff options
-rw-r--r-- | BUILD | 1 | ||||
-rw-r--r-- | block_map.cc | 18 | ||||
-rw-r--r-- | block_map.h | 4 | ||||
-rw-r--r-- | trace.cc | 431 | ||||
-rw-r--r-- | trace.h | 24 | ||||
-rw-r--r-- | trmul.cc | 24 |
6 files changed, 218 insertions, 284 deletions
@@ -195,7 +195,6 @@ cc_library( deps = [ ":block_map", ":check_macros", - ":common", ":side_pair", ":time", ], diff --git a/block_map.cc b/block_map.cc index 17fe906..3b41ac8 100644 --- a/block_map.cc +++ b/block_map.cc @@ -202,15 +202,15 @@ void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, block_map->large_blocks[Side::kRhs] = missc; } -void GetBlockMatrixCoords(Side side, const BlockMap& block_map, - const SidePair<int>& block, int* start, int* end) { +void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, + int* start, int* end) { gemmlowp::ScopedProfilingLabel label("GetBlockMatrixCoords"); - const int b = block[side]; - *start = - b * block_map.small_block_dims[side] + - std::min(b, block_map.large_blocks[side]) * block_map.kernel_dims[side]; - *end = *start + block_map.small_block_dims[side] + - (b < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0); + *start = block * block_map.small_block_dims[side] + + std::min(block, block_map.large_blocks[side]) * + block_map.kernel_dims[side]; + *end = + *start + block_map.small_block_dims[side] + + (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0); RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]); RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]); @@ -222,7 +222,7 @@ void GetBlockMatrixCoords(Side side, const BlockMap& block_map, void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block, SidePair<int>* start, SidePair<int>* end) { for (Side side : {Side::kLhs, Side::kRhs}) { - GetBlockMatrixCoords(side, block_map, block, &(*start)[side], + GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side], &(*end)[side]); } } diff --git a/block_map.h b/block_map.h index 415cac7..deb2b38 100644 --- a/block_map.h +++ b/block_map.h @@ -115,8 +115,8 @@ void GetBlockByIndex(const BlockMap& block_map, int index, // position in the matrix that the BlockMap refers to in the dimension // referred to by `side`: along rows if side==kLhs, along columns if // side==kRhs. -void GetBlockMatrixCoords(Side side, const BlockMap& block_map, - const SidePair<int>& block, int* start, int* end); +void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, + int* start, int* end); // Given a block position in the grid, returns its actual // position in the matrix that the BlockMap refers to in terms of @@ -24,200 +24,153 @@ limitations under the License. #include "block_map.h" #include "check_macros.h" -#include "common.h" +#include "side_pair.h" #include "time.h" namespace ruy { #ifdef RUY_TRACE -struct BlockTraceEntry { - std::uint32_t thread_id = 0; - TimePoint time_reserved; - TimePoint time_computed_coords; - SidePair<TimePoint> time_packed; - TimePoint time_finished; +enum class TraceEvent : std::uint8_t { + kNone, + kThreadStart, + kThreadLoopStart, + kThreadEnd, + kBlockReserved, + kBlockPackedLhs, + kBlockPackedRhs, + kBlockFinished }; -struct ThreadTraceEntry { - TimePoint time_start; - TimePoint time_loop_start; - TimePoint time_end; +struct TraceEntry { + TimePoint time_point; + TraceEvent event; + // ruy-internal thread id i.e. contiguous index into array of threads, + // with 0 designating the main thread. + std::uint16_t thread_id = 0; + // Additional parameters whose meaning depends on the 'event' type. + std::uint32_t params[1]; }; struct Trace { - enum class LifeStage { - kInitial, - kRecordingRootFields, - kRecordingBlockAndThreadFields, - kComplete - }; - void StartRecordingBlockAndThreadFields(const BlockMap& block_map_, - int thread_count_) { - RUY_DCHECK(life_stage == LifeStage::kRecordingRootFields); - block_map = block_map_; - thread_count = thread_count_; - int num_blocks = NumBlocks(block_map); - if (num_blocks > block_entries.size()) { - block_entries.resize(NumBlocks(block_map)); - } - if (thread_count > thread_entries.size()) { - thread_entries.resize(thread_count); - } - life_stage = LifeStage::kRecordingBlockAndThreadFields; - } BlockMap block_map; int thread_count = 0; - std::vector<BlockTraceEntry> block_entries; - std::vector<ThreadTraceEntry> thread_entries; + // During recording, to avoid having to use locks or atomics, we let + // each thread append to its own specific vector. + std::vector<std::vector<TraceEntry>> thread_specific_entries; + // Global vector of entries into which we coalesce thread_specific_entries + // after recording is finished, when dumping a trace. See + // AggregateThreadSpecificEntries. + std::vector<TraceEntry> entries; TimePoint time_start; TimePoint time_execute; TimePoint time_end; - LifeStage life_stage = LifeStage::kInitial; }; -struct ProcessedTrace { - enum class Event : std::uint8_t { - kNone, - kThreadStart, - kThreadLoopStart, - kThreadEnd, - kBlockReserved, - kBlockComputedCoords, - kBlockPackedLhs, - kBlockPackedRhs, - kBlockFinished - }; - struct Entry { - Event event = Event::kNone; - std::uint32_t thread_id = 0; - std::uint32_t block_id = 0; - TimePoint time; - }; +namespace { - BlockMap block_map; - int thread_count = 0; - TimePoint time_start; - TimePoint time_execute; - TimePoint time_end; - std::vector<Entry> entries; - void Add(Event event, std::uint32_t thread_id, std::uint32_t block_id, - TimePoint time) { - // If the time point is still in its default-constructed state, - // that means we didn't record it. - if (!time.time_since_epoch().count()) { - return; +// Coalesce Trace::thread_specific_entries into Trace::entries. +void AggregateThreadSpecificEntries(Trace* trace) { + RUY_CHECK(trace->entries.empty()); + for (auto& thread_specific_entries_vector : trace->thread_specific_entries) { + for (const TraceEntry& entry : thread_specific_entries_vector) { + trace->entries.push_back(entry); } - Entry entry; - entry.event = event; - entry.thread_id = thread_id; - entry.block_id = block_id; - entry.time = time; - entries.push_back(entry); + thread_specific_entries_vector.clear(); } - void Process(const Trace& trace) { - thread_count = trace.thread_count; - block_map = trace.block_map; - time_start = trace.time_start; - time_execute = trace.time_execute; - time_end = trace.time_end; - entries.clear(); - for (int i = 0; i < trace.thread_count; i++) { - const auto& entry = trace.thread_entries[i]; - Add(Event::kThreadStart, i, 0, entry.time_start); - Add(Event::kThreadLoopStart, i, 0, entry.time_loop_start); - Add(Event::kThreadEnd, i, 0, entry.time_end); - } - std::uint32_t num_blocks = NumBlocks(block_map); - for (int i = 0; i < num_blocks; i++) { - const auto& entry = trace.block_entries[i]; - Add(Event::kBlockReserved, entry.thread_id, i, entry.time_reserved); - Add(Event::kBlockComputedCoords, entry.thread_id, i, - entry.time_computed_coords); - Add(Event::kBlockPackedLhs, entry.thread_id, i, - entry.time_packed[Side::kLhs]); - Add(Event::kBlockPackedRhs, entry.thread_id, i, - entry.time_packed[Side::kRhs]); - Add(Event::kBlockFinished, entry.thread_id, i, entry.time_finished); - } - std::sort(entries.begin(), entries.end(), - [](const Entry& a, const Entry& b) -> bool { - return a.time < b.time || - (a.time == b.time && - static_cast<int>(a.event) < static_cast<int>(b.event)); - }); +} + +// Sort Trace::entries by ascending time. In case of equal timepoints, +// sort by some semi-arbitrary ordering of event types. +void Sort(Trace* trace) { + std::sort(std::begin(trace->entries), std::end(trace->entries), + [](const TraceEntry& a, const TraceEntry& b) -> bool { + return a.time_point < b.time_point || + (a.time_point == b.time_point && + static_cast<int>(a.event) < static_cast<int>(b.event)); + }); +} + +// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have +// already been called on it. +void Dump(const Trace& trace) { + const char* trace_filename = getenv("RUY_TRACE_FILE"); + FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr; + if (!trace_file) { + fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename, + errno); + RUY_CHECK(false); } - void Dump() { - const char* trace_filename = getenv("RUY_TRACE_FILE"); - FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr; - if (!trace_file) { - fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename, - errno); - RUY_CHECK(false); - } - fprintf(trace_file, "thread_count:%d\n", thread_count); - fprintf(trace_file, "num_blocks:%d\n", NumBlocks(block_map)); - fprintf(trace_file, "rows:%d\n", block_map.rows); - fprintf(trace_file, "cols:%d\n", block_map.cols); - fprintf(trace_file, "Execute: %.9f\n", - ToSeconds(time_execute - time_start)); - for (const Entry& entry : entries) { - double time = ToSeconds(entry.time - time_start); - switch (entry.event) { - case Event::kThreadStart: - fprintf(trace_file, "ThreadStart: %.9f, %d\n", time, entry.thread_id); - break; - case Event::kThreadLoopStart: - fprintf(trace_file, "ThreadLoopStart: %.9f, %d\n", time, - entry.thread_id); - break; - case Event::kThreadEnd: - fprintf(trace_file, "ThreadEnd: %.9f, %d\n", time, entry.thread_id); - break; - case Event::kBlockReserved: { - std::uint16_t block_r, block_c; - int start_r, start_c, end_r, end_c; - GetBlockByIndex(block_map, entry.block_id, &block_r, &block_c); - GetBlockMatrixCoords(block_map, block_r, block_c, &start_r, &start_c, - &end_r, &end_c); - fprintf(trace_file, "BlockReserved: %.9f, %d, %d, %d, %d, %d, %d\n", - time, entry.thread_id, entry.block_id, start_r, start_c, - end_r, end_c); - break; - } - case Event::kBlockComputedCoords: - fprintf(trace_file, "BlockComputedCoords: %.9f, %d, %d\n", time, - entry.thread_id, entry.block_id); - break; - case Event::kBlockPackedLhs: - fprintf(trace_file, "BlockPackedLhs: %.9f, %d, %d\n", time, - entry.thread_id, entry.block_id); - break; - case Event::kBlockPackedRhs: - fprintf(trace_file, "BlockPackedRhs: %.9f, %d, %d\n", time, - entry.thread_id, entry.block_id); - break; - case Event::kBlockFinished: - fprintf(trace_file, "BlockFinished: %.9f, %d, %d\n", time, - entry.thread_id, entry.block_id); - break; - default: - RUY_CHECK(false); + fprintf(trace_file, "thread_count:%d\n", trace.thread_count); + fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]); + fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]); + fprintf(trace_file, "Execute: %.9f\n", + ToSeconds(trace.time_execute - trace.time_start)); + for (const TraceEntry& entry : trace.entries) { + double time = ToSeconds(entry.time_point - trace.time_start); + switch (entry.event) { + case TraceEvent::kThreadStart: + fprintf(trace_file, "ThreadStart: %.9f, %d\n", time, entry.thread_id); + break; + case TraceEvent::kThreadLoopStart: + fprintf(trace_file, "ThreadLoopStart: %.9f, %d\n", time, + entry.thread_id); + break; + case TraceEvent::kThreadEnd: + fprintf(trace_file, "ThreadEnd: %.9f, %d\n", time, entry.thread_id); + break; + case TraceEvent::kBlockReserved: { + std::uint32_t block_id = entry.params[0]; + SidePair<int> block; + GetBlockByIndex(trace.block_map, block_id, &block); + SidePair<int> start, end; + GetBlockMatrixCoords(trace.block_map, block, &start, &end); + fprintf(trace_file, + "BlockReserved: %.9f, %d, %d, %d, %d, %d, %d, %d, %d\n", time, + entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs], + start[Side::kLhs], start[Side::kRhs], end[Side::kLhs], + end[Side::kRhs]); + break; } - } - fprintf(trace_file, "End: %.9f\n", ToSeconds(time_end - time_start)); - if (trace_filename) { - fclose(trace_file); + case TraceEvent::kBlockPackedLhs: { + std::uint32_t block = entry.params[0]; + int start, end; + GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end); + fprintf(trace_file, "BlockPackedLhs: %.9f, %d, %d, %d, %d\n", time, + entry.thread_id, block, start, end); + break; + } + case TraceEvent::kBlockPackedRhs: { + std::uint32_t block = entry.params[0]; + int start, end; + GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end); + fprintf(trace_file, "BlockPackedRhs: %.9f, %d, %d, %d, %d\n", time, + entry.thread_id, block, start, end); + break; + } + case TraceEvent::kBlockFinished: { + std::uint32_t block_id = entry.params[0]; + SidePair<int> block; + GetBlockByIndex(trace.block_map, block_id, &block); + fprintf(trace_file, "BlockFinished: %.9f, %d, %d, %d, %d\n", time, + entry.thread_id, block_id, block[Side::kLhs], + block[Side::kRhs]); + break; + } + default: + RUY_CHECK(false); } } -}; - -void DumpTrace(const Trace& trace) { - ProcessedTrace processed_trace; - processed_trace.Process(trace); - processed_trace.Dump(); + fprintf(trace_file, "End: %.9f\n", + ToSeconds(trace.time_end - trace.time_start)); + if (trace_filename) { + fclose(trace_file); + } } +} // anonymous namespace + +// Get a Trace object to record to, or null of tracing is not enabled. Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) { if (!tracing->initialized) { tracing->initialized = true; @@ -254,122 +207,114 @@ Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) { return tracing->trace; } +// The trace recorded on a context is finalized and dumped by +// this TracingContext destructor. +// +// The idea of dumping on context destructor is that typically one wants to +// run many matrix multiplications, e.g. to hit a steady state in terms of +// performance characteristics, but only trace the last repetition of the +// workload, when that steady state was attained. TracingContext::~TracingContext() { if (trace) { - DumpTrace(*trace); + AggregateThreadSpecificEntries(trace); + Sort(trace); + Dump(*trace); } delete trace; } -void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) { - if (trace) { - RUY_DCHECK(trace->life_stage == - Trace::LifeStage::kRecordingBlockAndThreadFields); - relaxed_atomic_store(&trace->block_entries[thread_id].thread_id, thread_id); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->block_entries[thread_id].time_reserved, now); - relaxed_atomic_store(&trace->thread_entries[thread_id].time_start, now); - } -} - -void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) { +void TraceRecordStart(Trace* trace) { if (trace) { - RUY_DCHECK(trace->life_stage == - Trace::LifeStage::kRecordingBlockAndThreadFields); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->thread_entries[thread_id].time_loop_start, - now); + trace->time_start = Clock::now(); } } -void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace) { +void TraceRecordExecute(const BlockMap& block_map, int thread_count, + Trace* trace) { if (trace) { - RUY_DCHECK(trace->life_stage == - Trace::LifeStage::kRecordingBlockAndThreadFields); - // This is typically called on the next block id just obtained by atomic - // increment; this may be out of range. - if (block_id < trace->block_entries.size()) { - relaxed_atomic_store(&trace->block_entries[block_id].thread_id, - thread_id); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->block_entries[block_id].time_reserved, now); + trace->time_execute = Clock::now(); + trace->block_map = block_map; + trace->thread_count = thread_count; + trace->thread_specific_entries.resize(thread_count); + for (int thread = 0; thread < thread_count; thread++) { + trace->thread_specific_entries[thread].clear(); + // Reserve some large size to avoid frequent heap allocations + // affecting the recorded timings. + trace->thread_specific_entries[thread].reserve(16384); } } } -void TraceRecordBlockCoordsComputed(std::uint32_t block_id, Trace* trace) { - if (trace) { - RUY_DCHECK(trace->life_stage == - Trace::LifeStage::kRecordingBlockAndThreadFields); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->block_entries[block_id].time_computed_coords, - now); - } -} - -void TraceRecordBlockPacked(Side side, std::uint32_t block_id, Trace* trace) { +void TraceRecordEnd(Trace* trace) { if (trace) { - RUY_DCHECK(trace->life_stage == - Trace::LifeStage::kRecordingBlockAndThreadFields); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->block_entries[block_id].time_packed[side], - now); + trace->time_end = Clock::now(); } } -void TraceRecordBlockFinished(std::uint32_t block_id, Trace* trace) { +void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) { if (trace) { - RUY_DCHECK(trace->life_stage == - Trace::LifeStage::kRecordingBlockAndThreadFields); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->block_entries[block_id].time_finished, now); + TraceEntry entry; + entry.event = TraceEvent::kThreadStart; + entry.time_point = Clock::now(); + entry.thread_id = thread_id; + trace->thread_specific_entries[thread_id].push_back(entry); } } -void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) { +void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) { if (trace) { - RUY_DCHECK(trace->life_stage == - Trace::LifeStage::kRecordingBlockAndThreadFields); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->thread_entries[thread_id].time_end, now); + TraceEntry entry; + entry.event = TraceEvent::kThreadLoopStart; + entry.time_point = Clock::now(); + entry.thread_id = thread_id; + trace->thread_specific_entries[thread_id].push_back(entry); } } -void TraceRecordStart(Trace* trace) { +void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace) { if (trace) { - RUY_DCHECK(trace->life_stage == Trace::LifeStage::kInitial || - trace->life_stage == Trace::LifeStage::kComplete); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->time_start, now); - trace->life_stage = Trace::LifeStage::kRecordingRootFields; + TraceEntry entry; + entry.event = TraceEvent::kBlockReserved; + entry.time_point = Clock::now(); + entry.thread_id = thread_id; + entry.params[0] = block_id; + trace->thread_specific_entries[thread_id].push_back(entry); } } -void TraceRecordExecute(Trace* trace) { +void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, + Trace* trace) { if (trace) { - RUY_DCHECK(trace->life_stage == Trace::LifeStage::kRecordingRootFields); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->time_execute, now); + TraceEntry entry; + entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs + : TraceEvent::kBlockPackedRhs; + entry.time_point = Clock::now(); + entry.thread_id = thread_id; + entry.params[0] = block; + trace->thread_specific_entries[thread_id].push_back(entry); } } -void TraceRecordEnd(Trace* trace) { +void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace) { if (trace) { - RUY_DCHECK(trace->life_stage == - Trace::LifeStage::kRecordingBlockAndThreadFields); - TimePoint now = Clock::now(); - relaxed_atomic_store(&trace->time_end, now); - trace->life_stage = Trace::LifeStage::kComplete; + TraceEntry entry; + entry.event = TraceEvent::kBlockFinished; + entry.time_point = Clock::now(); + entry.thread_id = thread_id; + entry.params[0] = block_id; + trace->thread_specific_entries[thread_id].push_back(entry); } } -void TraceStartRecordingBlockAndThreadFields(const BlockMap& block_map, - int thread_count, Trace* trace) { +void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) { if (trace) { - RUY_DCHECK(trace->life_stage == Trace::LifeStage::kRecordingRootFields); - trace->StartRecordingBlockAndThreadFields(block_map, thread_count); - trace->life_stage = Trace::LifeStage::kRecordingBlockAndThreadFields; + TraceEntry entry; + entry.event = TraceEvent::kThreadEnd; + entry.time_point = Clock::now(); + entry.thread_id = thread_id; + trace->thread_specific_entries[thread_id].push_back(entry); } } @@ -22,7 +22,6 @@ limitations under the License. #include <vector> #include "block_map.h" -#include "side_pair.h" namespace ruy { @@ -40,22 +39,20 @@ struct TracingContext { ~TracingContext(); }; -void DumpTrace(const Trace& trace); - Trace* NewTraceOrNull(TracingContext* context, int rows, int depth, int cols); void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace); void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace); void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, Trace* trace); -void TraceRecordBlockCoordsComputed(std::uint32_t block_id, Trace* trace); -void TraceRecordBlockPacked(Side side, std::uint32_t block_id, Trace* trace); -void TraceRecordBlockFinished(std::uint32_t block_id, Trace* trace); +void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, + Trace* trace); +void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace); void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace); void TraceRecordStart(Trace* trace); -void TraceRecordExecute(Trace* trace); +void TraceRecordExecute(const BlockMap& block_map, int thread_count, + Trace* trace); void TraceRecordEnd(Trace* trace); -void TraceStartRecordingBlockAndThreadFields(const BlockMap& block_map, - int thread_count, Trace* trace); #else @@ -65,15 +62,12 @@ inline Trace* NewTraceOrNull(TracingContext*, int, int, int) { return nullptr; } inline void TraceRecordThreadStart(std::uint32_t, Trace*) {} inline void TraceRecordThreadLoopStart(std::uint32_t, Trace*) {} inline void TraceRecordBlockReserved(std::uint32_t, std::uint32_t, Trace*) {} -inline void TraceRecordBlockCoordsComputed(std::uint32_t, Trace*) {} -inline void TraceRecordBlockPacked(Side, std::uint32_t, Trace*) {} -inline void TraceRecordBlockFinished(std::uint32_t, Trace*) {} +inline void TraceRecordBlockPacked(std::uint32_t, Side, int, Trace*) {} +inline void TraceRecordBlockFinished(std::uint32_t, std::uint32_t, Trace*) {} inline void TraceRecordThreadEnd(std::uint32_t, Trace*) {} inline void TraceRecordStart(Trace*) {} -inline void TraceRecordExecute(Trace*) {} +inline void TraceRecordExecute(const BlockMap&, int, Trace*) {} inline void TraceRecordEnd(Trace*) {} -inline void TraceStartRecordingBlockAndThreadFields(const BlockMap&, int, - Trace*) {} #endif @@ -90,12 +90,11 @@ struct TrMulTask final : Task { GetBlockByIndex(block_map, block_id, &block); // Get coordinates of the current block to handle, in matrix space. GetBlockMatrixCoords(block_map, block, &start, &end); - TraceRecordBlockCoordsComputed(block_id, trace); // Maybe pack the current LHS/RHS block, if not already packed. - EnsurePacked(block_id, local_packed, block, start, end, tuning); + EnsurePacked(local_packed, block, start, end, tuning); // Actually do matrix multiplication work params->RunKernel(tuning, start, end); - TraceRecordBlockFinished(block_id, trace); + TraceRecordBlockFinished(thread_id, block_id, trace); // Move on to the next block as obtained by the atomic increment // at the start of this while loop iteration. block_id = next_block_id; @@ -107,8 +106,8 @@ struct TrMulTask final : Task { } private: - bool TryEnsurePacked(Side side, int block_id, bool* local_packed, int block, - int start, int end, Tuning tuning) { + bool TryEnsurePacked(Side side, bool* local_packed, int block, int start, + int end, Tuning tuning) { if (local_packed && !local_packed[block]) { PackingStatus not_started = PackingStatus::kNotStarted; std::atomic<PackingStatus>& status = packing_status[side][block]; @@ -119,7 +118,7 @@ struct TrMulTask final : Task { // changed it to kInProgress as we are about to handle the packing // ourselves. params->RunPack(side, tuning, start, end); - TraceRecordBlockPacked(side, block_id, trace); + TraceRecordBlockPacked(thread_id, side, block, trace); status.store(PackingStatus::kFinished, std::memory_order_release); } else if (status.load(std::memory_order_acquire) == PackingStatus::kInProgress) { @@ -133,15 +132,15 @@ struct TrMulTask final : Task { return true; } - void EnsurePacked(int block_id, const SidePair<bool*> local_packed, + void EnsurePacked(const SidePair<bool*> local_packed, const SidePair<int>& block, const SidePair<int>& start, const SidePair<int>& end, Tuning tuning) { while (true) { bool both_sides_packed = true; for (Side side : {Side::kLhs, Side::kRhs}) { both_sides_packed &= - TryEnsurePacked(side, block_id, local_packed[side], block[side], - start[side], end[side], tuning); + TryEnsurePacked(side, local_packed[side], block[side], start[side], + end[side], tuning); } if (both_sides_packed) { break; @@ -282,9 +281,7 @@ void TrMul(TrMulParams* params, Context* context) { } // Do the computation. - TraceRecordExecute(trace); - TraceStartRecordingBlockAndThreadFields(block_map, thread_count, trace); - + TraceRecordExecute(block_map, thread_count, trace); context->workers_pool.Execute(thread_count, tasks); // Finish up. @@ -292,9 +289,8 @@ void TrMul(TrMulParams* params, Context* context) { tasks[i].~TrMulTask(); } - TraceRecordEnd(trace); - allocator->FreeAll(); + TraceRecordEnd(trace); } } // namespace ruy |