From 63a496ee6ad51e1de41cbf698d766fa5a04e88a8 Mon Sep 17 00:00:00 2001 From: Michael Butler Date: Thu, 1 Dec 2022 17:50:37 -0800 Subject: Add additional bounds checks to NNAPI FMQ deserialize utility functions This CL adds the following additional bounds checks: * Adds additional checks of the index of the std::vector before accessing the element at the index * Changes the array index operator [] to the checked std::vector::at method Bug: 256589724 Test: mma Merged-In: I3461c9e33b64e7d44bb3b430c8eb00d794669037 Change-Id: I3461c9e33b64e7d44bb3b430c8eb00d794669037 (cherry picked from commit 9525bc4a6a63acf97513288dbdf8c48b5382c8d8) Merged-In: I3461c9e33b64e7d44bb3b430c8eb00d794669037 --- nn/common/ExecutionBurstController.cpp | 20 +++++++++++-------- nn/common/ExecutionBurstServer.cpp | 36 ++++++++++++++++++++-------------- 2 files changed, 33 insertions(+), 23 deletions(-) (limited to 'nn') diff --git a/nn/common/ExecutionBurstController.cpp b/nn/common/ExecutionBurstController.cpp index 8463df895..1415e641e 100644 --- a/nn/common/ExecutionBurstController.cpp +++ b/nn/common/ExecutionBurstController.cpp @@ -157,13 +157,14 @@ std::optional, Timing>> d size_t index = 0; // validate packet information - if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::packetInformation) { LOG(ERROR) << "FMQ Result packet ill-formed"; return std::nullopt; } // unpackage packet information - const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation(); + const FmqResultDatum::PacketInformation& packetInfo = data.at(index).packetInformation(); index++; const uint32_t packetSize = packetInfo.packetSize; const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus; @@ -178,13 +179,14 @@ std::optional, Timing>> d // unpackage operands for (size_t operand = 0; operand < numberOfOperands; ++operand) { // validate operand information - if (data[index].getDiscriminator() != discriminator::operandInformation) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::operandInformation) { LOG(ERROR) << "FMQ Result packet ill-formed"; return std::nullopt; } // unpackage operand information - const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation(); + const FmqResultDatum::OperandInformation& operandInfo = data.at(index).operandInformation(); index++; const bool isSufficient = operandInfo.isSufficient; const uint32_t numberOfDimensions = operandInfo.numberOfDimensions; @@ -194,13 +196,14 @@ std::optional, Timing>> d dimensions.reserve(numberOfDimensions); for (size_t i = 0; i < numberOfDimensions; ++i) { // validate dimension - if (data[index].getDiscriminator() != discriminator::operandDimensionValue) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::operandDimensionValue) { LOG(ERROR) << "FMQ Result packet ill-formed"; return std::nullopt; } // unpackage dimension - const uint32_t dimension = data[index].operandDimensionValue(); + const uint32_t dimension = data.at(index).operandDimensionValue(); index++; // store result @@ -212,13 +215,14 @@ std::optional, Timing>> d } // validate execution timing - if (data[index].getDiscriminator() != discriminator::executionTiming) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::executionTiming) { LOG(ERROR) << "FMQ Result packet ill-formed"; return std::nullopt; } // unpackage execution timing - const Timing timing = data[index].executionTiming(); + const Timing timing = data.at(index).executionTiming(); index++; // validate packet information diff --git a/nn/common/ExecutionBurstServer.cpp b/nn/common/ExecutionBurstServer.cpp index 583ebf553..8c9123268 100644 --- a/nn/common/ExecutionBurstServer.cpp +++ b/nn/common/ExecutionBurstServer.cpp @@ -168,13 +168,14 @@ std::optional, MeasureTiming>> de size_t index = 0; // validate packet information - if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::packetInformation) { LOG(ERROR) << "FMQ Request packet ill-formed"; return std::nullopt; } // unpackage packet information - const FmqRequestDatum::PacketInformation& packetInfo = data[index].packetInformation(); + const FmqRequestDatum::PacketInformation& packetInfo = data.at(index).packetInformation(); index++; const uint32_t packetSize = packetInfo.packetSize; const uint32_t numberOfInputOperands = packetInfo.numberOfInputOperands; @@ -192,14 +193,15 @@ std::optional, MeasureTiming>> de inputs.reserve(numberOfInputOperands); for (size_t operand = 0; operand < numberOfInputOperands; ++operand) { // validate input operand information - if (data[index].getDiscriminator() != discriminator::inputOperandInformation) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::inputOperandInformation) { LOG(ERROR) << "FMQ Request packet ill-formed"; return std::nullopt; } // unpackage operand information const FmqRequestDatum::OperandInformation& operandInfo = - data[index].inputOperandInformation(); + data.at(index).inputOperandInformation(); index++; const bool hasNoValue = operandInfo.hasNoValue; const DataLocation location = operandInfo.location; @@ -210,13 +212,14 @@ std::optional, MeasureTiming>> de dimensions.reserve(numberOfDimensions); for (size_t i = 0; i < numberOfDimensions; ++i) { // validate dimension - if (data[index].getDiscriminator() != discriminator::inputOperandDimensionValue) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::inputOperandDimensionValue) { LOG(ERROR) << "FMQ Request packet ill-formed"; return std::nullopt; } // unpackage dimension - const uint32_t dimension = data[index].inputOperandDimensionValue(); + const uint32_t dimension = data.at(index).inputOperandDimensionValue(); index++; // store result @@ -233,14 +236,15 @@ std::optional, MeasureTiming>> de outputs.reserve(numberOfOutputOperands); for (size_t operand = 0; operand < numberOfOutputOperands; ++operand) { // validate output operand information - if (data[index].getDiscriminator() != discriminator::outputOperandInformation) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::outputOperandInformation) { LOG(ERROR) << "FMQ Request packet ill-formed"; return std::nullopt; } // unpackage operand information const FmqRequestDatum::OperandInformation& operandInfo = - data[index].outputOperandInformation(); + data.at(index).outputOperandInformation(); index++; const bool hasNoValue = operandInfo.hasNoValue; const DataLocation location = operandInfo.location; @@ -251,13 +255,14 @@ std::optional, MeasureTiming>> de dimensions.reserve(numberOfDimensions); for (size_t i = 0; i < numberOfDimensions; ++i) { // validate dimension - if (data[index].getDiscriminator() != discriminator::outputOperandDimensionValue) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::outputOperandDimensionValue) { LOG(ERROR) << "FMQ Request packet ill-formed"; return std::nullopt; } // unpackage dimension - const uint32_t dimension = data[index].outputOperandDimensionValue(); + const uint32_t dimension = data.at(index).outputOperandDimensionValue(); index++; // store result @@ -274,13 +279,14 @@ std::optional, MeasureTiming>> de slots.reserve(numberOfPools); for (size_t pool = 0; pool < numberOfPools; ++pool) { // validate input operand information - if (data[index].getDiscriminator() != discriminator::poolIdentifier) { + if (index >= data.size() || + data.at(index).getDiscriminator() != discriminator::poolIdentifier) { LOG(ERROR) << "FMQ Request packet ill-formed"; return std::nullopt; } // unpackage operand information - const int32_t poolId = data[index].poolIdentifier(); + const int32_t poolId = data.at(index).poolIdentifier(); index++; // store result @@ -288,18 +294,18 @@ std::optional, MeasureTiming>> de } // validate measureTiming - if (data[index].getDiscriminator() != discriminator::measureTiming) { + if (index >= data.size() || data.at(index).getDiscriminator() != discriminator::measureTiming) { LOG(ERROR) << "FMQ Request packet ill-formed"; return std::nullopt; } // unpackage measureTiming - const MeasureTiming measure = data[index].measureTiming(); + const MeasureTiming measure = data.at(index).measureTiming(); index++; // validate packet information if (index != packetSize) { - LOG(ERROR) << "FMQ Result packet ill-formed"; + LOG(ERROR) << "FMQ Request packet ill-formed"; return std::nullopt; } -- cgit v1.2.3