// Copyright (c) 2018 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "source/opcode.h" #include "source/val/instruction.h" #include "source/val/validate.h" #include "source/val/validation_state.h" namespace spvtools { namespace val { namespace { // Returns true if |a| and |b| are instructions defining pointers that point to // types logically match and the decorations that apply to |b| are a subset // of the decorations that apply to |a|. bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b, ValidationState_t& _) { if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) { return false; } const auto& dec_a = _.id_decorations(a->id()); const auto& dec_b = _.id_decorations(b->id()); for (const auto& dec : dec_b) { if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) { return false; } } uint32_t a_type = a->GetOperandAs(2); uint32_t b_type = b->GetOperandAs(2); if (a_type == b_type) { return true; } Instruction* a_type_inst = _.FindDef(a_type); Instruction* b_type_inst = _.FindDef(b_type); return _.LogicallyMatch(a_type_inst, b_type_inst, true); } spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { const auto function_type_id = inst->GetOperandAs(3); const auto function_type = _.FindDef(function_type_id); if (!function_type || SpvOpTypeFunction != function_type->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunction Function Type '" << _.getIdName(function_type_id) << "' is not a function type."; } const auto return_id = function_type->GetOperandAs(1); if (return_id != inst->type_id()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunction Result Type '" << _.getIdName(inst->type_id()) << "' does not match the Function Type's return type '" << _.getIdName(return_id) << "'."; } const std::vector acceptable = { SpvOpGroupDecorate, SpvOpDecorate, SpvOpEnqueueKernel, SpvOpEntryPoint, SpvOpExecutionMode, SpvOpExecutionModeId, SpvOpFunctionCall, SpvOpGetKernelNDrangeSubGroupCount, SpvOpGetKernelNDrangeMaxSubGroupSize, SpvOpGetKernelWorkGroupSize, SpvOpGetKernelPreferredWorkGroupSizeMultiple, SpvOpGetKernelLocalSizeForSubgroupCount, SpvOpGetKernelMaxNumSubgroups, SpvOpName}; for (auto& pair : inst->uses()) { const auto* use = pair.first; if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == acceptable.end() && !use->IsNonSemantic() && !use->IsDebugInfo()) { return _.diag(SPV_ERROR_INVALID_ID, use) << "Invalid use of function result id " << _.getIdName(inst->id()) << "."; } } return SPV_SUCCESS; } spv_result_t ValidateFunctionParameter(ValidationState_t& _, const Instruction* inst) { // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. size_t param_index = 0; size_t inst_num = inst->LineNum() - 1; if (inst_num == 0) { return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Function parameter cannot be the first instruction."; } auto func_inst = &_.ordered_instructions()[inst_num]; while (--inst_num) { func_inst = &_.ordered_instructions()[inst_num]; if (func_inst->opcode() == SpvOpFunction) { break; } else if (func_inst->opcode() == SpvOpFunctionParameter) { ++param_index; } } if (func_inst->opcode() != SpvOpFunction) { return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) << "Function parameter must be preceded by a function."; } const auto function_type_id = func_inst->GetOperandAs(3); const auto function_type = _.FindDef(function_type_id); if (!function_type) { return _.diag(SPV_ERROR_INVALID_ID, func_inst) << "Missing function type definition."; } if (param_index >= function_type->words().size() - 3) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Too many OpFunctionParameters for " << func_inst->id() << ": expected " << function_type->words().size() - 3 << " based on the function's type"; } const auto param_type = _.FindDef(function_type->GetOperandAs(param_index + 2)); if (!param_type || inst->type_id() != param_type->id()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionParameter Result Type '" << _.getIdName(inst->type_id()) << "' does not match the OpTypeFunction parameter " "type of the same index."; } // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased, // RestrictPointerEXT, or AliasedPointerEXT. auto param_nonarray_type_id = param_type->id(); while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) { param_nonarray_type_id = _.FindDef(param_nonarray_type_id)->GetOperandAs(1u); } if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) { auto param_nonarray_type = _.FindDef(param_nonarray_type_id); if (param_nonarray_type->GetOperandAs(1u) == SpvStorageClassPhysicalStorageBufferEXT) { // check for Aliased or Restrict const auto& decorations = _.id_decorations(inst->id()); bool foundAliased = std::any_of( decorations.begin(), decorations.end(), [](const Decoration& d) { return SpvDecorationAliased == d.dec_type(); }); bool foundRestrict = std::any_of( decorations.begin(), decorations.end(), [](const Decoration& d) { return SpvDecorationRestrict == d.dec_type(); }); if (!foundAliased && !foundRestrict) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionParameter " << inst->id() << ": expected Aliased or Restrict for PhysicalStorageBufferEXT " "pointer."; } if (foundAliased && foundRestrict) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionParameter " << inst->id() << ": can't specify both Aliased and Restrict for " "PhysicalStorageBufferEXT pointer."; } } else { const auto pointee_type_id = param_nonarray_type->GetOperandAs(2); const auto pointee_type = _.FindDef(pointee_type_id); if (SpvOpTypePointer == pointee_type->opcode() && pointee_type->GetOperandAs(1u) == SpvStorageClassPhysicalStorageBufferEXT) { // check for AliasedPointerEXT/RestrictPointerEXT const auto& decorations = _.id_decorations(inst->id()); bool foundAliased = std::any_of( decorations.begin(), decorations.end(), [](const Decoration& d) { return SpvDecorationAliasedPointerEXT == d.dec_type(); }); bool foundRestrict = std::any_of( decorations.begin(), decorations.end(), [](const Decoration& d) { return SpvDecorationRestrictPointerEXT == d.dec_type(); }); if (!foundAliased && !foundRestrict) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionParameter " << inst->id() << ": expected AliasedPointerEXT or RestrictPointerEXT for " "PhysicalStorageBufferEXT pointer."; } if (foundAliased && foundRestrict) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionParameter " << inst->id() << ": can't specify both AliasedPointerEXT and " "RestrictPointerEXT for PhysicalStorageBufferEXT pointer."; } } } } return SPV_SUCCESS; } spv_result_t ValidateFunctionCall(ValidationState_t& _, const Instruction* inst) { const auto function_id = inst->GetOperandAs(2); const auto function = _.FindDef(function_id); if (!function || SpvOpFunction != function->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionCall Function '" << _.getIdName(function_id) << "' is not a function."; } auto return_type = _.FindDef(function->type_id()); if (!return_type || return_type->id() != inst->type_id()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionCall Result Type '" << _.getIdName(inst->type_id()) << "'s type does not match Function '" << _.getIdName(return_type->id()) << "'s return type."; } const auto function_type_id = function->GetOperandAs(3); const auto function_type = _.FindDef(function_type_id); if (!function_type || function_type->opcode() != SpvOpTypeFunction) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Missing function type definition."; } const auto function_call_arg_count = inst->words().size() - 4; const auto function_param_count = function_type->words().size() - 3; if (function_param_count != function_call_arg_count) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionCall Function 's parameter count does not match " "the argument count."; } for (size_t argument_index = 3, param_index = 2; argument_index < inst->operands().size(); argument_index++, param_index++) { const auto argument_id = inst->GetOperandAs(argument_index); const auto argument = _.FindDef(argument_id); if (!argument) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Missing argument " << argument_index - 3 << " definition."; } const auto argument_type = _.FindDef(argument->type_id()); if (!argument_type) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Missing argument " << argument_index - 3 << " type definition."; } const auto parameter_type_id = function_type->GetOperandAs(param_index); const auto parameter_type = _.FindDef(parameter_type_id); if (!parameter_type || argument_type->id() != parameter_type->id()) { if (!_.options()->before_hlsl_legalization || !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpFunctionCall Argument '" << _.getIdName(argument_id) << "'s type does not match Function '" << _.getIdName(parameter_type_id) << "'s parameter type."; } } if (_.addressing_model() == SpvAddressingModelLogical) { if (parameter_type->opcode() == SpvOpTypePointer && !_.options()->relax_logical_pointer) { SpvStorageClass sc = parameter_type->GetOperandAs(1u); // Validate which storage classes can be pointer operands. switch (sc) { case SpvStorageClassUniformConstant: case SpvStorageClassFunction: case SpvStorageClassPrivate: case SpvStorageClassWorkgroup: case SpvStorageClassAtomicCounter: // These are always allowed. break; case SpvStorageClassStorageBuffer: if (!_.features().variable_pointers_storage_buffer) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "StorageBuffer pointer operand " << _.getIdName(argument_id) << " requires a variable pointers capability"; } break; default: return _.diag(SPV_ERROR_INVALID_ID, inst) << "Invalid storage class for pointer operand " << _.getIdName(argument_id); } // Validate memory object declaration requirements. if (argument->opcode() != SpvOpVariable && argument->opcode() != SpvOpFunctionParameter) { const bool ssbo_vptr = _.features().variable_pointers_storage_buffer && sc == SpvStorageClassStorageBuffer; const bool wg_vptr = _.features().variable_pointers && sc == SpvStorageClassWorkgroup; const bool uc_ptr = sc == SpvStorageClassUniformConstant; if (!ssbo_vptr && !wg_vptr && !uc_ptr) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Pointer operand " << _.getIdName(argument_id) << " must be a memory object declaration"; } } } } } return SPV_SUCCESS; } } // namespace spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { switch (inst->opcode()) { case SpvOpFunction: if (auto error = ValidateFunction(_, inst)) return error; break; case SpvOpFunctionParameter: if (auto error = ValidateFunctionParameter(_, inst)) return error; break; case SpvOpFunctionCall: if (auto error = ValidateFunctionCall(_, inst)) return error; break; default: break; } return SPV_SUCCESS; } } // namespace val } // namespace spvtools