aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShiyu Liu <jamessliu2020@gmail.com>2021-06-02 11:46:56 -0500
committerGitHub <noreply@github.com>2021-06-02 17:46:56 +0100
commit26cdce984fe723dd922cb5aae0f9c8b4e8b9ea19 (patch)
tree8324a4fe08514fd053cbead62d27756ec4cb9b2b
parentc853a91144b2054da2c22584f45debda2baa7a69 (diff)
downloadSPIRV-Tools-26cdce984fe723dd922cb5aae0f9c8b4e8b9ea19.tar.gz
spirv-fuzz: add tests for full coverage of TransformationAccessChain (#4304)
Fixes #4286 by achieving full coverage of the transformation.
-rw-r--r--source/fuzz/transformation_access_chain.cpp16
-rw-r--r--source/fuzz/transformation_access_chain.h8
-rw-r--r--test/fuzz/transformation_access_chain_test.cpp63
3 files changed, 76 insertions, 11 deletions
diff --git a/source/fuzz/transformation_access_chain.cpp b/source/fuzz/transformation_access_chain.cpp
index 97d4e041..3fe9e656 100644
--- a/source/fuzz/transformation_access_chain.cpp
+++ b/source/fuzz/transformation_access_chain.cpp
@@ -122,7 +122,7 @@ bool TransformationAccessChain::IsApplicable(
bool successful;
std::tie(successful, index_value) =
- GetIndexValue(ir_context, index_id, subobject_type_id);
+ GetStructIndexValue(ir_context, index_id, subobject_type_id);
if (!successful) {
return false;
@@ -247,7 +247,7 @@ void TransformationAccessChain::Apply(
// It is a struct: we need to retrieve the integer value.
index_value =
- GetIndexValue(ir_context, index_id, subobject_type_id).second;
+ GetStructIndexValue(ir_context, index_id, subobject_type_id).second;
new_index_id = index_id;
@@ -363,9 +363,12 @@ protobufs::Transformation TransformationAccessChain::ToMessage() const {
return result;
}
-std::pair<bool, uint32_t> TransformationAccessChain::GetIndexValue(
+std::pair<bool, uint32_t> TransformationAccessChain::GetStructIndexValue(
opt::IRContext* ir_context, uint32_t index_id,
uint32_t object_type_id) const {
+ assert(ir_context->get_def_use_mgr()->GetDef(object_type_id)->opcode() ==
+ SpvOpTypeStruct &&
+ "Precondition: the type must be a struct type.");
if (!ValidIndexToComposite(ir_context, index_id, object_type_id)) {
return {false, 0};
}
@@ -374,10 +377,9 @@ std::pair<bool, uint32_t> TransformationAccessChain::GetIndexValue(
uint32_t bound = fuzzerutil::GetBoundForCompositeIndex(
*ir_context->get_def_use_mgr()->GetDef(object_type_id), ir_context);
- // The index must be a constant
- if (!spvOpcodeIsConstant(index_instruction->opcode())) {
- return {false, 0};
- }
+ // Ensure that the index given must represent a constant.
+ assert(spvOpcodeIsConstant(index_instruction->opcode()) &&
+ "A non-constant index should already have been rejected.");
// The index must be in bounds.
uint32_t value = index_instruction->GetSingleWordInOperand(0);
diff --git a/source/fuzz/transformation_access_chain.h b/source/fuzz/transformation_access_chain.h
index 5582de39..4e4fd2b6 100644
--- a/source/fuzz/transformation_access_chain.h
+++ b/source/fuzz/transformation_access_chain.h
@@ -83,13 +83,13 @@ class TransformationAccessChain : public Transformation {
private:
// Returns {false, 0} in each of the following cases:
// - |index_id| does not correspond to a 32-bit integer constant
- // - the object being indexed is not a composite type
+ // - |object_type_id| must be a struct type
// - the constant at |index_id| is out of bounds.
// Otherwise, returns {true, value}, where value is the value of the constant
// at |index_id|.
- std::pair<bool, uint32_t> GetIndexValue(opt::IRContext* ir_context,
- uint32_t index_id,
- uint32_t object_type_id) const;
+ std::pair<bool, uint32_t> GetStructIndexValue(opt::IRContext* ir_context,
+ uint32_t index_id,
+ uint32_t object_type_id) const;
// Returns true if |index_id| corresponds, in the given context, to a 32-bit
// integer which can be used to index an object of the type specified by
diff --git a/test/fuzz/transformation_access_chain_test.cpp b/test/fuzz/transformation_access_chain_test.cpp
index 5c431279..e7919816 100644
--- a/test/fuzz/transformation_access_chain_test.cpp
+++ b/test/fuzz/transformation_access_chain_test.cpp
@@ -127,6 +127,16 @@ TEST(TransformationAccessChainTest, BasicTest) {
transformation_context.GetFactManager()->AddFactValueOfPointeeIsIrrelevant(
54);
+ // Check the case where the index type is not a 32-bit integer.
+ TransformationAccessChain invalid_index_example1(
+ 101, 28, {29}, MakeInstructionDescriptor(42, SpvOpReturn, 0));
+
+ // Since the index is not a 32-bit integer type but a 32-bit float type,
+ // ValidIndexComposite should return false and thus the transformation is not
+ // applicable.
+ ASSERT_FALSE(invalid_index_example1.IsApplicable(context.get(),
+ transformation_context));
+
// Bad: id is not fresh
ASSERT_FALSE(TransformationAccessChain(
43, 43, {80}, MakeInstructionDescriptor(24, SpvOpLoad, 0))
@@ -304,6 +314,20 @@ TEST(TransformationAccessChainTest, BasicTest) {
ASSERT_FALSE(
transformation_context.GetFactManager()->PointeeValueIsIrrelevant(107));
}
+ {
+ // Check the case where the access chain's base pointer has the irrelevant
+ // pointee fact; the resulting access chain should inherit this fact.
+ TransformationAccessChain transformation(
+ 107, 54, {}, MakeInstructionDescriptor(24, SpvOpLoad, 0));
+ ASSERT_TRUE(
+ transformation.IsApplicable(context.get(), transformation_context));
+ ApplyAndCheckFreshIds(transformation, context.get(),
+ &transformation_context);
+ ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(
+ context.get(), validator_options, kConsoleMessageConsumer));
+ ASSERT_TRUE(
+ transformation_context.GetFactManager()->PointeeValueIsIrrelevant(54));
+ }
std::string after_transformation = R"(
OpCapability Shader
@@ -383,6 +407,7 @@ TEST(TransformationAccessChainTest, BasicTest) {
%23 = OpConvertFToS %10 %22
%100 = OpAccessChain %70 %43 %80
%106 = OpAccessChain %11 %14
+ %107 = OpAccessChain %53 %54
%24 = OpLoad %10 %14
%25 = OpIAdd %10 %23 %24
OpReturnValue %25
@@ -391,6 +416,44 @@ TEST(TransformationAccessChainTest, BasicTest) {
ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
}
+TEST(TransformationAccessChainTest, StructIndexMustBeConstant) {
+ std::string shader = R"(
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %4 "main"
+ OpExecutionMode %4 OriginUpperLeft
+ OpSource ESSL 320
+ %2 = OpTypeVoid
+ %3 = OpTypeFunction %2
+ %6 = OpTypeInt 32 1
+ %20 = OpUndef %6
+ %7 = OpTypeStruct %6 %6
+ %8 = OpTypePointer Function %7
+ %10 = OpConstant %6 0
+ %11 = OpConstant %6 2
+ %12 = OpTypePointer Function %6
+ %4 = OpFunction %2 None %3
+ %5 = OpLabel
+ %9 = OpVariable %8 Function
+ OpReturn
+ OpFunctionEnd
+ )";
+
+ const auto env = SPV_ENV_UNIVERSAL_1_4;
+ const auto consumer = nullptr;
+ const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
+ spvtools::ValidatorOptions validator_options;
+ ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
+ kConsoleMessageConsumer));
+ TransformationContext transformation_context(
+ MakeUnique<FactManager>(context.get()), validator_options);
+ // Bad: %9 is a pointer to a struct, but %20 is not a constant.
+ ASSERT_FALSE(TransformationAccessChain(
+ 100, 9, {20}, MakeInstructionDescriptor(9, SpvOpReturn, 0))
+ .IsApplicable(context.get(), transformation_context));
+}
+
TEST(TransformationAccessChainTest, IsomorphicStructs) {
std::string shader = R"(
OpCapability Shader