aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--BUILD1
-rw-r--r--block_map.cc18
-rw-r--r--block_map.h4
-rw-r--r--trace.cc431
-rw-r--r--trace.h24
-rw-r--r--trmul.cc24
6 files changed, 218 insertions, 284 deletions
diff --git a/BUILD b/BUILD
index 5fd7936..33e8f14 100644
--- a/BUILD
+++ b/BUILD
@@ -195,7 +195,6 @@ cc_library(
deps = [
":block_map",
":check_macros",
- ":common",
":side_pair",
":time",
],
diff --git a/block_map.cc b/block_map.cc
index 17fe906..3b41ac8 100644
--- a/block_map.cc
+++ b/block_map.cc
@@ -202,15 +202,15 @@ void MakeBlockMap(int rows, int cols, int depth, int kernel_rows,
block_map->large_blocks[Side::kRhs] = missc;
}
-void GetBlockMatrixCoords(Side side, const BlockMap& block_map,
- const SidePair<int>& block, int* start, int* end) {
+void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
+ int* start, int* end) {
gemmlowp::ScopedProfilingLabel label("GetBlockMatrixCoords");
- const int b = block[side];
- *start =
- b * block_map.small_block_dims[side] +
- std::min(b, block_map.large_blocks[side]) * block_map.kernel_dims[side];
- *end = *start + block_map.small_block_dims[side] +
- (b < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0);
+ *start = block * block_map.small_block_dims[side] +
+ std::min(block, block_map.large_blocks[side]) *
+ block_map.kernel_dims[side];
+ *end =
+ *start + block_map.small_block_dims[side] +
+ (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0);
RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]);
RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]);
@@ -222,7 +222,7 @@ void GetBlockMatrixCoords(Side side, const BlockMap& block_map,
void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair<int>& block,
SidePair<int>* start, SidePair<int>* end) {
for (Side side : {Side::kLhs, Side::kRhs}) {
- GetBlockMatrixCoords(side, block_map, block, &(*start)[side],
+ GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side],
&(*end)[side]);
}
}
diff --git a/block_map.h b/block_map.h
index 415cac7..deb2b38 100644
--- a/block_map.h
+++ b/block_map.h
@@ -115,8 +115,8 @@ void GetBlockByIndex(const BlockMap& block_map, int index,
// position in the matrix that the BlockMap refers to in the dimension
// referred to by `side`: along rows if side==kLhs, along columns if
// side==kRhs.
-void GetBlockMatrixCoords(Side side, const BlockMap& block_map,
- const SidePair<int>& block, int* start, int* end);
+void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block,
+ int* start, int* end);
// Given a block position in the grid, returns its actual
// position in the matrix that the BlockMap refers to in terms of
diff --git a/trace.cc b/trace.cc
index 6704303..8e28b47 100644
--- a/trace.cc
+++ b/trace.cc
@@ -24,200 +24,153 @@ limitations under the License.
#include "block_map.h"
#include "check_macros.h"
-#include "common.h"
+#include "side_pair.h"
#include "time.h"
namespace ruy {
#ifdef RUY_TRACE
-struct BlockTraceEntry {
- std::uint32_t thread_id = 0;
- TimePoint time_reserved;
- TimePoint time_computed_coords;
- SidePair<TimePoint> time_packed;
- TimePoint time_finished;
+enum class TraceEvent : std::uint8_t {
+ kNone,
+ kThreadStart,
+ kThreadLoopStart,
+ kThreadEnd,
+ kBlockReserved,
+ kBlockPackedLhs,
+ kBlockPackedRhs,
+ kBlockFinished
};
-struct ThreadTraceEntry {
- TimePoint time_start;
- TimePoint time_loop_start;
- TimePoint time_end;
+struct TraceEntry {
+ TimePoint time_point;
+ TraceEvent event;
+ // ruy-internal thread id i.e. contiguous index into array of threads,
+ // with 0 designating the main thread.
+ std::uint16_t thread_id = 0;
+ // Additional parameters whose meaning depends on the 'event' type.
+ std::uint32_t params[1];
};
struct Trace {
- enum class LifeStage {
- kInitial,
- kRecordingRootFields,
- kRecordingBlockAndThreadFields,
- kComplete
- };
- void StartRecordingBlockAndThreadFields(const BlockMap& block_map_,
- int thread_count_) {
- RUY_DCHECK(life_stage == LifeStage::kRecordingRootFields);
- block_map = block_map_;
- thread_count = thread_count_;
- int num_blocks = NumBlocks(block_map);
- if (num_blocks > block_entries.size()) {
- block_entries.resize(NumBlocks(block_map));
- }
- if (thread_count > thread_entries.size()) {
- thread_entries.resize(thread_count);
- }
- life_stage = LifeStage::kRecordingBlockAndThreadFields;
- }
BlockMap block_map;
int thread_count = 0;
- std::vector<BlockTraceEntry> block_entries;
- std::vector<ThreadTraceEntry> thread_entries;
+ // During recording, to avoid having to use locks or atomics, we let
+ // each thread append to its own specific vector.
+ std::vector<std::vector<TraceEntry>> thread_specific_entries;
+ // Global vector of entries into which we coalesce thread_specific_entries
+ // after recording is finished, when dumping a trace. See
+ // AggregateThreadSpecificEntries.
+ std::vector<TraceEntry> entries;
TimePoint time_start;
TimePoint time_execute;
TimePoint time_end;
- LifeStage life_stage = LifeStage::kInitial;
};
-struct ProcessedTrace {
- enum class Event : std::uint8_t {
- kNone,
- kThreadStart,
- kThreadLoopStart,
- kThreadEnd,
- kBlockReserved,
- kBlockComputedCoords,
- kBlockPackedLhs,
- kBlockPackedRhs,
- kBlockFinished
- };
- struct Entry {
- Event event = Event::kNone;
- std::uint32_t thread_id = 0;
- std::uint32_t block_id = 0;
- TimePoint time;
- };
+namespace {
- BlockMap block_map;
- int thread_count = 0;
- TimePoint time_start;
- TimePoint time_execute;
- TimePoint time_end;
- std::vector<Entry> entries;
- void Add(Event event, std::uint32_t thread_id, std::uint32_t block_id,
- TimePoint time) {
- // If the time point is still in its default-constructed state,
- // that means we didn't record it.
- if (!time.time_since_epoch().count()) {
- return;
+// Coalesce Trace::thread_specific_entries into Trace::entries.
+void AggregateThreadSpecificEntries(Trace* trace) {
+ RUY_CHECK(trace->entries.empty());
+ for (auto& thread_specific_entries_vector : trace->thread_specific_entries) {
+ for (const TraceEntry& entry : thread_specific_entries_vector) {
+ trace->entries.push_back(entry);
}
- Entry entry;
- entry.event = event;
- entry.thread_id = thread_id;
- entry.block_id = block_id;
- entry.time = time;
- entries.push_back(entry);
+ thread_specific_entries_vector.clear();
}
- void Process(const Trace& trace) {
- thread_count = trace.thread_count;
- block_map = trace.block_map;
- time_start = trace.time_start;
- time_execute = trace.time_execute;
- time_end = trace.time_end;
- entries.clear();
- for (int i = 0; i < trace.thread_count; i++) {
- const auto& entry = trace.thread_entries[i];
- Add(Event::kThreadStart, i, 0, entry.time_start);
- Add(Event::kThreadLoopStart, i, 0, entry.time_loop_start);
- Add(Event::kThreadEnd, i, 0, entry.time_end);
- }
- std::uint32_t num_blocks = NumBlocks(block_map);
- for (int i = 0; i < num_blocks; i++) {
- const auto& entry = trace.block_entries[i];
- Add(Event::kBlockReserved, entry.thread_id, i, entry.time_reserved);
- Add(Event::kBlockComputedCoords, entry.thread_id, i,
- entry.time_computed_coords);
- Add(Event::kBlockPackedLhs, entry.thread_id, i,
- entry.time_packed[Side::kLhs]);
- Add(Event::kBlockPackedRhs, entry.thread_id, i,
- entry.time_packed[Side::kRhs]);
- Add(Event::kBlockFinished, entry.thread_id, i, entry.time_finished);
- }
- std::sort(entries.begin(), entries.end(),
- [](const Entry& a, const Entry& b) -> bool {
- return a.time < b.time ||
- (a.time == b.time &&
- static_cast<int>(a.event) < static_cast<int>(b.event));
- });
+}
+
+// Sort Trace::entries by ascending time. In case of equal timepoints,
+// sort by some semi-arbitrary ordering of event types.
+void Sort(Trace* trace) {
+ std::sort(std::begin(trace->entries), std::end(trace->entries),
+ [](const TraceEntry& a, const TraceEntry& b) -> bool {
+ return a.time_point < b.time_point ||
+ (a.time_point == b.time_point &&
+ static_cast<int>(a.event) < static_cast<int>(b.event));
+ });
+}
+
+// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have
+// already been called on it.
+void Dump(const Trace& trace) {
+ const char* trace_filename = getenv("RUY_TRACE_FILE");
+ FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr;
+ if (!trace_file) {
+ fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename,
+ errno);
+ RUY_CHECK(false);
}
- void Dump() {
- const char* trace_filename = getenv("RUY_TRACE_FILE");
- FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr;
- if (!trace_file) {
- fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename,
- errno);
- RUY_CHECK(false);
- }
- fprintf(trace_file, "thread_count:%d\n", thread_count);
- fprintf(trace_file, "num_blocks:%d\n", NumBlocks(block_map));
- fprintf(trace_file, "rows:%d\n", block_map.rows);
- fprintf(trace_file, "cols:%d\n", block_map.cols);
- fprintf(trace_file, "Execute: %.9f\n",
- ToSeconds(time_execute - time_start));
- for (const Entry& entry : entries) {
- double time = ToSeconds(entry.time - time_start);
- switch (entry.event) {
- case Event::kThreadStart:
- fprintf(trace_file, "ThreadStart: %.9f, %d\n", time, entry.thread_id);
- break;
- case Event::kThreadLoopStart:
- fprintf(trace_file, "ThreadLoopStart: %.9f, %d\n", time,
- entry.thread_id);
- break;
- case Event::kThreadEnd:
- fprintf(trace_file, "ThreadEnd: %.9f, %d\n", time, entry.thread_id);
- break;
- case Event::kBlockReserved: {
- std::uint16_t block_r, block_c;
- int start_r, start_c, end_r, end_c;
- GetBlockByIndex(block_map, entry.block_id, &block_r, &block_c);
- GetBlockMatrixCoords(block_map, block_r, block_c, &start_r, &start_c,
- &end_r, &end_c);
- fprintf(trace_file, "BlockReserved: %.9f, %d, %d, %d, %d, %d, %d\n",
- time, entry.thread_id, entry.block_id, start_r, start_c,
- end_r, end_c);
- break;
- }
- case Event::kBlockComputedCoords:
- fprintf(trace_file, "BlockComputedCoords: %.9f, %d, %d\n", time,
- entry.thread_id, entry.block_id);
- break;
- case Event::kBlockPackedLhs:
- fprintf(trace_file, "BlockPackedLhs: %.9f, %d, %d\n", time,
- entry.thread_id, entry.block_id);
- break;
- case Event::kBlockPackedRhs:
- fprintf(trace_file, "BlockPackedRhs: %.9f, %d, %d\n", time,
- entry.thread_id, entry.block_id);
- break;
- case Event::kBlockFinished:
- fprintf(trace_file, "BlockFinished: %.9f, %d, %d\n", time,
- entry.thread_id, entry.block_id);
- break;
- default:
- RUY_CHECK(false);
+ fprintf(trace_file, "thread_count:%d\n", trace.thread_count);
+ fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]);
+ fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]);
+ fprintf(trace_file, "Execute: %.9f\n",
+ ToSeconds(trace.time_execute - trace.time_start));
+ for (const TraceEntry& entry : trace.entries) {
+ double time = ToSeconds(entry.time_point - trace.time_start);
+ switch (entry.event) {
+ case TraceEvent::kThreadStart:
+ fprintf(trace_file, "ThreadStart: %.9f, %d\n", time, entry.thread_id);
+ break;
+ case TraceEvent::kThreadLoopStart:
+ fprintf(trace_file, "ThreadLoopStart: %.9f, %d\n", time,
+ entry.thread_id);
+ break;
+ case TraceEvent::kThreadEnd:
+ fprintf(trace_file, "ThreadEnd: %.9f, %d\n", time, entry.thread_id);
+ break;
+ case TraceEvent::kBlockReserved: {
+ std::uint32_t block_id = entry.params[0];
+ SidePair<int> block;
+ GetBlockByIndex(trace.block_map, block_id, &block);
+ SidePair<int> start, end;
+ GetBlockMatrixCoords(trace.block_map, block, &start, &end);
+ fprintf(trace_file,
+ "BlockReserved: %.9f, %d, %d, %d, %d, %d, %d, %d, %d\n", time,
+ entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs],
+ start[Side::kLhs], start[Side::kRhs], end[Side::kLhs],
+ end[Side::kRhs]);
+ break;
}
- }
- fprintf(trace_file, "End: %.9f\n", ToSeconds(time_end - time_start));
- if (trace_filename) {
- fclose(trace_file);
+ case TraceEvent::kBlockPackedLhs: {
+ std::uint32_t block = entry.params[0];
+ int start, end;
+ GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end);
+ fprintf(trace_file, "BlockPackedLhs: %.9f, %d, %d, %d, %d\n", time,
+ entry.thread_id, block, start, end);
+ break;
+ }
+ case TraceEvent::kBlockPackedRhs: {
+ std::uint32_t block = entry.params[0];
+ int start, end;
+ GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end);
+ fprintf(trace_file, "BlockPackedRhs: %.9f, %d, %d, %d, %d\n", time,
+ entry.thread_id, block, start, end);
+ break;
+ }
+ case TraceEvent::kBlockFinished: {
+ std::uint32_t block_id = entry.params[0];
+ SidePair<int> block;
+ GetBlockByIndex(trace.block_map, block_id, &block);
+ fprintf(trace_file, "BlockFinished: %.9f, %d, %d, %d, %d\n", time,
+ entry.thread_id, block_id, block[Side::kLhs],
+ block[Side::kRhs]);
+ break;
+ }
+ default:
+ RUY_CHECK(false);
}
}
-};
-
-void DumpTrace(const Trace& trace) {
- ProcessedTrace processed_trace;
- processed_trace.Process(trace);
- processed_trace.Dump();
+ fprintf(trace_file, "End: %.9f\n",
+ ToSeconds(trace.time_end - trace.time_start));
+ if (trace_filename) {
+ fclose(trace_file);
+ }
}
+} // anonymous namespace
+
+// Get a Trace object to record to, or null of tracing is not enabled.
Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) {
if (!tracing->initialized) {
tracing->initialized = true;
@@ -254,122 +207,114 @@ Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) {
return tracing->trace;
}
+// The trace recorded on a context is finalized and dumped by
+// this TracingContext destructor.
+//
+// The idea of dumping on context destructor is that typically one wants to
+// run many matrix multiplications, e.g. to hit a steady state in terms of
+// performance characteristics, but only trace the last repetition of the
+// workload, when that steady state was attained.
TracingContext::~TracingContext() {
if (trace) {
- DumpTrace(*trace);
+ AggregateThreadSpecificEntries(trace);
+ Sort(trace);
+ Dump(*trace);
}
delete trace;
}
-void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) {
- if (trace) {
- RUY_DCHECK(trace->life_stage ==
- Trace::LifeStage::kRecordingBlockAndThreadFields);
- relaxed_atomic_store(&trace->block_entries[thread_id].thread_id, thread_id);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->block_entries[thread_id].time_reserved, now);
- relaxed_atomic_store(&trace->thread_entries[thread_id].time_start, now);
- }
-}
-
-void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) {
+void TraceRecordStart(Trace* trace) {
if (trace) {
- RUY_DCHECK(trace->life_stage ==
- Trace::LifeStage::kRecordingBlockAndThreadFields);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->thread_entries[thread_id].time_loop_start,
- now);
+ trace->time_start = Clock::now();
}
}
-void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id,
- Trace* trace) {
+void TraceRecordExecute(const BlockMap& block_map, int thread_count,
+ Trace* trace) {
if (trace) {
- RUY_DCHECK(trace->life_stage ==
- Trace::LifeStage::kRecordingBlockAndThreadFields);
- // This is typically called on the next block id just obtained by atomic
- // increment; this may be out of range.
- if (block_id < trace->block_entries.size()) {
- relaxed_atomic_store(&trace->block_entries[block_id].thread_id,
- thread_id);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->block_entries[block_id].time_reserved, now);
+ trace->time_execute = Clock::now();
+ trace->block_map = block_map;
+ trace->thread_count = thread_count;
+ trace->thread_specific_entries.resize(thread_count);
+ for (int thread = 0; thread < thread_count; thread++) {
+ trace->thread_specific_entries[thread].clear();
+ // Reserve some large size to avoid frequent heap allocations
+ // affecting the recorded timings.
+ trace->thread_specific_entries[thread].reserve(16384);
}
}
}
-void TraceRecordBlockCoordsComputed(std::uint32_t block_id, Trace* trace) {
- if (trace) {
- RUY_DCHECK(trace->life_stage ==
- Trace::LifeStage::kRecordingBlockAndThreadFields);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->block_entries[block_id].time_computed_coords,
- now);
- }
-}
-
-void TraceRecordBlockPacked(Side side, std::uint32_t block_id, Trace* trace) {
+void TraceRecordEnd(Trace* trace) {
if (trace) {
- RUY_DCHECK(trace->life_stage ==
- Trace::LifeStage::kRecordingBlockAndThreadFields);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->block_entries[block_id].time_packed[side],
- now);
+ trace->time_end = Clock::now();
}
}
-void TraceRecordBlockFinished(std::uint32_t block_id, Trace* trace) {
+void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) {
if (trace) {
- RUY_DCHECK(trace->life_stage ==
- Trace::LifeStage::kRecordingBlockAndThreadFields);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->block_entries[block_id].time_finished, now);
+ TraceEntry entry;
+ entry.event = TraceEvent::kThreadStart;
+ entry.time_point = Clock::now();
+ entry.thread_id = thread_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
}
}
-void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) {
+void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) {
if (trace) {
- RUY_DCHECK(trace->life_stage ==
- Trace::LifeStage::kRecordingBlockAndThreadFields);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->thread_entries[thread_id].time_end, now);
+ TraceEntry entry;
+ entry.event = TraceEvent::kThreadLoopStart;
+ entry.time_point = Clock::now();
+ entry.thread_id = thread_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
}
}
-void TraceRecordStart(Trace* trace) {
+void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id,
+ Trace* trace) {
if (trace) {
- RUY_DCHECK(trace->life_stage == Trace::LifeStage::kInitial ||
- trace->life_stage == Trace::LifeStage::kComplete);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->time_start, now);
- trace->life_stage = Trace::LifeStage::kRecordingRootFields;
+ TraceEntry entry;
+ entry.event = TraceEvent::kBlockReserved;
+ entry.time_point = Clock::now();
+ entry.thread_id = thread_id;
+ entry.params[0] = block_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
}
}
-void TraceRecordExecute(Trace* trace) {
+void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block,
+ Trace* trace) {
if (trace) {
- RUY_DCHECK(trace->life_stage == Trace::LifeStage::kRecordingRootFields);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->time_execute, now);
+ TraceEntry entry;
+ entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs
+ : TraceEvent::kBlockPackedRhs;
+ entry.time_point = Clock::now();
+ entry.thread_id = thread_id;
+ entry.params[0] = block;
+ trace->thread_specific_entries[thread_id].push_back(entry);
}
}
-void TraceRecordEnd(Trace* trace) {
+void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id,
+ Trace* trace) {
if (trace) {
- RUY_DCHECK(trace->life_stage ==
- Trace::LifeStage::kRecordingBlockAndThreadFields);
- TimePoint now = Clock::now();
- relaxed_atomic_store(&trace->time_end, now);
- trace->life_stage = Trace::LifeStage::kComplete;
+ TraceEntry entry;
+ entry.event = TraceEvent::kBlockFinished;
+ entry.time_point = Clock::now();
+ entry.thread_id = thread_id;
+ entry.params[0] = block_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
}
}
-void TraceStartRecordingBlockAndThreadFields(const BlockMap& block_map,
- int thread_count, Trace* trace) {
+void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) {
if (trace) {
- RUY_DCHECK(trace->life_stage == Trace::LifeStage::kRecordingRootFields);
- trace->StartRecordingBlockAndThreadFields(block_map, thread_count);
- trace->life_stage = Trace::LifeStage::kRecordingBlockAndThreadFields;
+ TraceEntry entry;
+ entry.event = TraceEvent::kThreadEnd;
+ entry.time_point = Clock::now();
+ entry.thread_id = thread_id;
+ trace->thread_specific_entries[thread_id].push_back(entry);
}
}
diff --git a/trace.h b/trace.h
index e8664de..79381ef 100644
--- a/trace.h
+++ b/trace.h
@@ -22,7 +22,6 @@ limitations under the License.
#include <vector>
#include "block_map.h"
-#include "side_pair.h"
namespace ruy {
@@ -40,22 +39,20 @@ struct TracingContext {
~TracingContext();
};
-void DumpTrace(const Trace& trace);
-
Trace* NewTraceOrNull(TracingContext* context, int rows, int depth, int cols);
void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace);
void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace);
void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id,
Trace* trace);
-void TraceRecordBlockCoordsComputed(std::uint32_t block_id, Trace* trace);
-void TraceRecordBlockPacked(Side side, std::uint32_t block_id, Trace* trace);
-void TraceRecordBlockFinished(std::uint32_t block_id, Trace* trace);
+void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block,
+ Trace* trace);
+void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id,
+ Trace* trace);
void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace);
void TraceRecordStart(Trace* trace);
-void TraceRecordExecute(Trace* trace);
+void TraceRecordExecute(const BlockMap& block_map, int thread_count,
+ Trace* trace);
void TraceRecordEnd(Trace* trace);
-void TraceStartRecordingBlockAndThreadFields(const BlockMap& block_map,
- int thread_count, Trace* trace);
#else
@@ -65,15 +62,12 @@ inline Trace* NewTraceOrNull(TracingContext*, int, int, int) { return nullptr; }
inline void TraceRecordThreadStart(std::uint32_t, Trace*) {}
inline void TraceRecordThreadLoopStart(std::uint32_t, Trace*) {}
inline void TraceRecordBlockReserved(std::uint32_t, std::uint32_t, Trace*) {}
-inline void TraceRecordBlockCoordsComputed(std::uint32_t, Trace*) {}
-inline void TraceRecordBlockPacked(Side, std::uint32_t, Trace*) {}
-inline void TraceRecordBlockFinished(std::uint32_t, Trace*) {}
+inline void TraceRecordBlockPacked(std::uint32_t, Side, int, Trace*) {}
+inline void TraceRecordBlockFinished(std::uint32_t, std::uint32_t, Trace*) {}
inline void TraceRecordThreadEnd(std::uint32_t, Trace*) {}
inline void TraceRecordStart(Trace*) {}
-inline void TraceRecordExecute(Trace*) {}
+inline void TraceRecordExecute(const BlockMap&, int, Trace*) {}
inline void TraceRecordEnd(Trace*) {}
-inline void TraceStartRecordingBlockAndThreadFields(const BlockMap&, int,
- Trace*) {}
#endif
diff --git a/trmul.cc b/trmul.cc
index 535ad45..a41372a 100644
--- a/trmul.cc
+++ b/trmul.cc
@@ -90,12 +90,11 @@ struct TrMulTask final : Task {
GetBlockByIndex(block_map, block_id, &block);
// Get coordinates of the current block to handle, in matrix space.
GetBlockMatrixCoords(block_map, block, &start, &end);
- TraceRecordBlockCoordsComputed(block_id, trace);
// Maybe pack the current LHS/RHS block, if not already packed.
- EnsurePacked(block_id, local_packed, block, start, end, tuning);
+ EnsurePacked(local_packed, block, start, end, tuning);
// Actually do matrix multiplication work
params->RunKernel(tuning, start, end);
- TraceRecordBlockFinished(block_id, trace);
+ TraceRecordBlockFinished(thread_id, block_id, trace);
// Move on to the next block as obtained by the atomic increment
// at the start of this while loop iteration.
block_id = next_block_id;
@@ -107,8 +106,8 @@ struct TrMulTask final : Task {
}
private:
- bool TryEnsurePacked(Side side, int block_id, bool* local_packed, int block,
- int start, int end, Tuning tuning) {
+ bool TryEnsurePacked(Side side, bool* local_packed, int block, int start,
+ int end, Tuning tuning) {
if (local_packed && !local_packed[block]) {
PackingStatus not_started = PackingStatus::kNotStarted;
std::atomic<PackingStatus>& status = packing_status[side][block];
@@ -119,7 +118,7 @@ struct TrMulTask final : Task {
// changed it to kInProgress as we are about to handle the packing
// ourselves.
params->RunPack(side, tuning, start, end);
- TraceRecordBlockPacked(side, block_id, trace);
+ TraceRecordBlockPacked(thread_id, side, block, trace);
status.store(PackingStatus::kFinished, std::memory_order_release);
} else if (status.load(std::memory_order_acquire) ==
PackingStatus::kInProgress) {
@@ -133,15 +132,15 @@ struct TrMulTask final : Task {
return true;
}
- void EnsurePacked(int block_id, const SidePair<bool*> local_packed,
+ void EnsurePacked(const SidePair<bool*> local_packed,
const SidePair<int>& block, const SidePair<int>& start,
const SidePair<int>& end, Tuning tuning) {
while (true) {
bool both_sides_packed = true;
for (Side side : {Side::kLhs, Side::kRhs}) {
both_sides_packed &=
- TryEnsurePacked(side, block_id, local_packed[side], block[side],
- start[side], end[side], tuning);
+ TryEnsurePacked(side, local_packed[side], block[side], start[side],
+ end[side], tuning);
}
if (both_sides_packed) {
break;
@@ -282,9 +281,7 @@ void TrMul(TrMulParams* params, Context* context) {
}
// Do the computation.
- TraceRecordExecute(trace);
- TraceStartRecordingBlockAndThreadFields(block_map, thread_count, trace);
-
+ TraceRecordExecute(block_map, thread_count, trace);
context->workers_pool.Execute(thread_count, tasks);
// Finish up.
@@ -292,9 +289,8 @@ void TrMul(TrMulParams* params, Context* context) {
tasks[i].~TrMulTask();
}
- TraceRecordEnd(trace);
-
allocator->FreeAll();
+ TraceRecordEnd(trace);
}
} // namespace ruy