aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2022-11-14 13:16:07 -0800
committerMichael Butler <butlermichael@google.com>2022-12-05 20:23:24 +0000
commit6d7c5d22ada585428b607105a8d4753a902306a6 (patch)
tree30e21c853048a38edef74490bbae9fa084f7b4c3
parent2426493071139546c79b990f415abb470302d35c (diff)
downloadNeuralNetworks-6d7c5d22ada585428b607105a8d4753a902306a6.tar.gz
Add additional bounds checks to NNAPI FMQ deserialize utility functions
This CL adds the following additional bounds checks: * Adds additional checks of the index of the std::vector before accessing the element at the index * Changes the array index operator [] to the checked std::vector::at method Bug: 256589724 Test: mma Merged-In: I3461c9e33b64e7d44bb3b430c8eb00d794669037 Change-Id: I3461c9e33b64e7d44bb3b430c8eb00d794669037
-rw-r--r--common/ExecutionBurstController.cpp20
-rw-r--r--common/ExecutionBurstServer.cpp36
2 files changed, 33 insertions, 23 deletions
diff --git a/common/ExecutionBurstController.cpp b/common/ExecutionBurstController.cpp
index ac49448a3..b4f80fd5c 100644
--- a/common/ExecutionBurstController.cpp
+++ b/common/ExecutionBurstController.cpp
@@ -156,13 +156,14 @@ deserialize(const std::vector<FmqResultDatum>& data) {
size_t index = 0;
// validate packet information
- if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::packetInformation) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// unpackage packet information
- const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
+ const FmqResultDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
index++;
const uint32_t packetSize = packetInfo.packetSize;
const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
@@ -177,13 +178,14 @@ deserialize(const std::vector<FmqResultDatum>& data) {
// unpackage operands
for (size_t operand = 0; operand < numberOfOperands; ++operand) {
// validate operand information
- if (data[index].getDiscriminator() != discriminator::operandInformation) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::operandInformation) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// unpackage operand information
- const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
+ const FmqResultDatum::OperandInformation& operandInfo = data.at(index).operandInformation();
index++;
const bool isSufficient = operandInfo.isSufficient;
const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
@@ -193,13 +195,14 @@ deserialize(const std::vector<FmqResultDatum>& data) {
dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
- if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::operandDimensionValue) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// unpackage dimension
- const uint32_t dimension = data[index].operandDimensionValue();
+ const uint32_t dimension = data.at(index).operandDimensionValue();
index++;
// store result
@@ -211,13 +214,14 @@ deserialize(const std::vector<FmqResultDatum>& data) {
}
// validate execution timing
- if (data[index].getDiscriminator() != discriminator::executionTiming) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::executionTiming) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// unpackage execution timing
- const V1_2::Timing timing = data[index].executionTiming();
+ const V1_2::Timing timing = data.at(index).executionTiming();
index++;
// validate packet information
diff --git a/common/ExecutionBurstServer.cpp b/common/ExecutionBurstServer.cpp
index eab8e68ed..d119b2f90 100644
--- a/common/ExecutionBurstServer.cpp
+++ b/common/ExecutionBurstServer.cpp
@@ -171,13 +171,14 @@ std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTimin
size_t index = 0;
// validate packet information
- if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::packetInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}
// unpackage packet information
- const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation();
+ const FmqRequestDatum::PacketInformation& packetInfo = data.at(index).packetInformation();
index++;
const uint32_t packetSize = packetInfo.packetSize;
const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands;
@@ -195,14 +196,15 @@ std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTimin
inputs.reserve(numberOfInputOperands);
for (size_t operand = 0; operand < numberOfInputOperands; ++operand) {
// validate input operand information
- if (data[index].getDiscriminator() != discriminator::inputOperandInformation) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::inputOperandInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}
// unpackage operand information
const FmqRequestDatum::OperandInformation& operandInfo =
- data[index].inputOperandInformation();
+ data.at(index).inputOperandInformation();
index++;
const bool hasNoValue = operandInfo.hasNoValue;
const V1_0::DataLocation location = operandInfo.location;
@@ -213,13 +215,14 @@ std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTimin
dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
- if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::inputOperandDimensionValue) {
LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}
// unpackage dimension
- const uint32_t dimension = data[index].inputOperandDimensionValue();
+ const uint32_t dimension = data.at(index).inputOperandDimensionValue();
index++;
// store result
@@ -236,14 +239,15 @@ std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTimin
outputs.reserve(numberOfOutputOperands);
for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) {
// validate output operand information
- if (data[index].getDiscriminator() != discriminator::outputOperandInformation) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::outputOperandInformation) {
LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}
// unpackage operand information
const FmqRequestDatum::OperandInformation& operandInfo =
- data[index].outputOperandInformation();
+ data.at(index).outputOperandInformation();
index++;
const bool hasNoValue = operandInfo.hasNoValue;
const V1_0::DataLocation location = operandInfo.location;
@@ -254,13 +258,14 @@ std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTimin
dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
- if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::outputOperandDimensionValue) {
LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}
// unpackage dimension
- const uint32_t dimension = data[index].outputOperandDimensionValue();
+ const uint32_t dimension = data.at(index).outputOperandDimensionValue();
index++;
// store result
@@ -277,13 +282,14 @@ std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTimin
slots.reserve(numberOfPools);
for (size_t pool = 0; pool < numberOfPools; ++pool) {
// validate input operand information
- if (data[index].getDiscriminator() != discriminator::poolIdentifier) {
+ if (index >= data.size() ||
+ data.at(index).getDiscriminator() != discriminator::poolIdentifier) {
LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}
// unpackage operand information
- const int32_t poolId = data[index].poolIdentifier();
+ const int32_t poolId = data.at(index).poolIdentifier();
index++;
// store result
@@ -291,18 +297,18 @@ std::optional<std::tuple<V1_0::Request, std::vector<int32_t>, V1_2::MeasureTimin
}
// validate measureTiming
- if (data[index].getDiscriminator() != discriminator::measureTiming) {
+ if (index >= data.size() || data.at(index).getDiscriminator() != discriminator::measureTiming) {
LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}
// unpackage measureTiming
- const V1_2::MeasureTiming measure = data[index].measureTiming();
+ const V1_2::MeasureTiming measure = data.at(index).measureTiming();
index++;
// validate packet information
if (index != packetSize) {
- LOG(ERROR) << "FMQ Result packet ill-formed";
+ LOG(ERROR) << "FMQ Request packet ill-formed";
return std::nullopt;
}