aboutsummaryrefslogtreecommitdiff
path: root/source/val/validate_extensions.cpp
diff options
context:
space:
mode:
authorRyan Harrison <zoddicus@users.noreply.github.com>2018-11-28 10:49:05 -0500
committerGitHub <noreply@github.com>2018-11-28 10:49:05 -0500
commit3ee605d7ccb960345a454bad57e54238c66bfb05 (patch)
tree7557aaf2aa08282255a66d4e5cbfc290967d627c /source/val/validate_extensions.cpp
parent703305b1a5a2a06ea75510b954dea1dd2489fa46 (diff)
downloadSPIRV-Tools-3ee605d7ccb960345a454bad57e54238c66bfb05.tar.gz
Ensure that only whitelisted extensions are used in WebGPU (#2127)
Fixes #2058
Diffstat (limited to 'source/val/validate_extensions.cpp')
-rw-r--r--source/val/validate_extensions.cpp18
1 files changed, 18 insertions, 0 deletions
diff --git a/source/val/validate_extensions.cpp b/source/val/validate_extensions.cpp
index fe38f1f2..f264c8e7 100644
--- a/source/val/validate_extensions.cpp
+++ b/source/val/validate_extensions.cpp
@@ -21,6 +21,8 @@
#include <vector>
#include "source/diagnostic.h"
+#include "source/enum_string_mapping.h"
+#include "source/extensions.h"
#include "source/latest_version_glsl_std_450_header.h"
#include "source/latest_version_opencl_std_header.h"
#include "source/opcode.h"
@@ -42,6 +44,21 @@ uint32_t GetSizeTBitWidth(const ValidationState_t& _) {
} // anonymous namespace
+spv_result_t ValidateExtension(ValidationState_t& _, const Instruction* inst) {
+ if (spvIsWebGPUEnv(_.context()->target_env)) {
+ std::string extension = GetExtensionString(&(inst->c_inst()));
+
+ if (extension != ExtensionToString(kSPV_KHR_vulkan_memory_model)) {
+ return _.diag(SPV_ERROR_INVALID_DATA, inst)
+ << "For WebGPU, the only valid parameter to OpExtension is "
+ << "\"" << ExtensionToString(kSPV_KHR_vulkan_memory_model)
+ << "\".";
+ }
+ }
+
+ return SPV_SUCCESS;
+}
+
spv_result_t ValidateExtInstImport(ValidationState_t& _,
const Instruction* inst) {
if (spvIsWebGPUEnv(_.context()->target_env)) {
@@ -2001,6 +2018,7 @@ spv_result_t ValidateExtInst(ValidationState_t& _, const Instruction* inst) {
spv_result_t ExtensionPass(ValidationState_t& _, const Instruction* inst) {
const SpvOp opcode = inst->opcode();
+ if (opcode == SpvOpExtension) return ValidateExtension(_, inst);
if (opcode == SpvOpExtInstImport) return ValidateExtInstImport(_, inst);
if (opcode == SpvOpExtInst) return ValidateExtInst(_, inst);