aboutsummaryrefslogtreecommitdiff
path: root/src/buffer.cc
diff options
context:
space:
mode:
authordan sinclair <dsinclair@chromium.org>2019-08-03 13:51:11 -0400
committerGitHub <noreply@github.com>2019-08-03 13:51:11 -0400
commite9e2a52ea8282b482c8e4026ed23e986c1fb6177 (patch)
treef66927b5ef85fdec9e6f4ea3553562883bc557f1 /src/buffer.cc
parent5bb9c644ec17a298d901152e9a8e10f8d08c2ee3 (diff)
downloadamber-e9e2a52ea8282b482c8e4026ed23e986c1fb6177.tar.gz
Add root mean square error buffer comparison (#602)
This CL adds a `RMSE_BUFFER` comparator which can be used with two buffers to verify their values are within a given tolerance using the root mean square error comparison method. Issue #600
Diffstat (limited to 'src/buffer.cc')
-rw-r--r--src/buffer.cc92
1 files changed, 92 insertions, 0 deletions
diff --git a/src/buffer.cc b/src/buffer.cc
index 617ff7a..f160923 100644
--- a/src/buffer.cc
+++ b/src/buffer.cc
@@ -15,6 +15,7 @@
#include "src/buffer.h"
#include <cassert>
+#include <cmath>
#include <cstring>
namespace amber {
@@ -53,6 +54,45 @@ T* ValuesAs(uint8_t* values) {
return reinterpret_cast<T*>(values);
}
+template <typename T>
+double Sub(const uint8_t* buf1, const uint8_t* buf2) {
+ return static_cast<double>(*reinterpret_cast<const T*>(buf1) -
+ *reinterpret_cast<const T*>(buf2));
+}
+
+double CalculateDiff(const Format::Component& comp,
+ const uint8_t* buf1,
+ const uint8_t* buf2) {
+ if (comp.IsInt8())
+ return Sub<int8_t>(buf1, buf2);
+ if (comp.IsInt16())
+ return Sub<int16_t>(buf1, buf2);
+ if (comp.IsInt32())
+ return Sub<int32_t>(buf1, buf2);
+ if (comp.IsInt64())
+ return Sub<int64_t>(buf1, buf2);
+ if (comp.IsUint8())
+ return Sub<uint8_t>(buf1, buf2);
+ if (comp.IsUint16())
+ return Sub<uint16_t>(buf1, buf2);
+ if (comp.IsUint32())
+ return Sub<uint32_t>(buf1, buf2);
+ if (comp.IsUint64())
+ return Sub<uint64_t>(buf1, buf2);
+ // TOOD(dsinclair): Handle float16 ...
+ if (comp.IsFloat16()) {
+ assert(false && "Float16 suppport not implemented");
+ return 0.0;
+ }
+ if (comp.IsFloat())
+ return Sub<float>(buf1, buf2);
+ if (comp.IsDouble())
+ return Sub<double>(buf1, buf2);
+
+ assert(false && "NOTREACHED");
+ return 0.0;
+}
+
} // namespace
Buffer::Buffer() = default;
@@ -111,6 +151,58 @@ Result Buffer::IsEqual(Buffer* buffer) const {
return {};
}
+std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const {
+ std::vector<double> diffs;
+
+ auto* buf_1_ptr = GetValues<uint8_t>();
+ auto* buf_2_ptr = buffer->GetValues<uint8_t>();
+ auto comps = format_->GetComponents();
+
+ for (size_t i = 0; i < ElementCount(); ++i) {
+ for (size_t j = 0; j < format_->ColumnCount(); ++j) {
+ auto* buf_1_row_ptr = buf_1_ptr;
+ auto* buf_2_row_ptr = buf_2_ptr;
+ for (size_t k = 0; k < format_->RowCount(); ++k) {
+ diffs.push_back(CalculateDiff(comps[k], buf_1_row_ptr, buf_2_row_ptr));
+
+ buf_1_row_ptr += comps[k].SizeInBytes();
+ buf_2_row_ptr += comps[k].SizeInBytes();
+ }
+ buf_1_ptr += format_->SizeInBytesPerRow();
+ buf_2_ptr += format_->SizeInBytesPerRow();
+ }
+ }
+
+ return diffs;
+}
+
+Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
+ if (!buffer->format_->Equal(format_.get()))
+ return Result{"Buffers have a different format"};
+ if (buffer->element_count_ != element_count_)
+ return Result{"Buffers have a different size"};
+ if (buffer->width_ != width_)
+ return Result{"Buffers have a different width"};
+ if (buffer->height_ != height_)
+ return Result{"Buffers have a different height"};
+ if (buffer->ValueCount() != ValueCount())
+ return Result{"Buffers have a different number of values"};
+
+ auto diffs = CalculateDiffs(buffer);
+ double sum = 0.0;
+ for (const auto val : diffs)
+ sum += (val * val);
+
+ sum /= diffs.size();
+ double rmse = std::sqrt(sum);
+ if (rmse > static_cast<double>(tolerance)) {
+ return Result("Root Mean Square Error of " + std::to_string(rmse) +
+ " is greater then tolerance of " + std::to_string(tolerance));
+ }
+
+ return {};
+}
+
Result Buffer::SetData(const std::vector<Value>& data) {
return SetDataWithOffset(data, 0);
}