aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source/val/validate.cpp1
-rw-r--r--source/val/validate.h5
-rw-r--r--source/val/validate_cfg.cpp172
-rw-r--r--source/val/validate_id.cpp181
4 files changed, 177 insertions, 182 deletions
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index bcb54d14..24d8686b 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -336,6 +336,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
if (auto error = ArithmeticsPass(*vstate, &instruction)) return error;
if (auto error = BitwisePass(*vstate, &instruction)) return error;
if (auto error = LogicalsPass(*vstate, &instruction)) return error;
+ if (auto error = ControlFlowPass(*vstate, &instruction)) return error;
if (auto error = DerivativesPass(*vstate, &instruction)) return error;
if (auto error = AtomicsPass(*vstate, &instruction)) return error;
if (auto error = PrimitivesPass(*vstate, &instruction)) return error;
diff --git a/source/val/validate.h b/source/val/validate.h
index 5135a472..5d941ca0 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -116,9 +116,12 @@ void printDominatorList(BasicBlock& block);
/// spec.
spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst);
-/// Performs Control Flow Graph validation of a module
+/// Performs Control Flow Graph validation and construction.
spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst);
+/// Validates Control Flow Graph instructions.
+spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst);
+
/// Performs Id and SSA validation of a module
spv_result_t IdPass(ValidationState_t& _, Instruction* inst);
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index e369d06e..167f2273 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -18,6 +18,7 @@
#include <cassert>
#include <functional>
#include <iostream>
+#include <iterator>
#include <map>
#include <string>
#include <tuple>
@@ -27,6 +28,7 @@
#include <vector>
#include "source/cfa.h"
+#include "source/opcode.h"
#include "source/spirv_validator_options.h"
#include "source/val/basic_block.h"
#include "source/val/construct.h"
@@ -35,6 +37,158 @@
namespace spvtools {
namespace val {
+namespace {
+
+spv_result_t ValidatePhi(ValidationState_t& _, const Instruction* inst) {
+ SpvOp type_op = _.GetIdOpcode(inst->type_id());
+ if (!spvOpcodeGeneratesType(type_op)) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpPhi's type <id> " << _.getIdName(inst->type_id())
+ << " is not a type instruction.";
+ }
+
+ auto block = inst->block();
+ size_t num_in_ops = inst->words().size() - 3;
+ if (num_in_ops % 2 != 0) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpPhi does not have an equal number of incoming values and "
+ "basic blocks.";
+ }
+
+ // Create a uniqued vector of predecessor ids for comparison against
+ // incoming values. OpBranchConditional %cond %label %label produces two
+ // predecessors in the CFG.
+ std::vector<uint32_t> pred_ids;
+ std::transform(block->predecessors()->begin(), block->predecessors()->end(),
+ std::back_inserter(pred_ids),
+ [](const BasicBlock* b) { return b->id(); });
+ std::sort(pred_ids.begin(), pred_ids.end());
+ pred_ids.erase(std::unique(pred_ids.begin(), pred_ids.end()), pred_ids.end());
+
+ size_t num_edges = num_in_ops / 2;
+ if (num_edges != pred_ids.size()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpPhi's number of incoming blocks (" << num_edges
+ << ") does not match block's predecessor count ("
+ << block->predecessors()->size() << ").";
+ }
+
+ for (size_t i = 3; i < inst->words().size(); ++i) {
+ auto inc_id = inst->word(i);
+ if (i % 2 == 1) {
+ // Incoming value type must match the phi result type.
+ auto inc_type_id = _.GetTypeId(inc_id);
+ if (inst->type_id() != inc_type_id) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpPhi's result type <id> " << _.getIdName(inst->type_id())
+ << " does not match incoming value <id> " << _.getIdName(inc_id)
+ << " type <id> " << _.getIdName(inc_type_id) << ".";
+ }
+ } else {
+ if (_.GetIdOpcode(inc_id) != SpvOpLabel) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpPhi's incoming basic block <id> " << _.getIdName(inc_id)
+ << " is not an OpLabel.";
+ }
+
+ // Incoming basic block must be an immediate predecessor of the phi's
+ // block.
+ if (!std::binary_search(pred_ids.begin(), pred_ids.end(), inc_id)) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpPhi's incoming basic block <id> " << _.getIdName(inc_id)
+ << " is not a predecessor of <id> " << _.getIdName(block->id())
+ << ".";
+ }
+ }
+ }
+
+ return SPV_SUCCESS;
+}
+
+spv_result_t ValidateBranchConditional(ValidationState_t& _,
+ const Instruction* inst) {
+ // num_operands is either 3 or 5 --- if 5, the last two need to be literal
+ // integers
+ const auto num_operands = inst->operands().size();
+ if (num_operands != 3 && num_operands != 5) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpBranchConditional requires either 3 or 5 parameters";
+ }
+
+ // grab the condition operand and check that it is a bool
+ const auto cond_id = inst->GetOperandAs<uint32_t>(0);
+ const auto cond_op = _.FindDef(cond_id);
+ if (!cond_op || !_.IsBoolScalarType(cond_op->type_id())) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst) << "Condition operand for "
+ "OpBranchConditional must be "
+ "of boolean type";
+ }
+
+ // target operands must be OpLabel
+ // note that we don't need to check that the target labels are in the same
+ // function,
+ // PerformCfgChecks already checks for that
+ const auto true_id = inst->GetOperandAs<uint32_t>(1);
+ const auto true_target = _.FindDef(true_id);
+ if (!true_target || SpvOpLabel != true_target->opcode()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "The 'True Label' operand for OpBranchConditional must be the "
+ "ID of an OpLabel instruction";
+ }
+
+ const auto false_id = inst->GetOperandAs<uint32_t>(2);
+ const auto false_target = _.FindDef(false_id);
+ if (!false_target || SpvOpLabel != false_target->opcode()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "The 'False Label' operand for OpBranchConditional must be the "
+ "ID of an OpLabel instruction";
+ }
+
+ return SPV_SUCCESS;
+}
+
+spv_result_t ValidateReturnValue(ValidationState_t& _,
+ const Instruction* inst) {
+ const auto value_id = inst->GetOperandAs<uint32_t>(0);
+ const auto value = _.FindDef(value_id);
+ if (!value || !value->type_id()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpReturnValue Value <id> '" << _.getIdName(value_id)
+ << "' does not represent a value.";
+ }
+ auto value_type = _.FindDef(value->type_id());
+ if (!value_type || SpvOpTypeVoid == value_type->opcode()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpReturnValue value's type <id> '"
+ << _.getIdName(value->type_id()) << "' is missing or void.";
+ }
+
+ const bool uses_variable_pointer =
+ _.features().variable_pointers ||
+ _.features().variable_pointers_storage_buffer;
+
+ if (_.addressing_model() == SpvAddressingModelLogical &&
+ SpvOpTypePointer == value_type->opcode() && !uses_variable_pointer &&
+ !_.options()->relax_logical_pointer) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpReturnValue value's type <id> '"
+ << _.getIdName(value->type_id())
+ << "' is a pointer, which is invalid in the Logical addressing "
+ "model.";
+ }
+
+ const auto function = inst->function();
+ const auto return_type = _.FindDef(function->GetResultTypeId());
+ if (!return_type || return_type->id() != value_type->id()) {
+ return _.diag(SPV_ERROR_INVALID_ID, inst)
+ << "OpReturnValue Value <id> '" << _.getIdName(value_id)
+ << "'s type does not match OpFunction's return type.";
+ }
+
+ return SPV_SUCCESS;
+}
+
+} // namespace
void printDominatorList(const BasicBlock& b) {
std::cout << b.id() << " is dominated by: ";
@@ -596,5 +750,23 @@ spv_result_t CfgPass(ValidationState_t& _, const Instruction* inst) {
return SPV_SUCCESS;
}
+spv_result_t ControlFlowPass(ValidationState_t& _, const Instruction* inst) {
+ switch (inst->opcode()) {
+ case SpvOpPhi:
+ if (auto error = ValidatePhi(_, inst)) return error;
+ break;
+ case SpvOpBranchConditional:
+ if (auto error = ValidateBranchConditional(_, inst)) return error;
+ break;
+ case SpvOpReturnValue:
+ if (auto error = ValidateReturnValue(_, inst)) return error;
+ break;
+ default:
+ break;
+ }
+
+ return SPV_SUCCESS;
+}
+
} // namespace val
} // namespace spvtools
diff --git a/source/val/validate_id.cpp b/source/val/validate_id.cpp
index c2d35fa3..72883682 100644
--- a/source/val/validate_id.cpp
+++ b/source/val/validate_id.cpp
@@ -784,182 +784,6 @@ bool idUsage::isValid<SpvOpFunctionCall>(const spv_instruction_t* inst,
return true;
}
-template <>
-bool idUsage::isValid<SpvOpPhi>(const spv_instruction_t* inst,
- const spv_opcode_desc /*opcodeEntry*/) {
- auto thisInst = module_.FindDef(inst->words[2]);
- SpvOp typeOp = module_.GetIdOpcode(thisInst->type_id());
- if (!spvOpcodeGeneratesType(typeOp)) {
- DIAG(thisInst) << "OpPhi's type <id> "
- << module_.getIdName(thisInst->type_id())
- << " is not a type instruction.";
- return false;
- }
-
- auto block = thisInst->block();
- size_t numInOps = inst->words.size() - 3;
- if (numInOps % 2 != 0) {
- DIAG(thisInst)
- << "OpPhi does not have an equal number of incoming values and "
- "basic blocks.";
- return false;
- }
-
- // Create a uniqued vector of predecessor ids for comparison against
- // incoming values. OpBranchConditional %cond %label %label produces two
- // predecessors in the CFG.
- std::vector<uint32_t> predIds;
- std::transform(block->predecessors()->begin(), block->predecessors()->end(),
- std::back_inserter(predIds),
- [](const BasicBlock* b) { return b->id(); });
- std::sort(predIds.begin(), predIds.end());
- predIds.erase(std::unique(predIds.begin(), predIds.end()), predIds.end());
-
- size_t numEdges = numInOps / 2;
- if (numEdges != predIds.size()) {
- DIAG(thisInst) << "OpPhi's number of incoming blocks (" << numEdges
- << ") does not match block's predecessor count ("
- << block->predecessors()->size() << ").";
- return false;
- }
-
- for (size_t i = 3; i < inst->words.size(); ++i) {
- auto incId = inst->words[i];
- if (i % 2 == 1) {
- // Incoming value type must match the phi result type.
- auto incTypeId = module_.GetTypeId(incId);
- if (thisInst->type_id() != incTypeId) {
- DIAG(thisInst) << "OpPhi's result type <id> "
- << module_.getIdName(thisInst->type_id())
- << " does not match incoming value <id> "
- << module_.getIdName(incId) << " type <id> "
- << module_.getIdName(incTypeId) << ".";
- return false;
- }
- } else {
- if (module_.GetIdOpcode(incId) != SpvOpLabel) {
- DIAG(thisInst) << "OpPhi's incoming basic block <id> "
- << module_.getIdName(incId) << " is not an OpLabel.";
- return false;
- }
-
- // Incoming basic block must be an immediate predecessor of the phi's
- // block.
- if (!std::binary_search(predIds.begin(), predIds.end(), incId)) {
- DIAG(thisInst) << "OpPhi's incoming basic block <id> "
- << module_.getIdName(incId)
- << " is not a predecessor of <id> "
- << module_.getIdName(block->id()) << ".";
- return false;
- }
- }
- }
-
- return true;
-}
-
-template <>
-bool idUsage::isValid<SpvOpBranchConditional>(const spv_instruction_t* inst,
- const spv_opcode_desc) {
- const size_t numOperands = inst->words.size() - 1;
- const size_t condOperandIndex = 1;
- const size_t targetTrueIndex = 2;
- const size_t targetFalseIndex = 3;
-
- // num_operands is either 3 or 5 --- if 5, the last two need to be literal
- // integers
- if (numOperands != 3 && numOperands != 5) {
- Instruction* fake_inst = nullptr;
- DIAG(fake_inst) << "OpBranchConditional requires either 3 or 5 parameters";
- return false;
- }
-
- bool ret = true;
-
- // grab the condition operand and check that it is a bool
- const auto condOp = module_.FindDef(inst->words[condOperandIndex]);
- if (!condOp || !module_.IsBoolScalarType(condOp->type_id())) {
- DIAG(condOp)
- << "Condition operand for OpBranchConditional must be of boolean type";
- ret = false;
- }
-
- // target operands must be OpLabel
- // note that we don't need to check that the target labels are in the same
- // function,
- // PerformCfgChecks already checks for that
- const auto targetOpTrue = module_.FindDef(inst->words[targetTrueIndex]);
- if (!targetOpTrue || SpvOpLabel != targetOpTrue->opcode()) {
- DIAG(targetOpTrue)
- << "The 'True Label' operand for OpBranchConditional must be the "
- "ID of an OpLabel instruction";
- ret = false;
- }
-
- const auto targetOpFalse = module_.FindDef(inst->words[targetFalseIndex]);
- if (!targetOpFalse || SpvOpLabel != targetOpFalse->opcode()) {
- DIAG(targetOpFalse)
- << "The 'False Label' operand for OpBranchConditional must be the "
- "ID of an OpLabel instruction";
- ret = false;
- }
-
- return ret;
-}
-
-template <>
-bool idUsage::isValid<SpvOpReturnValue>(const spv_instruction_t* inst,
- const spv_opcode_desc) {
- auto valueIndex = 1;
- auto value = module_.FindDef(inst->words[valueIndex]);
- if (!value || !value->type_id()) {
- DIAG(value) << "OpReturnValue Value <id> '"
- << module_.getIdName(inst->words[valueIndex])
- << "' does not represent a value.";
- return false;
- }
- auto valueType = module_.FindDef(value->type_id());
- if (!valueType || SpvOpTypeVoid == valueType->opcode()) {
- DIAG(value) << "OpReturnValue value's type <id> '"
- << module_.getIdName(value->type_id())
- << "' is missing or void.";
- return false;
- }
-
- const bool uses_variable_pointer =
- module_.features().variable_pointers ||
- module_.features().variable_pointers_storage_buffer;
-
- if (addressingModel == SpvAddressingModelLogical &&
- SpvOpTypePointer == valueType->opcode() && !uses_variable_pointer &&
- !module_.options()->relax_logical_pointer) {
- DIAG(value)
- << "OpReturnValue value's type <id> '"
- << module_.getIdName(value->type_id())
- << "' is a pointer, which is invalid in the Logical addressing model.";
- return false;
- }
-
- // NOTE: Find OpFunction
- const spv_instruction_t* function = inst - 1;
- while (firstInst != function) {
- if (SpvOpFunction == function->opcode) break;
- function--;
- }
- if (SpvOpFunction != function->opcode) {
- DIAG(value) << "OpReturnValue is not in a basic block.";
- return false;
- }
- auto returnType = module_.FindDef(function->words[1]);
- if (!returnType || returnType->id() != valueType->id()) {
- DIAG(value) << "OpReturnValue Value <id> '"
- << module_.getIdName(inst->words[valueIndex])
- << "'s type does not match OpFunction's return type.";
- return false;
- }
- return true;
-}
-
#undef DIAG
bool idUsage::isValid(const spv_instruction_t* inst) {
@@ -988,11 +812,6 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
// Bitwise opcodes are validated in validate_bitwise.cpp.
// Logical opcodes are validated in validate_logicals.cpp.
// Derivative opcodes are validated in validate_derivatives.cpp.
- CASE(OpPhi)
- // OpBranch is validated in validate_cfg.cpp.
- // See tests in test/val/val_cfg_test.cpp.
- CASE(OpBranchConditional)
- CASE(OpReturnValue)
default:
return true;
}