diff options
author | dan sinclair <dsinclair@chromium.org> | 2019-08-03 13:51:11 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-08-03 13:51:11 -0400 |
commit | e9e2a52ea8282b482c8e4026ed23e986c1fb6177 (patch) | |
tree | f66927b5ef85fdec9e6f4ea3553562883bc557f1 /src/buffer.cc | |
parent | 5bb9c644ec17a298d901152e9a8e10f8d08c2ee3 (diff) | |
download | amber-e9e2a52ea8282b482c8e4026ed23e986c1fb6177.tar.gz |
Add root mean square error buffer comparison (#602)
This CL adds a `RMSE_BUFFER` comparator which can be used with two
buffers to verify their values are within a given tolerance using the
root mean square error comparison method.
Issue #600
Diffstat (limited to 'src/buffer.cc')
-rw-r--r-- | src/buffer.cc | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/src/buffer.cc b/src/buffer.cc index 617ff7a..f160923 100644 --- a/src/buffer.cc +++ b/src/buffer.cc @@ -15,6 +15,7 @@ #include "src/buffer.h" #include <cassert> +#include <cmath> #include <cstring> namespace amber { @@ -53,6 +54,45 @@ T* ValuesAs(uint8_t* values) { return reinterpret_cast<T*>(values); } +template <typename T> +double Sub(const uint8_t* buf1, const uint8_t* buf2) { + return static_cast<double>(*reinterpret_cast<const T*>(buf1) - + *reinterpret_cast<const T*>(buf2)); +} + +double CalculateDiff(const Format::Component& comp, + const uint8_t* buf1, + const uint8_t* buf2) { + if (comp.IsInt8()) + return Sub<int8_t>(buf1, buf2); + if (comp.IsInt16()) + return Sub<int16_t>(buf1, buf2); + if (comp.IsInt32()) + return Sub<int32_t>(buf1, buf2); + if (comp.IsInt64()) + return Sub<int64_t>(buf1, buf2); + if (comp.IsUint8()) + return Sub<uint8_t>(buf1, buf2); + if (comp.IsUint16()) + return Sub<uint16_t>(buf1, buf2); + if (comp.IsUint32()) + return Sub<uint32_t>(buf1, buf2); + if (comp.IsUint64()) + return Sub<uint64_t>(buf1, buf2); + // TOOD(dsinclair): Handle float16 ... + if (comp.IsFloat16()) { + assert(false && "Float16 suppport not implemented"); + return 0.0; + } + if (comp.IsFloat()) + return Sub<float>(buf1, buf2); + if (comp.IsDouble()) + return Sub<double>(buf1, buf2); + + assert(false && "NOTREACHED"); + return 0.0; +} + } // namespace Buffer::Buffer() = default; @@ -111,6 +151,58 @@ Result Buffer::IsEqual(Buffer* buffer) const { return {}; } +std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const { + std::vector<double> diffs; + + auto* buf_1_ptr = GetValues<uint8_t>(); + auto* buf_2_ptr = buffer->GetValues<uint8_t>(); + auto comps = format_->GetComponents(); + + for (size_t i = 0; i < ElementCount(); ++i) { + for (size_t j = 0; j < format_->ColumnCount(); ++j) { + auto* buf_1_row_ptr = buf_1_ptr; + auto* buf_2_row_ptr = buf_2_ptr; + for (size_t k = 0; k < format_->RowCount(); ++k) { + diffs.push_back(CalculateDiff(comps[k], buf_1_row_ptr, buf_2_row_ptr)); + + buf_1_row_ptr += comps[k].SizeInBytes(); + buf_2_row_ptr += comps[k].SizeInBytes(); + } + buf_1_ptr += format_->SizeInBytesPerRow(); + buf_2_ptr += format_->SizeInBytesPerRow(); + } + } + + return diffs; +} + +Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const { + if (!buffer->format_->Equal(format_.get())) + return Result{"Buffers have a different format"}; + if (buffer->element_count_ != element_count_) + return Result{"Buffers have a different size"}; + if (buffer->width_ != width_) + return Result{"Buffers have a different width"}; + if (buffer->height_ != height_) + return Result{"Buffers have a different height"}; + if (buffer->ValueCount() != ValueCount()) + return Result{"Buffers have a different number of values"}; + + auto diffs = CalculateDiffs(buffer); + double sum = 0.0; + for (const auto val : diffs) + sum += (val * val); + + sum /= diffs.size(); + double rmse = std::sqrt(sum); + if (rmse > static_cast<double>(tolerance)) { + return Result("Root Mean Square Error of " + std::to_string(rmse) + + " is greater then tolerance of " + std::to_string(tolerance)); + } + + return {}; +} + Result Buffer::SetData(const std::vector<Value>& data) { return SetDataWithOffset(data, 0); } |