aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSpencer Fricke <115671160+spencer-lunarg@users.noreply.github.com>2024-02-22 07:52:13 +0900
committerGitHub <noreply@github.com>2024-02-21 17:52:13 -0500
commit1b643eac5d4062bbec48b912a1332e6909802479 (patch)
tree05264183273afe3fac8ceca4a4e138ad346e71d1
parentdc6676445be97ab19d8191fee019af62e2aaf774 (diff)
downloadSPIRV-Tools-1b643eac5d4062bbec48b912a1332e6909802479.tar.gz
spirv-val: Make Constant evaluation consistent (#5587)
Bring 64-bit evaluation in line with 32-bit evaluation.
-rw-r--r--source/val/validate_builtins.cpp2
-rw-r--r--source/val/validate_composites.cpp4
-rw-r--r--source/val/validate_extensions.cpp4
-rw-r--r--source/val/validate_image.cpp4
-rw-r--r--source/val/validate_memory.cpp17
-rw-r--r--source/val/validate_non_uniform.cpp13
-rw-r--r--source/val/validate_type.cpp43
-rw-r--r--source/val/validation_state.cpp47
-rw-r--r--source/val/validation_state.h12
-rw-r--r--test/val/val_id_test.cpp2
10 files changed, 78 insertions, 70 deletions
diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp
index 42fbc52a..a7e9942a 100644
--- a/source/val/validate_builtins.cpp
+++ b/source/val/validate_builtins.cpp
@@ -1120,7 +1120,7 @@ spv_result_t BuiltInsValidator::ValidateF32ArrHelper(
if (num_components != 0) {
uint64_t actual_num_components = 0;
- if (!_.GetConstantValUint64(type_inst->word(3), &actual_num_components)) {
+ if (!_.EvalConstantValUint64(type_inst->word(3), &actual_num_components)) {
assert(0 && "Array type definition is corrupt");
}
if (actual_num_components != num_components) {
diff --git a/source/val/validate_composites.cpp b/source/val/validate_composites.cpp
index ed043b68..26486dac 100644
--- a/source/val/validate_composites.cpp
+++ b/source/val/validate_composites.cpp
@@ -94,7 +94,7 @@ spv_result_t GetExtractInsertValueType(ValidationState_t& _,
break;
}
- if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
+ if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
assert(0 && "Array type definition is corrupt");
}
if (component_index >= array_size) {
@@ -289,7 +289,7 @@ spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
}
uint64_t array_size = 0;
- if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) {
+ if (!_.EvalConstantValUint64(array_inst->word(3), &array_size)) {
assert(0 && "Array type definition is corrupt");
}
diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp
index 0334b606..7b73c9c6 100644
--- a/source/val/validate_extensions.cpp
+++ b/source/val/validate_extensions.cpp
@@ -3100,7 +3100,7 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) {
uint32_t vector_count = inst->word(6);
uint64_t const_val;
- if (!_.GetConstantValUint64(vector_count, &const_val)) {
+ if (!_.EvalConstantValUint64(vector_count, &const_val)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< ext_inst_name()
<< ": Vector Count must be 32-bit integer OpConstant";
@@ -3191,7 +3191,7 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) {
uint32_t component_count = inst->word(6);
if (vulkanDebugInfo) {
uint64_t const_val;
- if (!_.GetConstantValUint64(component_count, &const_val)) {
+ if (!_.EvalConstantValUint64(component_count, &const_val)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< ext_inst_name()
<< ": Component Count must be 32-bit integer OpConstant";
diff --git a/source/val/validate_image.cpp b/source/val/validate_image.cpp
index 46a32f24..543f345e 100644
--- a/source/val/validate_image.cpp
+++ b/source/val/validate_image.cpp
@@ -495,7 +495,7 @@ spv_result_t ValidateImageOperands(ValidationState_t& _,
}
uint64_t array_size = 0;
- if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
+ if (!_.EvalConstantValUint64(type_inst->word(3), &array_size)) {
assert(0 && "Array type definition is corrupt");
}
@@ -1210,7 +1210,7 @@ spv_result_t ValidateImageTexelPointer(ValidationState_t& _,
if (info.multisampled == 0) {
uint64_t ms = 0;
- if (!_.GetConstantValUint64(inst->GetOperandAs<uint32_t>(4), &ms) ||
+ if (!_.EvalConstantValUint64(inst->GetOperandAs<uint32_t>(4), &ms) ||
ms != 0) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Sample for Image with MS 0 to be a valid <id> for "
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index 5b25eeb3..41dd71e9 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -1374,22 +1374,18 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
case spv::Op::OpTypeStruct: {
// In case of structures, there is an additional constraint on the
// index: the index must be an OpConstant.
- if (spv::Op::OpConstant != cur_word_instr->opcode()) {
+ int64_t cur_index;
+ if (!_.EvalConstantValInt64(cur_word, &cur_index)) {
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
<< "The <id> passed to " << instr_name
<< " to index into a "
"structure must be an OpConstant.";
}
- // Get the index value from the OpConstant (word 3 of OpConstant).
- // OpConstant could be a signed integer. But it's okay to treat it as
- // unsigned because a negative constant int would never be seen as
- // correct as a struct offset, since structs can't have more than 2
- // billion members.
- const uint32_t cur_index = cur_word_instr->word(3);
+
// The index points to the struct member we want, therefore, the index
// should be less than the number of struct members.
- const uint32_t num_struct_members =
- static_cast<uint32_t>(type_pointee->words().size() - 2);
+ const int64_t num_struct_members =
+ static_cast<int64_t>(type_pointee->words().size() - 2);
if (cur_index >= num_struct_members) {
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
<< "Index is out of bounds: " << instr_name
@@ -1400,7 +1396,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
<< num_struct_members - 1 << ".";
}
// Struct members IDs start at word 2 of OpTypeStruct.
- auto structMemberId = type_pointee->word(cur_index + 2);
+ const size_t word_index = static_cast<size_t>(cur_index) + 2;
+ auto structMemberId = type_pointee->word(word_index);
type_pointee = _.FindDef(structMemberId);
break;
}
diff --git a/source/val/validate_non_uniform.cpp b/source/val/validate_non_uniform.cpp
index 74449e9d..75967d2f 100644
--- a/source/val/validate_non_uniform.cpp
+++ b/source/val/validate_non_uniform.cpp
@@ -389,20 +389,25 @@ spv_result_t ValidateGroupNonUniformRotateKHR(ValidationState_t& _,
if (inst->words().size() > 6) {
const uint32_t cluster_size_op_id = inst->GetOperandAs<uint32_t>(5);
- const uint32_t cluster_size_type = _.GetTypeId(cluster_size_op_id);
+ const Instruction* cluster_size_inst = _.FindDef(cluster_size_op_id);
+ const uint32_t cluster_size_type =
+ cluster_size_inst ? cluster_size_inst->type_id() : 0;
if (!_.IsUnsignedIntScalarType(cluster_size_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "ClusterSize must be a scalar of integer type, whose "
"Signedness operand is 0.";
}
- uint64_t cluster_size;
- if (!_.GetConstantValUint64(cluster_size_op_id, &cluster_size)) {
+ if (!spvOpcodeIsConstant(cluster_size_inst->opcode())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "ClusterSize must come from a constant instruction.";
}
- if ((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0)) {
+ uint64_t cluster_size;
+ const bool valid_const =
+ _.EvalConstantValUint64(cluster_size_op_id, &cluster_size);
+ if (valid_const &&
+ ((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0))) {
return _.diag(SPV_WARNING, inst)
<< "Behavior is undefined unless ClusterSize is at least 1 and a "
"power of 2.";
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index 7edd12ff..cb26a527 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -24,21 +24,6 @@ namespace spvtools {
namespace val {
namespace {
-// Returns, as an int64_t, the literal value from an OpConstant or the
-// default value of an OpSpecConstant, assuming it is an integral type.
-// For signed integers, relies the rule that literal value is sign extended
-// to fill out to word granularity. Assumes that the constant value
-// has
-int64_t ConstantLiteralAsInt64(uint32_t width,
- const std::vector<uint32_t>& const_words) {
- const uint32_t lo_word = const_words[3];
- if (width <= 32) return int32_t(lo_word);
- assert(width <= 64);
- assert(const_words.size() > 4);
- const uint32_t hi_word = const_words[4]; // Must exist, per spec.
- return static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
-}
-
// Validates that type declarations are unique, unless multiple declarations
// of the same data type are allowed by the specification.
// (see section 2.8 Types and Variables)
@@ -252,29 +237,17 @@ spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) {
<< " is not a constant integer type.";
}
- switch (length->opcode()) {
- case spv::Op::OpSpecConstant:
- case spv::Op::OpConstant: {
- auto& type_words = const_result_type->words();
- const bool is_signed = type_words[3] > 0;
- const uint32_t width = type_words[2];
- const int64_t ivalue = ConstantLiteralAsInt64(width, length->words());
- if (ivalue == 0 || (ivalue < 0 && is_signed)) {
- return _.diag(SPV_ERROR_INVALID_ID, inst)
- << "OpTypeArray Length <id> " << _.getIdName(length_id)
- << " default value must be at least 1: found " << ivalue;
- }
- } break;
- case spv::Op::OpConstantNull:
+ int64_t length_value;
+ if (_.EvalConstantValInt64(length_id, &length_value)) {
+ auto& type_words = const_result_type->words();
+ const bool is_signed = type_words[3] > 0;
+ if (length_value == 0 || (length_value < 0 && is_signed)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpTypeArray Length <id> " << _.getIdName(length_id)
- << " default value must be at least 1.";
- case spv::Op::OpSpecConstantOp:
- // Assume it's OK, rather than try to evaluate the operation.
- break;
- default:
- assert(0 && "bug in spvOpcodeIsConstant() or result type isn't int");
+ << " default value must be at least 1: found " << length_value;
+ }
}
+
return SPV_SUCCESS;
}
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 25b374de..fa5ae3e0 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -1209,7 +1209,7 @@ bool ValidationState_t::IsCooperativeMatrixAType(uint32_t id) const {
if (!IsCooperativeMatrixKHRType(id)) return false;
const Instruction* inst = FindDef(id);
uint64_t matrixUse = 0;
- if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+ if (EvalConstantValUint64(inst->word(6), &matrixUse)) {
return matrixUse ==
static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixAKHR);
}
@@ -1220,7 +1220,7 @@ bool ValidationState_t::IsCooperativeMatrixBType(uint32_t id) const {
if (!IsCooperativeMatrixKHRType(id)) return false;
const Instruction* inst = FindDef(id);
uint64_t matrixUse = 0;
- if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+ if (EvalConstantValUint64(inst->word(6), &matrixUse)) {
return matrixUse ==
static_cast<uint64_t>(spv::CooperativeMatrixUse::MatrixBKHR);
}
@@ -1230,7 +1230,7 @@ bool ValidationState_t::IsCooperativeMatrixAccType(uint32_t id) const {
if (!IsCooperativeMatrixKHRType(id)) return false;
const Instruction* inst = FindDef(id);
uint64_t matrixUse = 0;
- if (GetConstantValUint64(inst->word(6), &matrixUse)) {
+ if (EvalConstantValUint64(inst->word(6), &matrixUse)) {
return matrixUse == static_cast<uint64_t>(
spv::CooperativeMatrixUse::MatrixAccumulatorKHR);
}
@@ -1340,20 +1340,23 @@ uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
}
-bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
+bool ValidationState_t::EvalConstantValUint64(uint32_t id,
+ uint64_t* val) const {
const Instruction* inst = FindDef(id);
if (!inst) {
assert(0 && "Instruction not found");
return false;
}
- if (inst->opcode() != spv::Op::OpConstant &&
- inst->opcode() != spv::Op::OpSpecConstant)
- return false;
-
if (!IsIntScalarType(inst->type_id())) return false;
- if (inst->words().size() == 4) {
+ if (inst->opcode() == spv::Op::OpConstantNull) {
+ *val = 0;
+ } else if (inst->opcode() != spv::Op::OpConstant) {
+ // Spec constant values cannot be evaluated so don't consider constant for
+ // static validation
+ return false;
+ } else if (inst->words().size() == 4) {
*val = inst->word(3);
} else {
assert(inst->words().size() == 5);
@@ -1363,6 +1366,32 @@ bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
return true;
}
+bool ValidationState_t::EvalConstantValInt64(uint32_t id, int64_t* val) const {
+ const Instruction* inst = FindDef(id);
+ if (!inst) {
+ assert(0 && "Instruction not found");
+ return false;
+ }
+
+ if (!IsIntScalarType(inst->type_id())) return false;
+
+ if (inst->opcode() == spv::Op::OpConstantNull) {
+ *val = 0;
+ } else if (inst->opcode() != spv::Op::OpConstant) {
+ // Spec constant values cannot be evaluated so don't consider constant for
+ // static validation
+ return false;
+ } else if (inst->words().size() == 4) {
+ *val = int32_t(inst->word(3));
+ } else {
+ assert(inst->words().size() == 5);
+ const uint32_t lo_word = inst->word(3);
+ const uint32_t hi_word = inst->word(4);
+ *val = static_cast<int64_t>(uint64_t(lo_word) | uint64_t(hi_word) << 32);
+ }
+ return true;
+}
+
std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
uint32_t id) const {
const Instruction* const inst = FindDef(id);
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 46a8cbfa..27acdcc2 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -648,10 +648,6 @@ class ValidationState_t {
const std::function<bool(const Instruction*)>& f,
bool traverse_all_types = true) const;
- // Gets value from OpConstant and OpSpecConstant as uint64.
- // Returns false on failure (no instruction, wrong instruction, not int).
- bool GetConstantValUint64(uint32_t id, uint64_t* val) const;
-
// Returns type_id if id has type or zero otherwise.
uint32_t GetTypeId(uint32_t id) const;
@@ -726,6 +722,14 @@ class ValidationState_t {
pointer_to_storage_image_.insert(type_id);
}
+ // Tries to evaluate a any scalar integer OpConstant as uint64.
+ // OpConstantNull is defined as zero for scalar int (will return true)
+ // OpSpecConstant* return false since their values cannot be relied upon
+ // during validation.
+ bool EvalConstantValUint64(uint32_t id, uint64_t* val) const;
+ // Same as EvalConstantValUint64 but returns a signed int
+ bool EvalConstantValInt64(uint32_t id, int64_t* val) const;
+
// Tries to evaluate a 32-bit signed or unsigned scalar integer constant.
// Returns tuple <is_int32, is_const_int32, value>.
// OpSpecConstant* return |is_const_int32| as false since their values cannot
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 7acac563..e236134b 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -1056,7 +1056,7 @@ TEST_P(ValidateIdWithMessage, OpTypeArrayLengthNull) {
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr(make_message("OpTypeArray Length <id> '2[%2]' default "
- "value must be at least 1.")));
+ "value must be at least 1: found 0")));
}
TEST_P(ValidateIdWithMessage, OpTypeArrayLengthSpecConst) {