aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Neto <dneto@google.com>2017-08-01 11:55:23 -0400
committerDavid Neto <dneto@google.com>2017-08-01 11:56:10 -0400
commit9d77a0cc68f9d90d327fda406a46060d4c1f86a1 (patch)
tree3636590b88594fbbdbc262f52b8d580f7981c29d
parent4d55e9f66f91a62d5493c866cfeec77891b4cf55 (diff)
parent7954740d542786acd071310a8978ffb3e2042b7c (diff)
downloadspirv-tools-ndk-r16-release.tar.gz
Merge remote-tracking branch 'aosp/upstream-master' into update-shadercndk-r16-beta1ndk-r16-release
Includes: 7954740 Opt: Delete names and decorations of dead instructions 9f6efc7 Opt: HasOnlySupportedRefs should consider OpCopyObject 4a539d7 Revert "Revert "Opt: LocalBlockElim: Add HasOnlySupportedRefs"" 1182415 Add extension whitelists to size-reduction passes. df96e24 Revert "Opt: LocalBlockElim: Add HasOnlySupportedRefs" 2d0f7fb Opt: LocalBlockElim: Add HasOnlySupportedRefs adb237f Fix handling of CopyObject in GetPtr and its call sites e9e4393 Fix Visual Studio size_t cast compiler warning fe24e03 LocalMultiStore: Always put varId for backedge on loop phi function. e2544dd DeadBranchElim: Improve algorithm to only remove blocks with no predecessors 06d4fd5 Minor code review feedback on AggressiveDCE 9de4e69 Add AggressiveDCEPass cc8bad3 Add LocalMultiStoreElim pass 52e247f DeadBranchElim: Add DeadBranchElimPass 35a0695 Include memory and semantics IDs when iterating over inbound IDs abc6f5a MARK-V decoder supports extended instructions 826d968 Update CHANGES to say we use GNUIntallDirs fd70a1d Define variable to skip installation 78338d5 Convert pattern stack from deque to vector, and share it e842c17 Added fixed width encoding to bit_stream 73e8dac Added compression tool tools/spirv-markv. Work in progress. 8d3882a Added log(n) move-to-front implementation 40a2829 Added Huffman codec to utils 65ea885 Travis CI: stop requiring sudo and use make instead of ninja d431b69 Don't do hash lookup twice in FindDef c14966b Move spv_instruction_t's into vector 1cd47d7 Reserve expected length of instructions vector fcd991f Move some temp vectors into parser state ad1d035 BlockMerge: Add BlockMergePass 0b0454c Update CHANGES 5fbbadc Add support for SPV AMD extensions 6136bf9 mem2reg: Add InsertExtractElimPass 760789f Transform multiple entry points 0c5722f mem2reg: Add LocalSingleStoreElimPass Test: ndk/checkbuild.py on Linux; unit tests on Windows Change-Id: Iaf0022decf13c2b60146ecd145b818eb0e021867
-rw-r--r--.travis.yml21
-rw-r--r--CHANGES24
-rw-r--r--CMakeLists.txt23
-rw-r--r--include/spirv-tools/libspirv.h3
-rw-r--r--include/spirv-tools/markv.h91
-rw-r--r--include/spirv-tools/optimizer.hpp133
-rw-r--r--source/CMakeLists.txt14
-rw-r--r--source/assembly_grammar.cpp4
-rw-r--r--source/assembly_grammar.h16
-rw-r--r--source/binary.cpp67
-rw-r--r--source/comp/CMakeLists.txt34
-rw-r--r--source/comp/markv_codec.cpp1556
-rw-r--r--source/ext_inst.cpp27
-rw-r--r--source/extinst.spv-amd-shader-ballot.grammar.json41
-rw-r--r--source/extinst.spv-amd-shader-explicit-vertex-parameter.grammar.json14
-rw-r--r--source/extinst.spv-amd-shader-trinary-minmax.grammar.json95
-rw-r--r--source/operand.cpp63
-rw-r--r--source/operand.h31
-rw-r--r--source/opt/CMakeLists.txt23
-rw-r--r--source/opt/aggressive_dead_code_elim_pass.cpp546
-rw-r--r--source/opt/aggressive_dead_code_elim_pass.h147
-rw-r--r--source/opt/basic_block.h17
-rw-r--r--source/opt/block_merge_pass.cpp189
-rw-r--r--source/opt/block_merge_pass.h82
-rw-r--r--source/opt/dead_branch_elim_pass.cpp418
-rw-r--r--source/opt/dead_branch_elim_pass.h159
-rw-r--r--source/opt/def_use_manager.cpp34
-rw-r--r--source/opt/def_use_manager.h6
-rw-r--r--source/opt/inline_pass.cpp2
-rw-r--r--source/opt/insert_extract_elim.cpp169
-rw-r--r--source/opt/insert_extract_elim.h87
-rw-r--r--source/opt/instruction.h24
-rw-r--r--source/opt/local_access_chain_convert_pass.cpp200
-rw-r--r--source/opt/local_access_chain_convert_pass.h29
-rw-r--r--source/opt/local_single_block_elim_pass.cpp195
-rw-r--r--source/opt/local_single_block_elim_pass.h37
-rw-r--r--source/opt/local_single_store_elim_pass.cpp585
-rw-r--r--source/opt/local_single_store_elim_pass.h248
-rw-r--r--source/opt/local_ssa_elim_pass.cpp825
-rw-r--r--source/opt/local_ssa_elim_pass.h268
-rw-r--r--source/opt/module.cpp9
-rw-r--r--source/opt/module.h16
-rw-r--r--source/opt/optimizer.cpp30
-rw-r--r--source/opt/passes.h6
-rw-r--r--source/text.cpp19
-rw-r--r--source/util/bit_stream.cpp72
-rw-r--r--source/util/bit_stream.h68
-rw-r--r--source/util/huffman_codec.h299
-rw-r--r--source/util/move_to_front.h649
-rw-r--r--source/val/validation_state.cpp20
-rw-r--r--source/val/validation_state.h2
-rw-r--r--source/validate.cpp4
-rw-r--r--test/CMakeLists.txt13
-rw-r--r--test/bit_stream.cpp148
-rw-r--r--test/comp/CMakeLists.txt23
-rw-r--r--test/comp/markv_codec_test.cpp433
-rw-r--r--test/huffman_codec.cpp220
-rw-r--r--test/move_to_front_test.cpp785
-rw-r--r--test/operand_pattern_test.cpp90
-rw-r--r--test/opt/CMakeLists.txt30
-rw-r--r--test/opt/aggressive_dead_code_elim_test.cpp689
-rw-r--r--test/opt/block_merge_test.cpp337
-rw-r--r--test/opt/dead_branch_elim_test.cpp905
-rw-r--r--test/opt/insert_extract_elim_test.cpp334
-rw-r--r--test/opt/instruction_test.cpp73
-rw-r--r--test/opt/local_single_block_elim.cpp154
-rw-r--r--test/opt/local_single_store_elim_test.cpp655
-rw-r--r--test/opt/local_ssa_elim_test.cpp1239
-rw-r--r--test/text_to_binary.extension_test.cpp106
-rw-r--r--test/val/val_extensions_test.cpp6
-rw-r--r--tools/CMakeLists.txt17
-rw-r--r--tools/comp/markv.cpp247
-rw-r--r--tools/emacs/CMakeLists.txt4
-rw-r--r--tools/lesspipe/CMakeLists.txt4
-rw-r--r--tools/opt/opt.cpp12
-rwxr-xr-xutils/generate_grammar_tables.py9
76 files changed, 13964 insertions, 310 deletions
diff --git a/.travis.yml b/.travis.yml
index 848ee47a..bf677c04 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -8,7 +8,7 @@ os:
# Use Ubuntu 14.04 LTS (Trusty) as the Linux testing environment.
dist: trusty
-sudo: required
+sudo: false
# Use the default Xcode environment for Xcode.
@@ -35,18 +35,14 @@ matrix:
cache:
apt: true
+git:
+ depth: 1
+
branches:
only:
- master
-addons:
- apt:
- packages:
- - ninja-build
-
before_install:
- # Install cmake & ninja on macOS.
- - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then brew install ninja; fi
- if [[ "$BUILD_NDK" == "ON" ]]; then
git clone --depth=1 https://github.com/urho3d/android-ndk.git $HOME/android-ndk;
export ANDROID_NDK=$HOME/android-ndk;
@@ -65,12 +61,13 @@ script:
-DANDROID_NATIVE_API_LEVEL=android-9
-DCMAKE_BUILD_TYPE=Release
-DANDROID_ABI="armeabi-v7a with NEON"
- -DSPIRV_SKIP_TESTS=ON
- -GNinja ..;
+ -DSPIRV_SKIP_TESTS=ON ..;
else
- cmake -GNinja -DCMAKE_BUILD_TYPE=${BUILD_TYPE} ..;
+ cmake -DCMAKE_BUILD_TYPE=${BUILD_TYPE} ..;
fi
- - ninja
+ # Due to the limitation of Travis platform, we cannot start too many concurrent jobs.
+ # Otherwise GCC will panic with internal error, possibility because of memory issues.
+ - make -j4
- if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then
export NPROC=`nproc`;
else
diff --git a/CHANGES b/CHANGES
index 493b9758..417d62e7 100644
--- a/CHANGES
+++ b/CHANGES
@@ -3,15 +3,31 @@ Revision history for SPIRV-Tools
v2016.7-dev 2017-01-06
- Add SPIR-V 1.2
- OpenCL 2.2 support is now based on SPIR-V 1.2
- - Optimizer: Add inlining of all function calls in entry points.
- - Optimizer: Add flattening of decoration groups. Fixes #602
- - Optimizer: Add Id compaction (minimize Id bound). Fixes #624
+ - Support AMD extensions in assembler, disassembler:
+ SPV_AMD_gcn_shader
+ SPV_AMD_shader_ballot
+ SPV_AMD_shader_explicit_vertex_parameter
+ SPV_AMD_shader_trinary_minmax
+ SPV_AMD_gpu_shader_half_float
+ SPV_AMD_texture_gather_bias_lod
+ SPV_AMD_gpu_shader_int16
+ - Optimizer: Add support for:
+ - Inline all function calls in entry points.
+ - Flatten of decoration groups. Fixes #602
+ - Id compaction (minimize Id bound). Fixes #624
+ - Eliminate redundant composite insert followed by extract
+ - Simplify access chains to local variables
+ - Eliminate local variables with a single store, if possible
+ - Eliminate loads and stores in same block to local variables
- Assembler: Add option to preserve numeric ids. Fixes #625
- Add build target spirv-tools-vimsyntax to generate spvasm.vim, a SPIR-V
assembly syntax file for Vim.
- Version string: Allow overriding of wall clock timestamp with contents
of environment variable SOURCE_DATE_EPOCH.
- Validator implements relaxed rules for SPV_KHR_16bit_storage.
+ - CMake installation rules use GNUInstallDirs. For example, libraries
+ will be installed into a lib64 directory if that's the norm for the
+ current system.
- Fixes:
#500: Parameterize validator limit checks
#508: Support compilation under CYGWIN
@@ -25,6 +41,8 @@ v2016.7-dev 2017-01-06
binary vector when all passes succeded without changes.
#629: The inline-entry-points-all optimization could generate invalidly
structured code when the inlined function had early returns.
+ #697: Optimizer's Instruction::ForEachInId method was skipping semantics-id
+ and scope-id.
v2016.6 2016-12-13
- Published the C++ interface for assembling, disassembling, validation, and
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7da83753..56896f91 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -25,6 +25,8 @@ project(spirv-tools)
enable_testing()
set(SPIRV_TOOLS "SPIRV-Tools")
+include(GNUInstallDirs)
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
if("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux")
@@ -48,6 +50,11 @@ if ("${CMAKE_BUILD_TYPE}" STREQUAL "")
set(CMAKE_BUILD_TYPE "Debug")
endif()
+option(SKIP_SPIRV_TOOLS_INSTALL "Skip installation" ${SKIP_SPIRV_TOOLS_INSTALL})
+if(NOT ${SKIP_SPIRV_TOOLS_INSTALL})
+ set(ENABLE_SPIRV_TOOLS_INSTALL ON)
+endif()
+
option(SPIRV_WERROR "Enable error on warning" ON)
if(("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang"))
set(COMPILER_IS_LIKE_GNU TRUE)
@@ -165,13 +172,15 @@ add_subdirectory(tools)
add_subdirectory(test)
add_subdirectory(examples)
-install(
- FILES
- ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/libspirv.h
- ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/libspirv.hpp
- ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/optimizer.hpp
- DESTINATION
- include/spirv-tools/)
+if(ENABLE_SPIRV_TOOLS_INSTALL)
+ install(
+ FILES
+ ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/libspirv.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/libspirv.hpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/include/spirv-tools/optimizer.hpp
+ DESTINATION
+ ${CMAKE_INSTALL_INCLUDEDIR}/spirv-tools/)
+endif(ENABLE_SPIRV_TOOLS_INSTALL)
add_test(NAME spirv-tools-copyrights
COMMAND ${PYTHON_EXECUTABLE} utils/check_copyright.py
diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h
index 73e72cfe..b7bcc0f5 100644
--- a/include/spirv-tools/libspirv.h
+++ b/include/spirv-tools/libspirv.h
@@ -227,7 +227,10 @@ typedef enum spv_ext_inst_type_t {
SPV_EXT_INST_TYPE_NONE = 0,
SPV_EXT_INST_TYPE_GLSL_STD_450,
SPV_EXT_INST_TYPE_OPENCL_STD,
+ SPV_EXT_INST_TYPE_SPV_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER,
+ SPV_EXT_INST_TYPE_SPV_AMD_SHADER_TRINARY_MINMAX,
SPV_EXT_INST_TYPE_SPV_AMD_GCN_SHADER,
+ SPV_EXT_INST_TYPE_SPV_AMD_SHADER_BALLOT,
SPV_FORCE_32_BIT_ENUM(spv_ext_inst_type_t)
} spv_ext_inst_type_t;
diff --git a/include/spirv-tools/markv.h b/include/spirv-tools/markv.h
new file mode 100644
index 00000000..9941d435
--- /dev/null
+++ b/include/spirv-tools/markv.h
@@ -0,0 +1,91 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// MARK-V is a compression format for SPIR-V binaries. It strips away
+// non-essential information (such as result ids which can be regenerated) and
+// uses various bit reduction techiniques to reduce the size of the binary.
+//
+// WIP: MARK-V codec is in early stages of development. At the moment it only
+// can encode and decode some SPIR-V files and only if exacly the same build of
+// software is used (is doesn't write or handle version numbers yet).
+
+#ifndef SPIRV_TOOLS_MARKV_H_
+#define SPIRV_TOOLS_MARKV_H_
+
+#include "libspirv.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef struct spv_markv_binary_t {
+ uint8_t* data;
+ size_t length;
+} spv_markv_binary_t;
+
+typedef spv_markv_binary_t* spv_markv_binary;
+typedef const spv_markv_binary_t* const_spv_markv_binary;
+
+typedef struct spv_markv_encoder_options_t spv_markv_encoder_options_t;
+typedef spv_markv_encoder_options_t* spv_markv_encoder_options;
+typedef const spv_markv_encoder_options_t* spv_const_markv_encoder_options;
+
+typedef struct spv_markv_decoder_options_t spv_markv_decoder_options_t;
+typedef spv_markv_decoder_options_t* spv_markv_decoder_options;
+typedef const spv_markv_decoder_options_t* spv_const_markv_decoder_options;
+
+// Creates spv_markv_encoder_options with default options. Returns a valid
+// options object. The object remains valid until it is passed into
+// spvMarkvEncoderOptionsDestroy.
+spv_markv_encoder_options spvMarkvEncoderOptionsCreate();
+
+// Destroys the given spv_markv_encoder_options object.
+void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options);
+
+// Creates spv_markv_decoder_options with default options. Returns a valid
+// options object. The object remains valid until it is passed into
+// spvMarkvDecoderOptionsDestroy.
+spv_markv_decoder_options spvMarkvDecoderOptionsCreate();
+
+// Destroys the given spv_markv_decoder_options object.
+void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options);
+
+// Encodes the given SPIR-V binary to MARK-V binary.
+// If |comments| is not nullptr, it would contain a textual description of
+// how encoding was done (with snippets of disassembly and bit sequences).
+spv_result_t spvSpirvToMarkv(spv_const_context context,
+ const uint32_t* spirv_words,
+ size_t spirv_num_words,
+ spv_const_markv_encoder_options options,
+ spv_markv_binary* markv_binary,
+ spv_text* comments, spv_diagnostic* diagnostic);
+
+// Decodes a SPIR-V binary from the given MARK-V binary.
+// If |comments| is not nullptr, it would contain a textual description of
+// how decoding was done (with snippets of disassembly and bit sequences).
+spv_result_t spvMarkvToSpirv(spv_const_context context,
+ const uint8_t* markv_data,
+ size_t markv_size_bytes,
+ spv_const_markv_decoder_options options,
+ spv_binary* spirv_binary,
+ spv_text* comments, spv_diagnostic* diagnostic);
+
+// Destroys MARK-V binary created by spvSpirvToMarkv().
+void spvMarkvBinaryDestroy(spv_markv_binary binary);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // SPIRV_TOOLS_MARKV_H_
diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp
index 099a9445..b52de8fe 100644
--- a/include/spirv-tools/optimizer.hpp
+++ b/include/spirv-tools/optimizer.hpp
@@ -185,6 +185,22 @@ Optimizer::PassToken CreateUnifyConstantPass();
// OpSpecConstantOp.
Optimizer::PassToken CreateEliminateDeadConstantPass();
+// Creates a block merge pass.
+// This pass searches for blocks with a single Branch to a block with no
+// other predecessors and merges the blocks into a single block. Continue
+// blocks and Merge blocks are not candidates for the second block.
+//
+// The pass is most useful after Dead Branch Elimination, which can leave
+// such sequences of blocks. Merging them makes subsequent passes more
+// effective, such as single block local store-load elimination.
+//
+// While this pass reduces the number of occurrences of this sequence, at
+// this time it does not guarantee all such sequences are eliminated.
+//
+// Presence of phi instructions can inhibit this optimization. Handling
+// these is left for future improvements.
+Optimizer::PassToken CreateBlockMergePass();
+
// Creates an inline pass.
// An inline pass exhaustively inlines all function calls in all functions
// designated as an entry point. The intent is to enable, albeit through
@@ -211,9 +227,40 @@ Optimizer::PassToken CreateInlinePass();
//
// This pass is most effective if preceeded by Inlining and
// LocalAccessChainConvert. This pass will reduce the work needed to be done
-// by LocalSingleStoreElim and LocalSSARewrite.
+// by LocalSingleStoreElim and LocalMultiStoreElim.
Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass();
+// Create dead branch elimination pass.
+// For each entry point function, this pass will look for SelectionMerge
+// BranchConditionals with constant condition and convert to a Branch to
+// the indicated label. It will delete resulting dead blocks.
+//
+// This pass only works on shaders (guaranteed to have structured control
+// flow). Note that some such branches and blocks may be left to avoid
+// creating invalid control flow. Improving this is left to future work.
+//
+// This pass is most effective when preceeded by passes which eliminate
+// local loads and stores, effectively propagating constant values where
+// possible.
+Optimizer::PassToken CreateDeadBranchElimPass();
+
+// Creates an SSA local variable load/store elimination pass.
+// For every entry point function, eliminate all loads and stores of function
+// scope variables only referenced with non-access-chain loads and stores.
+// Eliminate the variables as well.
+//
+// The presence of access chain references and function calls can inhibit
+// the above optimization.
+//
+// Only shader modules with logical addressing are currently processed.
+// Currently modules with any extensions enabled are not processed. This
+// is left for future work.
+//
+// This pass is most effective if preceeded by Inlining and
+// LocalAccessChainConvert. LocalSingleStoreElim and LocalSingleBlockElim
+// will reduce the work that this pass has to do.
+Optimizer::PassToken CreateLocalMultiStoreElimPass();
+
// Creates a local access chain conversion pass.
// A local access chain conversion pass identifies all function scope
// variables which are accessed only with loads, stores and access chains
@@ -231,6 +278,90 @@ Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass();
// possible.
Optimizer::PassToken CreateLocalAccessChainConvertPass();
+// Create aggressive dead code elimination pass
+// This pass eliminates unused code from functions. In addition,
+// it detects and eliminates code which may have spurious uses but which do
+// not contribute to the output of the function. The most common cause of
+// such code sequences is summations in loops whose result is no longer used
+// due to dead code elimination. This optimization has additional compile
+// time cost over standard dead code elimination.
+//
+// This pass only processes entry point functions. It also only processes
+// shaders with logical addressing. It currently will not process functions
+// with function calls. It currently only supports the GLSL.std.450 extended
+// instruction set. It currently does not support any extensions.
+//
+// This pass will be made more effective by first running passes that remove
+// dead control flow and inlines function calls.
+//
+// This pass can be especially useful after running Local Access Chain
+// Conversion, which tends to cause cycles of dead code to be left after
+// Store/Load elimination passes are completed. These cycles cannot be
+// eliminated with standard dead code elimination.
+Optimizer::PassToken CreateAggressiveDCEPass();
+
+// Creates a local single store elimination pass.
+// For each entry point function, this pass eliminates loads and stores for
+// function scope variable that are stored to only once, where possible. Only
+// whole variable loads and stores are eliminated; access-chain references are
+// not optimized. Replace all loads of such variables with the value that is
+// stored and eliminate any resulting dead code.
+//
+// Currently, the presence of access chains and function calls can inhibit this
+// pass, however the Inlining and LocalAccessChainConvert passes can make it
+// more effective. In additional, many non-load/store memory operations are
+// not supported and will prohibit optimization of a function. Support of
+// these operations are future work.
+//
+// This pass will reduce the work needed to be done by LocalSingleBlockElim
+// and LocalMultiStoreElim and can improve the effectiveness of other passes
+// such as DeadBranchElimination which depend on values for their analysis.
+Optimizer::PassToken CreateLocalSingleStoreElimPass();
+
+// Creates an insert/extract elimination pass.
+// This pass processes each entry point function in the module, searching for
+// extracts on a sequence of inserts. It further searches the sequence for an
+// insert with indices identical to the extract. If such an insert can be
+// found before hitting a conflicting insert, the extract's result id is
+// replaced with the id of the values from the insert.
+//
+// Besides removing extracts this pass enables subsequent dead code elimination
+// passes to delete the inserts. This pass performs best after access chains are
+// converted to inserts and extracts and local loads and stores are eliminated.
+Optimizer::PassToken CreateInsertExtractElimPass();
+
+// Create dead branch elimination pass.
+// For each entry point function, this pass will look for BranchConditionals
+// with constant condition and convert to a branch. The BranchConditional must
+// be preceeded by OpSelectionMerge. For all phi functions in merge block,
+// replace all uses with the id corresponding to the living predecessor.
+//
+// This pass is most effective when preceeded by passes which eliminate
+// local loads and stores, effectively propagating constant values where
+// possible.
+Optimizer::PassToken CreateDeadBranchElimPass();
+
+// Create aggressive dead code elimination pass
+// This pass eliminates unused code from functions. In addition,
+// it detects and eliminates code which may have spurious uses but which do
+// not contribute to the output of the function. The most common cause of
+// such code sequences is summations in loops whose result is no longer used
+// due to dead code elimination. This optimization has additional compile
+// time cost over standard dead code elimination.
+//
+// This pass only processes entry point functions. It also only processes
+// shaders with logical addressing. It currently will not process functions
+// with function calls.
+//
+// This pass will be made more effective by first running passes that remove
+// dead control flow and inlines function calls.
+//
+// This pass can be especially useful after running Local Access Chain
+// Conversion, which tends to cause cycles of dead code to be left after
+// Store/Load elimination passes are completed. These cycles cannot be
+// eliminated with standard dead code elimination.
+Optimizer::PassToken CreateAggressiveDCEPass();
+
// Creates a compact ids pass.
// The pass remaps result ids to a compact and gapless range starting from %1.
Optimizer::PassToken CreateCompactIdsPass();
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index 8f8ca45e..6ea6ad26 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -115,7 +115,10 @@ spvtools_core_tables("1.2")
spvtools_enum_string_mapping("1.2")
spvtools_opencl_tables("1.0")
spvtools_glsl_tables("1.0")
+spvtools_vendor_tables("spv-amd-shader-explicit-vertex-parameter")
+spvtools_vendor_tables("spv-amd-shader-trinary-minmax")
spvtools_vendor_tables("spv-amd-gcn-shader")
+spvtools_vendor_tables("spv-amd-shader-ballot")
spvtools_vimsyntax("1.2" "1.0")
add_custom_target(spirv-tools-vimsyntax DEPENDS ${VIMSYNTAX_FILE})
@@ -185,6 +188,7 @@ add_custom_target(spirv-tools-build-version
DEPENDS ${SPIRV_TOOLS_BUILD_VERSION_INC})
set_property(TARGET spirv-tools-build-version PROPERTY FOLDER "SPIRV-Tools build")
+add_subdirectory(comp)
add_subdirectory(opt)
set(SPIRV_SOURCES
@@ -285,7 +289,9 @@ target_include_directories(${SPIRV_TOOLS}
)
set_property(TARGET ${SPIRV_TOOLS} PROPERTY FOLDER "SPIRV-Tools libraries")
-install(TARGETS ${SPIRV_TOOLS}
- RUNTIME DESTINATION bin
- LIBRARY DESTINATION lib
- ARCHIVE DESTINATION lib)
+if(ENABLE_SPIRV_TOOLS_INSTALL)
+ install(TARGETS ${SPIRV_TOOLS}
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
+endif(ENABLE_SPIRV_TOOLS_INSTALL)
diff --git a/source/assembly_grammar.cpp b/source/assembly_grammar.cpp
index 294c07a6..997c4c20 100644
--- a/source/assembly_grammar.cpp
+++ b/source/assembly_grammar.cpp
@@ -234,9 +234,9 @@ spv_result_t AssemblyGrammar::lookupExtInst(spv_ext_inst_type_t type,
return spvExtInstTableValueLookup(extInstTable_, type, firstWord, extInst);
}
-void AssemblyGrammar::prependOperandTypesForMask(
+void AssemblyGrammar::pushOperandTypesForMask(
const spv_operand_type_t type, const uint32_t mask,
spv_operand_pattern_t* pattern) const {
- spvPrependOperandTypesForMask(operandTable_, type, mask, pattern);
+ spvPushOperandTypesForMask(operandTable_, type, mask, pattern);
}
} // namespace libspirv
diff --git a/source/assembly_grammar.h b/source/assembly_grammar.h
index ac211369..cd89a1b1 100644
--- a/source/assembly_grammar.h
+++ b/source/assembly_grammar.h
@@ -95,17 +95,19 @@ class AssemblyGrammar {
spv_result_t lookupExtInst(spv_ext_inst_type_t type, uint32_t firstWord,
spv_ext_inst_desc* extInst) const;
- // Inserts the operands expected after the given typed mask onto the front
+ // Inserts the operands expected after the given typed mask onto the end
// of the given pattern.
//
- // Each set bit in the mask represents zero or more operand types that should
- // be prepended onto the pattern. Operands for a less significant bit always
- // appear before operands for a more significant bit.
+ // Each set bit in the mask represents zero or more operand types that
+ // should be appended onto the pattern. Operands for a less significant
+ // bit must always match before operands for a more significant bit, so
+ // the operands for a less significant bit must appear closer to the end
+ // of the pattern stack.
//
// If a set bit is unknown, then we assume it has no operands.
- void prependOperandTypesForMask(const spv_operand_type_t type,
- const uint32_t mask,
- spv_operand_pattern_t* pattern) const;
+ void pushOperandTypesForMask(const spv_operand_type_t type,
+ const uint32_t mask,
+ spv_operand_pattern_t* pattern) const;
private:
const spv_target_env target_env_;
diff --git a/source/binary.cpp b/source/binary.cpp
index a803def7..2ade2546 100644
--- a/source/binary.cpp
+++ b/source/binary.cpp
@@ -178,7 +178,14 @@ class Parser {
diagnostic(diagnostic_arg),
word_index(0),
endian(),
- requires_endian_conversion(false) {}
+ requires_endian_conversion(false) {
+
+ // Temporary storage for parser state within a single instruction.
+ // Most instructions require fewer than 25 words or operands.
+ operands.reserve(25);
+ endian_converted_words.reserve(25);
+ expected_operands.reserve(25);
+ }
State() : State(0, 0, nullptr) {}
const uint32_t* words; // Words in the binary SPIR-V module.
size_t num_words; // Number of words in the module.
@@ -198,6 +205,11 @@ class Parser {
// Maps an ExtInstImport id to the extended instruction type.
std::unordered_map<uint32_t, spv_ext_inst_type_t>
import_id_to_ext_inst_type;
+
+ // Used by parseOperand
+ std::vector<spv_parsed_operand_t> operands;
+ std::vector<uint32_t> endian_converted_words;
+ spv_operand_pattern_t expected_operands;
} _;
};
@@ -262,24 +274,15 @@ spv_result_t Parser::parseInstruction() {
const uint32_t first_word = peek();
- // TODO(dneto): If it's too expensive to construct the following "words"
- // and "operands" vectors for each instruction, each instruction, then make
- // them class data members instead, and clear them here.
-
// If the module's endianness is different from the host native endianness,
// then converted_words contains the the endian-translated words in the
// instruction.
- std::vector<uint32_t> endian_converted_words = {first_word};
- if (_.requires_endian_conversion) {
- // Most instructions have fewer than 25 words.
- endian_converted_words.reserve(25);
- }
+ _.endian_converted_words.clear();
+ _.endian_converted_words.push_back(first_word);
// After a successful parse of the instruction, the inst.operands member
// will point to this vector's storage.
- std::vector<spv_parsed_operand_t> operands;
- // Most instructions have fewer than 25 logical operands.
- operands.reserve(25);
+ _.operands.clear();
assert(_.word_index < _.num_words);
// Decompose and check the first word.
@@ -305,13 +308,13 @@ spv_result_t Parser::parseInstruction() {
// has its own logical operands (such as the LocalSize operand for
// ExecutionMode), or for extended instructions that may have their
// own operands depending on the selected extended instruction.
- spv_operand_pattern_t expected_operands(
- opcode_desc->operandTypes,
- opcode_desc->operandTypes + opcode_desc->numTypes);
+ _.expected_operands.clear();
+ for (auto i = 0; i < opcode_desc->numTypes; i++)
+ _.expected_operands.push_back(opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
while (_.word_index < inst_offset + inst_word_count) {
const uint16_t inst_word_index = uint16_t(_.word_index - inst_offset);
- if (expected_operands.empty()) {
+ if (_.expected_operands.empty()) {
return diagnostic() << "Invalid instruction Op" << opcode_desc->name
<< " starting at word " << inst_offset
<< ": expected no more operands after "
@@ -320,17 +323,17 @@ spv_result_t Parser::parseInstruction() {
<< inst_word_count << ".";
}
- spv_operand_type_t type = spvTakeFirstMatchableOperand(&expected_operands);
+ spv_operand_type_t type = spvTakeFirstMatchableOperand(&_.expected_operands);
if (auto error =
- parseOperand(inst_offset, &inst, type, &endian_converted_words,
- &operands, &expected_operands)) {
+ parseOperand(inst_offset, &inst, type, &_.endian_converted_words,
+ &_.operands, &_.expected_operands)) {
return error;
}
}
- if (!expected_operands.empty() &&
- !spvOperandIsOptional(expected_operands.front())) {
+ if (!_.expected_operands.empty() &&
+ !spvOperandIsOptional(_.expected_operands.back())) {
return diagnostic() << "End of input reached while decoding Op"
<< opcode_desc->name << " starting at word "
<< inst_offset << ": expected more operands after "
@@ -351,15 +354,15 @@ spv_result_t Parser::parseInstruction() {
// performed, then the vector only contains the initial opcode/word-count
// word.
assert(!_.requires_endian_conversion ||
- (inst_word_count == endian_converted_words.size()));
- assert(_.requires_endian_conversion || (endian_converted_words.size() == 1));
+ (inst_word_count == _.endian_converted_words.size()));
+ assert(_.requires_endian_conversion || (_.endian_converted_words.size() == 1));
recordNumberType(inst_offset, &inst);
if (_.requires_endian_conversion) {
// We must wait until here to set this pointer, because the vector might
// have been be resized while we accumulated its elements.
- inst.words = endian_converted_words.data();
+ inst.words = _.endian_converted_words.data();
} else {
// If no conversion is required, then just point to the underlying binary.
// This saves time and space.
@@ -369,8 +372,8 @@ spv_result_t Parser::parseInstruction() {
// We must wait until here to set this pointer, because the vector might
// have been be resized while we accumulated its elements.
- inst.operands = operands.data();
- inst.num_operands = uint16_t(operands.size());
+ inst.operands = _.operands.data();
+ inst.num_operands = uint16_t(_.operands.size());
// Issue the callback. The callee should know that all the storage in inst
// is transient, and will disappear immediately afterward.
@@ -468,7 +471,7 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
spv_ext_inst_desc ext_inst;
if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst))
return diagnostic() << "Invalid extended instruction number: " << word;
- spvPrependOperandTypes(ext_inst->operandTypes, expected_operands);
+ spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
} break;
case SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER: {
@@ -488,7 +491,7 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
assert(opcode_entry->hasType);
assert(opcode_entry->hasResult);
assert(opcode_entry->numTypes >= 2);
- spvPrependOperandTypes(opcode_entry->operandTypes + 2, expected_operands);
+ spvPushOperandTypes(opcode_entry->operandTypes + 2, expected_operands);
} break;
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
@@ -622,7 +625,7 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
<< " operand: " << word;
}
// Prepare to accept operands to this operand, if needed.
- spvPrependOperandTypes(entry->operandTypes, expected_operands);
+ spvPushOperandTypes(entry->operandTypes, expected_operands);
} break;
case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
@@ -657,7 +660,7 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
<< mask;
}
remaining_word ^= mask;
- spvPrependOperandTypes(entry->operandTypes, expected_operands);
+ spvPushOperandTypes(entry->operandTypes, expected_operands);
}
}
if (word == 0) {
@@ -665,7 +668,7 @@ spv_result_t Parser::parseOperand(size_t inst_offset,
spv_operand_desc entry;
if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
// Prepare for its operands, if any.
- spvPrependOperandTypes(entry->operandTypes, expected_operands);
+ spvPushOperandTypes(entry->operandTypes, expected_operands);
}
}
} break;
diff --git a/source/comp/CMakeLists.txt b/source/comp/CMakeLists.txt
new file mode 100644
index 00000000..11def56a
--- /dev/null
+++ b/source/comp/CMakeLists.txt
@@ -0,0 +1,34 @@
+# Copyright (c) 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_library(SPIRV-Tools-comp markv_codec.cpp)
+
+spvtools_default_compile_options(SPIRV-Tools-comp)
+target_include_directories(SPIRV-Tools-comp
+ PUBLIC ${spirv-tools_SOURCE_DIR}/include
+ PUBLIC ${SPIRV_HEADER_INCLUDE_DIR}
+ PRIVATE ${spirv-tools_BINARY_DIR}
+)
+
+target_link_libraries(SPIRV-Tools-comp
+ PUBLIC ${SPIRV_TOOLS})
+
+set_property(TARGET SPIRV-Tools-comp PROPERTY FOLDER "SPIRV-Tools libraries")
+
+if(ENABLE_SPIRV_TOOLS_INSTALL)
+ install(TARGETS SPIRV-Tools-comp
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
+endif(ENABLE_SPIRV_TOOLS_INSTALL)
diff --git a/source/comp/markv_codec.cpp b/source/comp/markv_codec.cpp
new file mode 100644
index 00000000..2288dc3d
--- /dev/null
+++ b/source/comp/markv_codec.cpp
@@ -0,0 +1,1556 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Contains
+// - SPIR-V to MARK-V encoder
+// - MARK-V to SPIR-V decoder
+//
+// MARK-V is a compression format for SPIR-V binaries. It strips away
+// non-essential information (such as result ids which can be regenerated) and
+// uses various bit reduction techiniques to reduce the size of the binary.
+//
+// MarkvModel is a flatbuffers object containing a set of rules defining how
+// compression/decompression is done (coding schemes, dictionaries).
+
+#include <algorithm>
+#include <cassert>
+#include <cstring>
+#include <functional>
+#include <iostream>
+#include <list>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <vector>
+
+#include "binary.h"
+#include "diagnostic.h"
+#include "enum_string_mapping.h"
+#include "extensions.h"
+#include "ext_inst.h"
+#include "instruction.h"
+#include "opcode.h"
+#include "operand.h"
+#include "spirv-tools/libspirv.h"
+#include "spirv-tools/markv.h"
+#include "spirv_endian.h"
+#include "spirv_validator_options.h"
+#include "util/bit_stream.h"
+#include "util/parse_number.h"
+#include "validate.h"
+#include "val/instruction.h"
+#include "val/validation_state.h"
+
+using libspirv::Instruction;
+using libspirv::ValidationState_t;
+using spvtools::ValidateInstructionAndUpdateValidationState;
+using spvutils::BitReaderWord64;
+using spvutils::BitWriterWord64;
+
+struct spv_markv_encoder_options_t {
+};
+
+struct spv_markv_decoder_options_t {
+};
+
+namespace {
+
+const uint32_t kSpirvMagicNumber = SpvMagicNumber;
+const uint32_t kMarkvMagicNumber = 0x07230303;
+
+enum {
+ kMarkvFirstOpcode = 65536,
+ kMarkvOpNextInstructionEncodesResultId = 65536,
+};
+
+const size_t kCommentNumWhitespaces = 2;
+
+// TODO(atgoo@github.com): This is a placeholder for an autogenerated flatbuffer
+// containing MARK-V model for a specific dataset.
+class MarkvModel {
+ public:
+ size_t opcode_chunk_length() const { return 7; }
+ size_t num_operands_chunk_length() const { return 3; }
+ size_t id_index_chunk_length() const { return 3; }
+
+ size_t u16_chunk_length() const { return 4; }
+ size_t s16_chunk_length() const { return 4; }
+ size_t s16_block_exponent() const { return 6; }
+
+ size_t u32_chunk_length() const { return 8; }
+ size_t s32_chunk_length() const { return 8; }
+ size_t s32_block_exponent() const { return 10; }
+
+ size_t u64_chunk_length() const { return 8; }
+ size_t s64_chunk_length() const { return 8; }
+ size_t s64_block_exponent() const { return 10; }
+};
+
+const MarkvModel* GetDefaultModel() {
+ static MarkvModel model;
+ return &model;
+}
+
+// Returns chunk length used for variable length encoding of spirv operand
+// words. Returns zero if operand type corresponds to potentially multiple
+// words or a word which is not expected to profit from variable width encoding.
+// Chunk length is selected based on the size of expected value.
+// Most of these values will later be encoded with probability-based coding,
+// but variable width integer coding is a good quick solution.
+// TODO(atgoo@github.com): Put this in MarkvModel flatbuffer.
+size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) {
+ switch (type) {
+ case SPV_OPERAND_TYPE_TYPE_ID:
+ return 4;
+ case SPV_OPERAND_TYPE_RESULT_ID:
+ case SPV_OPERAND_TYPE_ID:
+ case SPV_OPERAND_TYPE_SCOPE_ID:
+ case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
+ return 8;
+ case SPV_OPERAND_TYPE_LITERAL_INTEGER:
+ case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER:
+ return 6;
+ case SPV_OPERAND_TYPE_CAPABILITY:
+ return 6;
+ case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
+ case SPV_OPERAND_TYPE_EXECUTION_MODEL:
+ return 3;
+ case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
+ case SPV_OPERAND_TYPE_MEMORY_MODEL:
+ return 2;
+ case SPV_OPERAND_TYPE_EXECUTION_MODE:
+ return 6;
+ case SPV_OPERAND_TYPE_STORAGE_CLASS:
+ return 4;
+ case SPV_OPERAND_TYPE_DIMENSIONALITY:
+ case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
+ return 3;
+ case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
+ return 2;
+ case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
+ return 6;
+ case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
+ case SPV_OPERAND_TYPE_LINKAGE_TYPE:
+ case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
+ case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
+ return 2;
+ case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
+ return 3;
+ case SPV_OPERAND_TYPE_DECORATION:
+ case SPV_OPERAND_TYPE_BUILT_IN:
+ return 6;
+ case SPV_OPERAND_TYPE_GROUP_OPERATION:
+ case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
+ case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
+ return 2;
+ case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
+ case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
+ case SPV_OPERAND_TYPE_LOOP_CONTROL:
+ case SPV_OPERAND_TYPE_IMAGE:
+ case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
+ case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
+ case SPV_OPERAND_TYPE_SELECTION_CONTROL:
+ return 4;
+ case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER:
+ return 6;
+ default:
+ return 0;
+ }
+ return 0;
+}
+
+// Returns true if the opcode has a fixed number of operands. May return a
+// false negative.
+bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
+ switch (opcode) {
+ // TODO(atgoo@github.com) This is not a complete list.
+ case SpvOpNop:
+ case SpvOpName:
+ case SpvOpUndef:
+ case SpvOpSizeOf:
+ case SpvOpLine:
+ case SpvOpNoLine:
+ case SpvOpDecorationGroup:
+ case SpvOpExtension:
+ case SpvOpExtInstImport:
+ case SpvOpMemoryModel:
+ case SpvOpCapability:
+ case SpvOpTypeVoid:
+ case SpvOpTypeBool:
+ case SpvOpTypeInt:
+ case SpvOpTypeFloat:
+ case SpvOpTypeVector:
+ case SpvOpTypeMatrix:
+ case SpvOpTypeSampler:
+ case SpvOpTypeSampledImage:
+ case SpvOpTypeArray:
+ case SpvOpTypePointer:
+ case SpvOpConstantTrue:
+ case SpvOpConstantFalse:
+ case SpvOpLabel:
+ case SpvOpBranch:
+ case SpvOpFunction:
+ case SpvOpFunctionParameter:
+ case SpvOpFunctionEnd:
+ case SpvOpBitcast:
+ case SpvOpCopyObject:
+ case SpvOpTranspose:
+ case SpvOpSNegate:
+ case SpvOpFNegate:
+ case SpvOpIAdd:
+ case SpvOpFAdd:
+ case SpvOpISub:
+ case SpvOpFSub:
+ case SpvOpIMul:
+ case SpvOpFMul:
+ case SpvOpUDiv:
+ case SpvOpSDiv:
+ case SpvOpFDiv:
+ case SpvOpUMod:
+ case SpvOpSRem:
+ case SpvOpSMod:
+ case SpvOpFRem:
+ case SpvOpFMod:
+ case SpvOpVectorTimesScalar:
+ case SpvOpMatrixTimesScalar:
+ case SpvOpVectorTimesMatrix:
+ case SpvOpMatrixTimesVector:
+ case SpvOpMatrixTimesMatrix:
+ case SpvOpOuterProduct:
+ case SpvOpDot:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+size_t GetNumBitsToNextByte(size_t bit_pos) {
+ return (8 - (bit_pos % 8)) % 8;
+}
+
+bool ShouldByteBreak(size_t bit_pos) {
+ const size_t num_bits_to_next_byte = GetNumBitsToNextByte(bit_pos);
+ return num_bits_to_next_byte > 0; // && num_bits_to_next_byte <= 2;
+}
+
+// Defines and returns current MARK-V version.
+uint32_t GetMarkvVersion() {
+ const uint32_t kVersionMajor = 1;
+ const uint32_t kVersionMinor = 0;
+ return kVersionMinor | (kVersionMajor << 16);
+}
+
+class CommentLogger {
+ public:
+ void AppendText(const std::string& str) {
+ Append(str);
+ use_delimiter_ = false;
+ }
+
+ void AppendTextNewLine(const std::string& str) {
+ Append(str);
+ Append("\n");
+ use_delimiter_ = false;
+ }
+
+ void AppendBitSequence(const std::string& str) {
+ if (use_delimiter_)
+ Append("-");
+ Append(str);
+ use_delimiter_ = true;
+ }
+
+ void AppendWhitespaces(size_t num) {
+ Append(std::string(num, ' '));
+ use_delimiter_ = false;
+ }
+
+ void NewLine() {
+ Append("\n");
+ use_delimiter_ = false;
+ }
+
+ std::string GetText() const {
+ return ss_.str();
+ }
+
+ private:
+ void Append(const std::string& str) {
+ ss_ << str;
+ // std::cerr << str;
+ }
+
+ std::stringstream ss_;
+
+ // If true a delimiter will be appended before the next bit sequence.
+ // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
+ bool use_delimiter_ = false;
+};
+
+// Creates spv_text object containing text from |str|.
+// The returned value is owned by the caller and needs to be destroyed with
+// spvTextDestroy.
+spv_text CreateSpvText(const std::string& str) {
+ spv_text out = new spv_text_t();
+ assert(out);
+ char* cstr = new char[str.length() + 1];
+ assert(cstr);
+ std::strncpy(cstr, str.c_str(), str.length());
+ cstr[str.length()] = '\0';
+ out->str = cstr;
+ out->length = str.length();
+ return out;
+}
+
+// Base class for MARK-V encoder and decoder. Contains common functionality
+// such as:
+// - Validator connection and validation state.
+// - SPIR-V grammar and helper functions.
+class MarkvCodecBase {
+ public:
+ virtual ~MarkvCodecBase() {
+ spvValidatorOptionsDestroy(validator_options_);
+ }
+
+ MarkvCodecBase() = delete;
+
+ void SetModel(const MarkvModel* model) {
+ model_ = model;
+ }
+
+ protected:
+ struct MarkvHeader {
+ MarkvHeader() {
+ magic_number = kMarkvMagicNumber;
+ markv_version = GetMarkvVersion();
+ markv_model = 0;
+ markv_length_in_bits = 0;
+ spirv_version = 0;
+ spirv_generator = 0;
+ }
+
+ uint32_t magic_number;
+ uint32_t markv_version;
+ // Magic number to identify or verify MarkvModel used for encoding.
+ uint32_t markv_model;
+ uint32_t markv_length_in_bits;
+ uint32_t spirv_version;
+ uint32_t spirv_generator;
+ };
+
+ explicit MarkvCodecBase(spv_const_context context,
+ spv_validator_options validator_options)
+ : validator_options_(validator_options),
+ vstate_(context, validator_options_), grammar_(context),
+ model_(GetDefaultModel()) {}
+
+ // Validates a single instruction and updates validation state of the module.
+ spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
+ return ValidateInstructionAndUpdateValidationState(&vstate_, &inst);
+ }
+
+ // Returns the current instruction (the one last processed by the validator).
+ const Instruction& GetCurrentInstruction() const {
+ return vstate_.ordered_instructions().back();
+ }
+
+ spv_validator_options validator_options_;
+ ValidationState_t vstate_;
+ const libspirv::AssemblyGrammar grammar_;
+ MarkvHeader header_;
+ const MarkvModel* model_;
+
+ // Move-to-front list of all ids.
+ // TODO(atgoo@github.com) Consider a better move-to-front implementation.
+ std::list<uint32_t> move_to_front_ids_;
+};
+
+// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
+// EncodeInstruction which can be used as callback by spvBinaryParse.
+// Encoded binary is written to an internally maintained bitstream.
+// After the last instruction is encoded, the resulting MARK-V binary can be
+// acquired by calling GetMarkvBinary().
+// The encoder uses SPIR-V validator to keep internal state, therefore
+// SPIR-V binary needs to be able to pass validator checks.
+// CreateCommentsLogger() can be used to enable the encoder to write comments
+// on how encoding was done, which can later be accessed with GetComments().
+class MarkvEncoder : public MarkvCodecBase {
+ public:
+ MarkvEncoder(spv_const_context context,
+ spv_const_markv_encoder_options options)
+ : MarkvCodecBase(context, GetValidatorOptions(options)),
+ options_(options) {
+ (void) options_;
+ }
+
+ // Writes data from SPIR-V header to MARK-V header.
+ spv_result_t EncodeHeader(
+ spv_endianness_t /* endian */, uint32_t /* magic */,
+ uint32_t version, uint32_t generator, uint32_t id_bound,
+ uint32_t /* schema */) {
+ vstate_.setIdBound(id_bound);
+ header_.spirv_version = version;
+ header_.spirv_generator = generator;
+ return SPV_SUCCESS;
+ }
+
+ // Encodes SPIR-V instruction to MARK-V and writes to bit stream.
+ // Operation can fail if the instruction fails to pass the validator or if
+ // the encoder stubmles on something unexpected.
+ spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
+
+ // Concatenates MARK-V header and the bit stream with encoded instructions
+ // into a single buffer and returns it as spv_markv_binary. The returned
+ // value is owned by the caller and needs to be destroyed with
+ // spvMarkvBinaryDestroy().
+ spv_markv_binary GetMarkvBinary() {
+ header_.markv_length_in_bits =
+ static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
+ const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
+
+ spv_markv_binary markv_binary = new spv_markv_binary_t();
+ markv_binary->data = new uint8_t[num_bytes];
+ markv_binary->length = num_bytes;
+ assert(writer_.GetData());
+ std::memcpy(markv_binary->data, &header_, sizeof(header_));
+ std::memcpy(markv_binary->data + sizeof(header_),
+ writer_.GetData(), writer_.GetDataSizeBytes());
+ return markv_binary;
+ }
+
+ // Creates an internal logger which writes comments on the encoding process.
+ // Output can later be accessed with GetComments().
+ void CreateCommentsLogger() {
+ logger_.reset(new CommentLogger());
+ writer_.SetCallback([this](const std::string& str){
+ logger_->AppendBitSequence(str);
+ });
+ }
+
+ // Optionally adds disassembly to the comments.
+ // Disassembly should contain all instructions in the module separated by
+ // \n, and no header.
+ void SetDisassembly(std::string&& disassembly) {
+ disassembly_.reset(new std::stringstream(std::move(disassembly)));
+ }
+
+ // Extracts the next instruction line from the disassembly and logs it.
+ void LogDisassemblyInstruction() {
+ if (logger_ && disassembly_) {
+ std::string line;
+ std::getline(*disassembly_, line, '\n');
+ logger_->AppendTextNewLine(line);
+ }
+ }
+
+ // Extracts the text from the comment logger.
+ std::string GetComments() const {
+ if (!logger_)
+ return "";
+ return logger_->GetText();
+ }
+
+ private:
+ // Creates and returns validator options. Return value owned by the caller.
+ static spv_validator_options GetValidatorOptions(
+ spv_const_markv_encoder_options) {
+ return spvValidatorOptionsCreate();
+ }
+
+ // Writes a single word to bit stream. |type| determines if the word is
+ // encoded and how.
+ void EncodeOperandWord(spv_operand_type_t type, uint32_t word) {
+ const size_t chunk_length =
+ GetOperandVariableWidthChunkLength(type);
+ if (chunk_length) {
+ writer_.WriteVariableWidthU32(word, chunk_length);
+ } else {
+ writer_.WriteUnencoded(word);
+ }
+ }
+
+ // Returns id index and updates move-to-front.
+ // Index is uint16 as SPIR-V module is guaranteed to have no more than 65535
+ // instructions.
+ uint16_t GetIdIndex(uint32_t id) {
+ if (all_known_ids_.count(id)) {
+ uint16_t index = 0;
+ for (auto it = move_to_front_ids_.begin();
+ it != move_to_front_ids_.end(); ++it) {
+ if (*it == id) {
+ if (index != 0) {
+ move_to_front_ids_.erase(it);
+ move_to_front_ids_.push_front(id);
+ }
+ return index;
+ }
+ ++index;
+ }
+ assert(0 && "Id not found in move_to_front_ids_");
+ return 0;
+ } else {
+ all_known_ids_.insert(id);
+ move_to_front_ids_.push_front(id);
+ return static_cast<uint16_t>(move_to_front_ids_.size() - 1);
+ }
+ }
+
+ void AddByteBreakIfAgreed() {
+ if (!ShouldByteBreak(writer_.GetNumBits()))
+ return;
+
+ if (logger_) {
+ logger_->AppendWhitespaces(kCommentNumWhitespaces);
+ logger_->AppendText("ByteBreak:");
+ }
+
+ writer_.WriteBits(0, GetNumBitsToNextByte(writer_.GetNumBits()));
+ }
+
+ // Encodes a literal number operand and writes it to the bit stream.
+ void EncodeLiteralNumber(const Instruction& instruction,
+ const spv_parsed_operand_t& operand);
+
+ spv_const_markv_encoder_options options_;
+
+ // Bit stream where encoded instructions are written.
+ BitWriterWord64 writer_;
+
+ // If not nullptr, encoder will write comments.
+ std::unique_ptr<CommentLogger> logger_;
+
+ // If not nullptr, disassembled instruction lines will be written to comments.
+ // Format: \n separated instruction lines, no header.
+ std::unique_ptr<std::stringstream> disassembly_;
+
+ // All ids which were previosly encountered in the module.
+ std::unordered_set<uint32_t> all_known_ids_;
+};
+
+// Decodes MARK-V buffers written by MarkvEncoder.
+class MarkvDecoder : public MarkvCodecBase {
+ public:
+ MarkvDecoder(spv_const_context context,
+ const uint8_t* markv_data,
+ size_t markv_size_bytes,
+ spv_const_markv_decoder_options options)
+ : MarkvCodecBase(context, GetValidatorOptions(options)),
+ options_(options), reader_(markv_data, markv_size_bytes) {
+ (void) options_;
+ vstate_.setIdBound(1);
+ parsed_operands_.reserve(25);
+ }
+
+ // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
+ // Can be called only once. Fails if data of wrong format or ends prematurely,
+ // of if validation fails.
+ spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
+
+ private:
+ // Describes the format of a typed literal number.
+ struct NumberType {
+ spv_number_kind_t type;
+ uint32_t bit_width;
+ };
+
+ // Creates and returns validator options. Return value owned by the caller.
+ static spv_validator_options GetValidatorOptions(
+ spv_const_markv_decoder_options) {
+ return spvValidatorOptionsCreate();
+ }
+
+ // Reads a single word from bit stream. |type| determines if the word needs
+ // to be decoded and how. Returns false if read fails.
+ bool DecodeOperandWord(spv_operand_type_t type, uint32_t* word) {
+ const size_t chunk_length = GetOperandVariableWidthChunkLength(type);
+ if (chunk_length) {
+ return reader_.ReadVariableWidthU32(word, chunk_length);
+ } else {
+ return reader_.ReadUnencoded(word);
+ }
+ }
+
+ // Fetches the id from the move-to-front list and moves it to front.
+ uint32_t GetIdAndMoveToFront(uint16_t index) {
+ if (index >= move_to_front_ids_.size()) {
+ // Issue new id.
+ const uint32_t id = vstate_.getIdBound();
+ move_to_front_ids_.push_front(id);
+ vstate_.setIdBound(id + 1);
+ return id;
+ } else {
+ if (index == 0)
+ return move_to_front_ids_.front();
+
+ // Iterate to index.
+ auto it = move_to_front_ids_.begin();
+ for (size_t i = 0; i < index; ++i)
+ ++it;
+ const uint32_t id = *it;
+ move_to_front_ids_.erase(it);
+ move_to_front_ids_.push_front(id);
+ return id;
+ }
+ }
+
+ // Decodes id index and fetches the id from move-to-front list.
+ bool DecodeId(uint32_t* id) {
+ uint16_t index = 0;
+ if (!reader_.ReadVariableWidthU16(&index, model_->id_index_chunk_length()))
+ return false;
+
+ *id = GetIdAndMoveToFront(index);
+ return true;
+ }
+
+ bool ReadToByteBreakIfAgreed() {
+ if (!ShouldByteBreak(reader_.GetNumReadBits()))
+ return true;
+
+ uint64_t bits = 0;
+ if (!reader_.ReadBits(&bits,
+ GetNumBitsToNextByte(reader_.GetNumReadBits())))
+ return false;
+
+ if (bits != 0)
+ return false;
+
+ return true;
+ }
+
+ // Reads a literal number as it is described in |operand| from the bit stream,
+ // decodes and writes it to spirv_.
+ spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
+
+ // Reads instruction from bit stream, decodes and validates it.
+ // Decoded instruction is valid until the next call of DecodeInstruction().
+ spv_result_t DecodeInstruction(spv_parsed_instruction_t* inst);
+
+ // Read operand from the stream decodes and validates it.
+ spv_result_t DecodeOperand(size_t instruction_offset, size_t operand_offset,
+ spv_parsed_instruction_t* inst,
+ const spv_operand_type_t type,
+ spv_operand_pattern_t* expected_operands,
+ bool read_result_id);
+
+ // Records the numeric type for an operand according to the type information
+ // associated with the given non-zero type Id. This can fail if the type Id
+ // is not a type Id, or if the type Id does not reference a scalar numeric
+ // type. On success, return SPV_SUCCESS and populates the num_words,
+ // number_kind, and number_bit_width fields of parsed_operand.
+ spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
+ uint32_t type_id);
+
+ // Records the number type for the given instruction, if that
+ // instruction generates a type. For types that aren't scalar numbers,
+ // record something with number kind SPV_NUMBER_NONE.
+ void RecordNumberType(const spv_parsed_instruction_t& inst);
+
+ spv_const_markv_decoder_options options_;
+
+ // Temporary sink where decoded SPIR-V words are written. Once it contains the
+ // entire module, the container is moved and returned.
+ std::vector<uint32_t> spirv_;
+
+ // Bit stream containing encoded data.
+ BitReaderWord64 reader_;
+
+ // Temporary storage for operands of the currently parsed instruction.
+ // Valid until next DecodeInstruction call.
+ std::vector<spv_parsed_operand_t> parsed_operands_;
+
+ // Maps a result ID to its type ID. By convention:
+ // - a result ID that is a type definition maps to itself.
+ // - a result ID without a type maps to 0. (E.g. for OpLabel)
+ std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
+ // Maps a type ID to its number type description.
+ std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
+ // Maps an ExtInstImport id to the extended instruction type.
+ std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
+};
+
+void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction,
+ const spv_parsed_operand_t& operand) {
+ if (operand.number_bit_width == 32) {
+ const uint32_t word = instruction.word(operand.offset);
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ writer_.WriteVariableWidthU32(word, model_->u32_chunk_length());
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int32_t val = 0;
+ std::memcpy(&val, &word, 4);
+ writer_.WriteVariableWidthS32(val, model_->s32_chunk_length(),
+ model_->s32_block_exponent());
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ writer_.WriteUnencoded(word);
+ } else {
+ assert(0);
+ }
+ } else if (operand.number_bit_width == 16) {
+ const uint16_t word =
+ static_cast<uint16_t>(instruction.word(operand.offset));
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ writer_.WriteVariableWidthU16(word, model_->u16_chunk_length());
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int16_t val = 0;
+ std::memcpy(&val, &word, 2);
+ writer_.WriteVariableWidthS16(val, model_->s16_chunk_length(),
+ model_->s16_block_exponent());
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ // TODO(atgoo@github.com) Write only 16 bits.
+ writer_.WriteUnencoded(word);
+ } else {
+ assert(0);
+ }
+ } else {
+ assert(operand.number_bit_width == 64);
+ const uint64_t word =
+ uint64_t(instruction.word(operand.offset)) |
+ (uint64_t(instruction.word(operand.offset + 1)) << 32);
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int64_t val = 0;
+ std::memcpy(&val, &word, 8);
+ writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
+ model_->s64_block_exponent());
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ writer_.WriteUnencoded(word);
+ } else {
+ assert(0);
+ }
+ }
+}
+
+spv_result_t MarkvEncoder::EncodeInstruction(
+ const spv_parsed_instruction_t& inst) {
+ const spv_result_t validation_result = UpdateValidationState(inst);
+ if (validation_result != SPV_SUCCESS)
+ return validation_result;
+
+ bool result_id_was_forward_declared = false;
+ if (all_known_ids_.count(inst.result_id)) {
+ // Result id of the instruction was forward declared.
+ // Write a service opcode to signal this to the decoder.
+ writer_.WriteVariableWidthU32(kMarkvOpNextInstructionEncodesResultId,
+ model_->opcode_chunk_length());
+ result_id_was_forward_declared = true;
+ }
+
+ const Instruction& instruction = GetCurrentInstruction();
+ const auto& operands = instruction.operands();
+
+ LogDisassemblyInstruction();
+
+ // Write opcode.
+ writer_.WriteVariableWidthU32(inst.opcode, model_->opcode_chunk_length());
+
+ if (!OpcodeHasFixedNumberOfOperands(SpvOp(inst.opcode))) {
+ // If the opcode has a variable number of operands, encode the number of
+ // operands with the instruction.
+
+ if (logger_)
+ logger_->AppendWhitespaces(kCommentNumWhitespaces);
+
+ writer_.WriteVariableWidthU16(inst.num_operands,
+ model_->num_operands_chunk_length());
+ }
+
+ // Write operands.
+ for (const auto& operand : operands) {
+ if (operand.type == SPV_OPERAND_TYPE_RESULT_ID &&
+ !result_id_was_forward_declared) {
+ // Register the id, but don't encode it.
+ GetIdIndex(instruction.word(operand.offset));
+ continue;
+ }
+
+ if (logger_)
+ logger_->AppendWhitespaces(kCommentNumWhitespaces);
+
+ if (operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) {
+ EncodeLiteralNumber(instruction, operand);
+ } else if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
+ const char* src =
+ reinterpret_cast<const char*>(&instruction.words()[operand.offset]);
+ const size_t length = spv_strnlen_s(src, operand.num_words * 4);
+ if (length == operand.num_words * 4)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to find terminal character of literal string";
+ for (size_t i = 0; i < length + 1; ++i)
+ writer_.WriteUnencoded(src[i]);
+ } else if (spvIsIdType(operand.type)) {
+ const uint16_t id_index = GetIdIndex(instruction.word(operand.offset));
+ writer_.WriteVariableWidthU16(id_index, model_->id_index_chunk_length());
+ } else {
+ for (int i = 0; i < operand.num_words; ++i) {
+ const uint32_t word = instruction.word(operand.offset + i);
+ EncodeOperandWord(operand.type, word);
+ }
+ }
+ }
+
+ AddByteBreakIfAgreed();
+
+ if (logger_) {
+ logger_->NewLine();
+ logger_->NewLine();
+ }
+
+ return SPV_SUCCESS;
+}
+
+spv_result_t MarkvDecoder::DecodeLiteralNumber(
+ const spv_parsed_operand_t& operand) {
+ if (operand.number_bit_width == 32) {
+ uint32_t word = 0;
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ if (!reader_.ReadVariableWidthU32(&word, model_->u32_chunk_length()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal U32";
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int32_t val = 0;
+ if (!reader_.ReadVariableWidthS32(&val, model_->s32_chunk_length(),
+ model_->s32_block_exponent()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal S32";
+ std::memcpy(&word, &val, 4);
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ if (!reader_.ReadUnencoded(&word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal F32";
+ } else {
+ assert(0);
+ }
+ spirv_.push_back(word);
+ } else if (operand.number_bit_width == 16) {
+ uint32_t word = 0;
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ uint16_t val = 0;
+ if (!reader_.ReadVariableWidthU16(&val, model_->u16_chunk_length()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal U16";
+ word = val;
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int16_t val = 0;
+ if (!reader_.ReadVariableWidthS16(&val, model_->s16_chunk_length(),
+ model_->s16_block_exponent()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal S16";
+ // Int16 is stored as int32 in SPIR-V, not as bits.
+ int32_t val32 = val;
+ std::memcpy(&word, &val32, 4);
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ uint16_t word16 = 0;
+ if (!reader_.ReadUnencoded(&word16))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal F16";
+ word = word16;
+ } else {
+ assert(0);
+ }
+ spirv_.push_back(word);
+ } else {
+ assert(operand.number_bit_width == 64);
+ uint64_t word = 0;
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal U64";
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int64_t val = 0;
+ if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
+ model_->s64_block_exponent()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal S64";
+ std::memcpy(&word, &val, 8);
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ if (!reader_.ReadUnencoded(&word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal F64";
+ } else {
+ assert(0);
+ }
+ spirv_.push_back(static_cast<uint32_t>(word));
+ spirv_.push_back(static_cast<uint32_t>(word >> 32));
+ }
+ return SPV_SUCCESS;
+}
+
+spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
+ const bool header_read_success =
+ reader_.ReadUnencoded(&header_.magic_number) &&
+ reader_.ReadUnencoded(&header_.markv_version) &&
+ reader_.ReadUnencoded(&header_.markv_model) &&
+ reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
+ reader_.ReadUnencoded(&header_.spirv_version) &&
+ reader_.ReadUnencoded(&header_.spirv_generator);
+
+ if (!header_read_success)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Unable to read MARK-V header";
+
+ assert(header_.magic_number == kMarkvMagicNumber);
+ assert(header_.markv_length_in_bits > 0);
+
+ if (header_.magic_number != kMarkvMagicNumber)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "MARK-V binary has incorrect magic number";
+
+ // TODO(atgoo@github.com): Print version strings.
+ if (header_.markv_version != GetMarkvVersion())
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "MARK-V binary and the codec have different versions";
+
+ spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
+ spirv_.resize(5, 0);
+ spirv_[0] = kSpirvMagicNumber;
+ spirv_[1] = header_.spirv_version;
+ spirv_[2] = header_.spirv_generator;
+
+ while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
+ spv_parsed_instruction_t inst = {};
+ const spv_result_t decode_result = DecodeInstruction(&inst);
+ if (decode_result != SPV_SUCCESS)
+ return decode_result;
+
+ const spv_result_t validation_result = UpdateValidationState(inst);
+ if (validation_result != SPV_SUCCESS)
+ return validation_result;
+ }
+
+
+ if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
+ !reader_.OnlyZeroesLeft()) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "MARK-V binary has wrong stated bit length "
+ << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
+ }
+
+ // Decoding of the module is finished, validation state should have correct
+ // id bound.
+ spirv_[3] = vstate_.getIdBound();
+
+ *spirv_binary = std::move(spirv_);
+ return SPV_SUCCESS;
+}
+
+// TODO(atgoo@github.com): The implementation borrows heavily from
+// Parser::parseOperand.
+// Consider coupling them together in some way once MARK-V codec is more mature.
+// For now it's better to keep the code independent for experimentation
+// purposes.
+spv_result_t MarkvDecoder::DecodeOperand(
+ size_t instruction_offset, size_t operand_offset,
+ spv_parsed_instruction_t* inst, const spv_operand_type_t type,
+ spv_operand_pattern_t* expected_operands,
+ bool read_result_id) {
+ const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
+
+ spv_parsed_operand_t parsed_operand;
+ memset(&parsed_operand, 0, sizeof(parsed_operand));
+
+ assert((operand_offset >> 16) == 0);
+ parsed_operand.offset = static_cast<uint16_t>(operand_offset);
+ parsed_operand.type = type;
+
+ // Set default values, may be updated later.
+ parsed_operand.number_kind = SPV_NUMBER_NONE;
+ parsed_operand.number_bit_width = 0;
+
+ const size_t first_word_index = spirv_.size();
+
+ switch (type) {
+ case SPV_OPERAND_TYPE_TYPE_ID: {
+ if (!DecodeId(&inst->type_id)) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read type_id";
+ }
+
+ if (inst->type_id == 0)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded type_id is 0";
+
+ spirv_.push_back(inst->type_id);
+ vstate_.setIdBound(std::max(vstate_.getIdBound(), inst->type_id + 1));
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_RESULT_ID: {
+ if (read_result_id) {
+ if (!DecodeId(&inst->result_id))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read result_id";
+ } else {
+ inst->result_id = vstate_.getIdBound();
+ vstate_.setIdBound(inst->result_id + 1);
+ move_to_front_ids_.push_front(inst->result_id);
+ }
+
+ spirv_.push_back(inst->result_id);
+
+ // Save the result ID to type ID mapping.
+ // In the grammar, type ID always appears before result ID.
+ // A regular value maps to its type. Some instructions (e.g. OpLabel)
+ // have no type Id, and will map to 0. The result Id for a
+ // type-generating instruction (e.g. OpTypeInt) maps to itself.
+ auto insertion_result = id_to_type_id_.emplace(
+ inst->result_id,
+ spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id);
+ if(!insertion_result.second) {
+ return vstate_.diag(SPV_ERROR_INVALID_ID)
+ << "Unexpected behavior: id->type_id pair was already registered";
+ }
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_ID:
+ case SPV_OPERAND_TYPE_OPTIONAL_ID:
+ case SPV_OPERAND_TYPE_SCOPE_ID:
+ case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
+ uint32_t id = 0;
+ if (!DecodeId(&id))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id";
+
+ if (id == 0)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
+
+ spirv_.push_back(id);
+ vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1));
+
+ if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
+
+ parsed_operand.type = SPV_OPERAND_TYPE_ID;
+
+ if (opcode == SpvOpExtInst && parsed_operand.offset == 3) {
+ // The current word is the extended instruction set id.
+ // Set the extended instruction set type for the current instruction.
+ auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
+ if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
+ return vstate_.diag(SPV_ERROR_INVALID_ID)
+ << "OpExtInst set id " << id
+ << " does not reference an OpExtInstImport result Id";
+ }
+ inst->ext_inst_type = ext_inst_type_iter->second;
+ }
+ }
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
+ uint32_t word = 0;
+ if (!DecodeOperandWord(type, &word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read enum";
+
+ spirv_.push_back(word);
+
+ assert(SpvOpExtInst == opcode);
+ assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE);
+ spv_ext_inst_desc ext_inst;
+ if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid extended instruction number: " << word;
+ spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_LITERAL_INTEGER:
+ case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
+ // These are regular single-word literal integer operands.
+ // Post-parsing validation should check the range of the parsed value.
+ parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
+ // It turns out they are always unsigned integers!
+ parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT;
+ parsed_operand.number_bit_width = 32;
+
+ uint32_t word = 0;
+ if (!DecodeOperandWord(type, &word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal integer";
+
+ spirv_.push_back(word);
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
+ case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER:
+ parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
+ if (opcode == SpvOpSwitch) {
+ // The literal operands have the same type as the value
+ // referenced by the selector Id.
+ const uint32_t selector_id = spirv_.at(instruction_offset + 1);
+ const auto type_id_iter = id_to_type_id_.find(selector_id);
+ if (type_id_iter == id_to_type_id_.end() ||
+ type_id_iter->second == 0) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid OpSwitch: selector id " << selector_id
+ << " has no type";
+ }
+ uint32_t type_id = type_id_iter->second;
+
+ if (selector_id == type_id) {
+ // Recall that by convention, a result ID that is a type definition
+ // maps to itself.
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid OpSwitch: selector id " << selector_id
+ << " is a type, not a value";
+ }
+ if (auto error = SetNumericTypeInfoForType(&parsed_operand, type_id))
+ return error;
+ if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT &&
+ parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid OpSwitch: selector id " << selector_id
+ << " is not a scalar integer";
+ }
+ } else {
+ assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
+ // The literal number type is determined by the type Id for the
+ // constant.
+ assert(inst->type_id);
+ if (auto error =
+ SetNumericTypeInfoForType(&parsed_operand, inst->type_id))
+ return error;
+ }
+
+ if (auto error = DecodeLiteralNumber(parsed_operand))
+ return error;
+
+ break;
+
+ case SPV_OPERAND_TYPE_LITERAL_STRING:
+ case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
+ parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING;
+ std::vector<char> str;
+ // The loop is expected to terminate once we encounter '\0' or exhaust
+ // the bit stream.
+ while (true) {
+ char ch = 0;
+ if (!reader_.ReadUnencoded(&ch))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal string";
+
+ str.push_back(ch);
+
+ if (ch == '\0')
+ break;
+ }
+
+ while (str.size() % 4 != 0)
+ str.push_back('\0');
+
+ spirv_.resize(spirv_.size() + str.size() / 4);
+ std::memcpy(&spirv_[first_word_index], str.data(), str.size());
+
+ if (SpvOpExtInstImport == opcode) {
+ // Record the extended instruction type for the ID for this import.
+ // There is only one string literal argument to OpExtInstImport,
+ // so it's sufficient to guard this just on the opcode.
+ const spv_ext_inst_type_t ext_inst_type =
+ spvExtInstImportTypeGet(str.data());
+ if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid extended instruction import '" << str.data() << "'";
+ }
+ // We must have parsed a valid result ID. It's a condition
+ // of the grammar, and we only accept non-zero result Ids.
+ assert(inst->result_id);
+ const bool inserted = import_id_to_ext_inst_type_.emplace(
+ inst->result_id, ext_inst_type).second;
+ (void)inserted;
+ assert(inserted);
+ }
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_CAPABILITY:
+ case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
+ case SPV_OPERAND_TYPE_EXECUTION_MODEL:
+ case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
+ case SPV_OPERAND_TYPE_MEMORY_MODEL:
+ case SPV_OPERAND_TYPE_EXECUTION_MODE:
+ case SPV_OPERAND_TYPE_STORAGE_CLASS:
+ case SPV_OPERAND_TYPE_DIMENSIONALITY:
+ case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
+ case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
+ case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
+ case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
+ case SPV_OPERAND_TYPE_LINKAGE_TYPE:
+ case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
+ case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
+ case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
+ case SPV_OPERAND_TYPE_DECORATION:
+ case SPV_OPERAND_TYPE_BUILT_IN:
+ case SPV_OPERAND_TYPE_GROUP_OPERATION:
+ case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
+ case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
+ // A single word that is a plain enum value.
+ uint32_t word = 0;
+ if (!DecodeOperandWord(type, &word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read enum";
+
+ spirv_.push_back(word);
+
+ // Map an optional operand type to its corresponding concrete type.
+ if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
+ parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
+
+ spv_operand_desc entry;
+ if (grammar_.lookupOperand(type, word, &entry)) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid "
+ << spvOperandTypeStr(parsed_operand.type)
+ << " operand: " << word;
+ }
+
+ // Prepare to accept operands to this operand, if needed.
+ spvPushOperandTypes(entry->operandTypes, expected_operands);
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
+ case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
+ case SPV_OPERAND_TYPE_LOOP_CONTROL:
+ case SPV_OPERAND_TYPE_IMAGE:
+ case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
+ case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
+ case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
+ // This operand is a mask.
+ uint32_t word = 0;
+ if (!DecodeOperandWord(type, &word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read " << spvOperandTypeStr(type)
+ << " for " << spvOpcodeString(SpvOp(inst->opcode));
+
+ spirv_.push_back(word);
+
+ // Map an optional operand type to its corresponding concrete type.
+ if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
+ parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
+ else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
+ parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
+
+ // Check validity of set mask bits. Also prepare for operands for those
+ // masks if they have any. To get operand order correct, scan from
+ // MSB to LSB since we can only prepend operands to a pattern.
+ // The only case in the grammar where you have more than one mask bit
+ // having an operand is for image operands. See SPIR-V 3.14 Image
+ // Operands.
+ uint32_t remaining_word = word;
+ for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
+ if (remaining_word & mask) {
+ spv_operand_desc entry;
+ if (grammar_.lookupOperand(type, mask, &entry)) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid " << spvOperandTypeStr(parsed_operand.type)
+ << " operand: " << word << " has invalid mask component "
+ << mask;
+ }
+ remaining_word ^= mask;
+ spvPushOperandTypes(entry->operandTypes, expected_operands);
+ }
+ }
+ if (word == 0) {
+ // An all-zeroes mask *might* also be valid.
+ spv_operand_desc entry;
+ if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
+ // Prepare for its operands, if any.
+ spvPushOperandTypes(entry->operandTypes, expected_operands);
+ }
+ }
+ break;
+ }
+ default:
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Internal error: Unhandled operand type: " << type;
+ }
+
+ parsed_operand.num_words = uint16_t(spirv_.size() - first_word_index);
+
+ assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(parsed_operand.type));
+ assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(parsed_operand.type));
+
+ parsed_operands_.push_back(parsed_operand);
+
+ return SPV_SUCCESS;
+}
+
+spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) {
+ parsed_operands_.clear();
+ const size_t instruction_offset = spirv_.size();
+
+ bool read_result_id = false;
+
+ while (true) {
+ uint32_t word = 0;
+ if (!reader_.ReadVariableWidthU32(&word,
+ model_->opcode_chunk_length())) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read opcode of instruction";
+ }
+
+ if (word >= kMarkvFirstOpcode) {
+ if (word == kMarkvOpNextInstructionEncodesResultId) {
+ read_result_id = true;
+ } else {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Encountered unknown MARK-V opcode";
+ }
+ } else {
+ inst->opcode = static_cast<uint16_t>(word);
+ break;
+ }
+ }
+
+ const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
+
+ // Opcode/num_words placeholder, the word will be filled in later.
+ spirv_.push_back(0);
+
+ spv_opcode_desc opcode_desc;
+ if (grammar_.lookupOpcode(opcode, &opcode_desc)
+ != SPV_SUCCESS) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
+ }
+
+ spv_operand_pattern_t expected_operands;
+ expected_operands.reserve(opcode_desc->numTypes);
+ for (auto i = 0; i < opcode_desc->numTypes; i++)
+ expected_operands.push_back(opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
+
+ if (!OpcodeHasFixedNumberOfOperands(opcode)) {
+ if (!reader_.ReadVariableWidthU16(&inst->num_operands,
+ model_->num_operands_chunk_length()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read num_operands of instruction";
+ } else {
+ inst->num_operands = static_cast<uint16_t>(expected_operands.size());
+ }
+
+ for (size_t operand_index = 0;
+ operand_index < static_cast<size_t>(inst->num_operands);
+ ++operand_index) {
+ assert(!expected_operands.empty());
+ const spv_operand_type_t type =
+ spvTakeFirstMatchableOperand(&expected_operands);
+
+ const size_t operand_offset = spirv_.size() - instruction_offset;
+
+ const spv_result_t decode_result =
+ DecodeOperand(instruction_offset, operand_offset, inst, type,
+ &expected_operands, read_result_id);
+
+ if (decode_result != SPV_SUCCESS)
+ return decode_result;
+ }
+
+ assert(inst->num_operands == parsed_operands_.size());
+
+ // Only valid while spirv_ and parsed_operands_ remain unchanged.
+ inst->words = &spirv_[instruction_offset];
+ inst->operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
+ inst->num_words = static_cast<uint16_t>(spirv_.size() - instruction_offset);
+ spirv_[instruction_offset] =
+ spvOpcodeMake(inst->num_words, SpvOp(inst->opcode));
+
+ assert(inst->num_words == std::accumulate(
+ parsed_operands_.begin(), parsed_operands_.end(), 1,
+ [](int num_words, const spv_parsed_operand_t& operand) {
+ return num_words += operand.num_words;
+ }) && "num_words in instruction doesn't correspond to the sum of num_words"
+ "in the operands");
+
+ RecordNumberType(*inst);
+
+ if (!ReadToByteBreakIfAgreed())
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read to byte break";
+
+ return SPV_SUCCESS;
+}
+
+spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
+ spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
+ assert(type_id != 0);
+ auto type_info_iter = type_id_to_number_type_info_.find(type_id);
+ if (type_info_iter == type_id_to_number_type_info_.end()) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Type Id " << type_id << " is not a type";
+ }
+
+ const NumberType& info = type_info_iter->second;
+ if (info.type == SPV_NUMBER_NONE) {
+ // This is a valid type, but for something other than a scalar number.
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Type Id " << type_id << " is not a scalar numeric type";
+ }
+
+ parsed_operand->number_kind = info.type;
+ parsed_operand->number_bit_width = info.bit_width;
+ // Round up the word count.
+ parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
+ return SPV_SUCCESS;
+}
+
+void MarkvDecoder::RecordNumberType(const spv_parsed_instruction_t& inst) {
+ const SpvOp opcode = static_cast<SpvOp>(inst.opcode);
+ if (spvOpcodeGeneratesType(opcode)) {
+ NumberType info = {SPV_NUMBER_NONE, 0};
+ if (SpvOpTypeInt == opcode) {
+ info.bit_width = inst.words[inst.operands[1].offset];
+ info.type = inst.words[inst.operands[2].offset] ?
+ SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
+ } else if (SpvOpTypeFloat == opcode) {
+ info.bit_width = inst.words[inst.operands[1].offset];
+ info.type = SPV_NUMBER_FLOATING;
+ }
+ // The *result* Id of a type generating instruction is the type Id.
+ type_id_to_number_type_info_[inst.result_id] = info;
+ }
+}
+
+spv_result_t EncodeHeader(
+ void* user_data, spv_endianness_t endian, uint32_t magic,
+ uint32_t version, uint32_t generator, uint32_t id_bound,
+ uint32_t schema) {
+ MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
+ return encoder->EncodeHeader(
+ endian, magic, version, generator, id_bound, schema);
+}
+
+spv_result_t EncodeInstruction(
+ void* user_data, const spv_parsed_instruction_t* inst) {
+ MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
+ return encoder->EncodeInstruction(*inst);
+}
+
+} // namespace
+
+spv_result_t spvSpirvToMarkv(spv_const_context context,
+ const uint32_t* spirv_words,
+ const size_t spirv_num_words,
+ spv_const_markv_encoder_options options,
+ spv_markv_binary* markv_binary,
+ spv_text* comments, spv_diagnostic* diagnostic) {
+ spv_context_t hijack_context = *context;
+ if (diagnostic) {
+ *diagnostic = nullptr;
+ libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
+ }
+
+ spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words};
+
+ spv_endianness_t endian;
+ spv_position_t position = {};
+ if (spvBinaryEndianness(&spirv_binary, &endian)) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Invalid SPIR-V magic number.";
+ }
+
+ spv_header_t header;
+ if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Invalid SPIR-V header.";
+ }
+
+ MarkvEncoder encoder(&hijack_context, options);
+
+ if (comments) {
+ encoder.CreateCommentsLogger();
+
+ spv_text text = nullptr;
+ if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words,
+ SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr)
+ != SPV_SUCCESS) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Failed to disassemble SPIR-V binary.";
+ }
+ assert(text);
+ encoder.SetDisassembly(std::string(text->str, text->length));
+ spvTextDestroy(text);
+ }
+
+ if (spvBinaryParse(
+ &hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader,
+ EncodeInstruction, diagnostic) != SPV_SUCCESS) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Unable to encode to MARK-V.";
+ }
+
+ if (comments)
+ *comments = CreateSpvText(encoder.GetComments());
+
+ *markv_binary = encoder.GetMarkvBinary();
+ return SPV_SUCCESS;
+}
+
+spv_result_t spvMarkvToSpirv(spv_const_context context,
+ const uint8_t* markv_data,
+ size_t markv_size_bytes,
+ spv_const_markv_decoder_options options,
+ spv_binary* spirv_binary,
+ spv_text* /* comments */, spv_diagnostic* diagnostic) {
+ spv_position_t position = {};
+ spv_context_t hijack_context = *context;
+ if (diagnostic) {
+ *diagnostic = nullptr;
+ libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
+ }
+
+ MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options);
+
+ std::vector<uint32_t> words;
+
+ if (decoder.DecodeModule(&words) != SPV_SUCCESS) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Unable to decode MARK-V.";
+ }
+
+ assert(!words.empty());
+
+ *spirv_binary = new spv_binary_t();
+ (*spirv_binary)->code = new uint32_t[words.size()];
+ (*spirv_binary)->wordCount = words.size();
+ std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size());
+
+ return SPV_SUCCESS;
+}
+
+void spvMarkvBinaryDestroy(spv_markv_binary binary) {
+ if (!binary) return;
+ delete[] binary->data;
+ delete binary;
+}
+
+spv_markv_encoder_options spvMarkvEncoderOptionsCreate() {
+ return new spv_markv_encoder_options_t;
+}
+
+void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) {
+ delete options;
+}
+
+spv_markv_decoder_options spvMarkvDecoderOptionsCreate() {
+ return new spv_markv_decoder_options_t;
+}
+
+void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) {
+ delete options;
+}
diff --git a/source/ext_inst.cpp b/source/ext_inst.cpp
index 3f2b6ce6..12930f19 100644
--- a/source/ext_inst.cpp
+++ b/source/ext_inst.cpp
@@ -31,10 +31,22 @@ static const spv_ext_inst_desc_t openclEntries_1_0[] = {
#include "opencl.std.insts-1.0.inc"
};
+static const spv_ext_inst_desc_t spv_amd_shader_explicit_vertex_parameter_entries[] = {
+#include "spv-amd-shader-explicit-vertex-parameter.insts.inc"
+};
+
+static const spv_ext_inst_desc_t spv_amd_shader_trinary_minmax_entries[] = {
+#include "spv-amd-shader-trinary-minmax.insts.inc"
+};
+
static const spv_ext_inst_desc_t spv_amd_gcn_shader_entries[] = {
#include "spv-amd-gcn-shader.insts.inc"
};
+static const spv_ext_inst_desc_t spv_amd_shader_ballot_entries[] = {
+#include "spv-amd-shader-ballot.insts.inc"
+};
+
spv_result_t spvExtInstTableGet(spv_ext_inst_table* pExtInstTable,
spv_target_env env) {
if (!pExtInstTable) return SPV_ERROR_INVALID_POINTER;
@@ -44,8 +56,14 @@ spv_result_t spvExtInstTableGet(spv_ext_inst_table* pExtInstTable,
glslStd450Entries_1_0},
{SPV_EXT_INST_TYPE_OPENCL_STD, ARRAY_SIZE(openclEntries_1_0),
openclEntries_1_0},
+ {SPV_EXT_INST_TYPE_SPV_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER,
+ ARRAY_SIZE(spv_amd_shader_explicit_vertex_parameter_entries), spv_amd_shader_explicit_vertex_parameter_entries},
+ {SPV_EXT_INST_TYPE_SPV_AMD_SHADER_TRINARY_MINMAX,
+ ARRAY_SIZE(spv_amd_shader_trinary_minmax_entries), spv_amd_shader_trinary_minmax_entries},
{SPV_EXT_INST_TYPE_SPV_AMD_GCN_SHADER,
ARRAY_SIZE(spv_amd_gcn_shader_entries), spv_amd_gcn_shader_entries},
+ {SPV_EXT_INST_TYPE_SPV_AMD_SHADER_BALLOT,
+ ARRAY_SIZE(spv_amd_shader_ballot_entries), spv_amd_shader_ballot_entries},
};
static const spv_ext_inst_table_t table_1_0 = {ARRAY_SIZE(groups_1_0),
@@ -81,9 +99,18 @@ spv_ext_inst_type_t spvExtInstImportTypeGet(const char* name) {
if (!strcmp("OpenCL.std", name)) {
return SPV_EXT_INST_TYPE_OPENCL_STD;
}
+ if (!strcmp("SPV_AMD_shader_explicit_vertex_parameter", name)) {
+ return SPV_EXT_INST_TYPE_SPV_AMD_SHADER_EXPLICIT_VERTEX_PARAMETER;
+ }
+ if (!strcmp("SPV_AMD_shader_trinary_minmax", name)) {
+ return SPV_EXT_INST_TYPE_SPV_AMD_SHADER_TRINARY_MINMAX;
+ }
if (!strcmp("SPV_AMD_gcn_shader", name)) {
return SPV_EXT_INST_TYPE_SPV_AMD_GCN_SHADER;
}
+ if (!strcmp("SPV_AMD_shader_ballot", name)) {
+ return SPV_EXT_INST_TYPE_SPV_AMD_SHADER_BALLOT;
+ }
return SPV_EXT_INST_TYPE_NONE;
}
diff --git a/source/extinst.spv-amd-shader-ballot.grammar.json b/source/extinst.spv-amd-shader-ballot.grammar.json
new file mode 100644
index 00000000..62a470ee
--- /dev/null
+++ b/source/extinst.spv-amd-shader-ballot.grammar.json
@@ -0,0 +1,41 @@
+{
+ "revision" : 5,
+ "instructions" : [
+ {
+ "opname" : "SwizzleInvocationsAMD",
+ "opcode" : 1,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'data'" },
+ { "kind" : "IdRef", "name" : "'offset'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_ballot" ]
+ },
+ {
+ "opname" : "SwizzleInvocationsMaskedAMD",
+ "opcode" : 2,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'data'" },
+ { "kind" : "IdRef", "name" : "'mask'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_ballot" ]
+ },
+ {
+ "opname" : "WriteInvocationAMD",
+ "opcode" : 3,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'inputValue'" },
+ { "kind" : "IdRef", "name" : "'writeValue'" },
+ { "kind" : "IdRef", "name" : "'invocationIndex'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_ballot" ]
+ },
+ {
+ "opname" : "MbcntAMD",
+ "opcode" : 4,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'mask'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_ballot" ]
+ }
+ ]
+}
diff --git a/source/extinst.spv-amd-shader-explicit-vertex-parameter.grammar.json b/source/extinst.spv-amd-shader-explicit-vertex-parameter.grammar.json
new file mode 100644
index 00000000..e156b1b6
--- /dev/null
+++ b/source/extinst.spv-amd-shader-explicit-vertex-parameter.grammar.json
@@ -0,0 +1,14 @@
+{
+ "revision" : 4,
+ "instructions" : [
+ {
+ "opname" : "InterpolateAtVertexAMD",
+ "opcode" : 1,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'interpolant'" },
+ { "kind" : "IdRef", "name" : "'vertexIdx'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_explicit_vertex_parameter" ]
+ }
+ ]
+}
diff --git a/source/extinst.spv-amd-shader-trinary-minmax.grammar.json b/source/extinst.spv-amd-shader-trinary-minmax.grammar.json
new file mode 100644
index 00000000..c681976f
--- /dev/null
+++ b/source/extinst.spv-amd-shader-trinary-minmax.grammar.json
@@ -0,0 +1,95 @@
+{
+ "revision" : 4,
+ "instructions" : [
+ {
+ "opname" : "FMin3AMD",
+ "opcode" : 1,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'x'" },
+ { "kind" : "IdRef", "name" : "'y'" },
+ { "kind" : "IdRef", "name" : "'z'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_trinary_minmax" ]
+ },
+ {
+ "opname" : "UMin3AMD",
+ "opcode" : 2,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'x'" },
+ { "kind" : "IdRef", "name" : "'y'" },
+ { "kind" : "IdRef", "name" : "'z'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_trinary_minmax" ]
+ },
+ {
+ "opname" : "SMin3AMD",
+ "opcode" : 3,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'x'" },
+ { "kind" : "IdRef", "name" : "'y'" },
+ { "kind" : "IdRef", "name" : "'z'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_trinary_minmax" ]
+ },
+ {
+ "opname" : "FMax3AMD",
+ "opcode" : 4,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'x'" },
+ { "kind" : "IdRef", "name" : "'y'" },
+ { "kind" : "IdRef", "name" : "'z'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_trinary_minmax" ]
+ },
+ {
+ "opname" : "UMax3AMD",
+ "opcode" : 5,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'x'" },
+ { "kind" : "IdRef", "name" : "'y'" },
+ { "kind" : "IdRef", "name" : "'z'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_trinary_minmax" ]
+ },
+ {
+ "opname" : "SMax3AMD",
+ "opcode" : 6,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'x'" },
+ { "kind" : "IdRef", "name" : "'y'" },
+ { "kind" : "IdRef", "name" : "'z'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_trinary_minmax" ]
+ },
+ {
+ "opname" : "FMid3AMD",
+ "opcode" : 7,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'x'" },
+ { "kind" : "IdRef", "name" : "'y'" },
+ { "kind" : "IdRef", "name" : "'z'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_trinary_minmax" ]
+ },
+ {
+ "opname" : "UMid3AMD",
+ "opcode" : 8,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'x'" },
+ { "kind" : "IdRef", "name" : "'y'" },
+ { "kind" : "IdRef", "name" : "'z'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_trinary_minmax" ]
+ },
+ {
+ "opname" : "SMid3AMD",
+ "opcode" : 9,
+ "operands" : [
+ { "kind" : "IdRef", "name" : "'x'" },
+ { "kind" : "IdRef", "name" : "'y'" },
+ { "kind" : "IdRef", "name" : "'z'" }
+ ],
+ "extensions" : [ "SPV_AMD_shader_trinary_minmax" ]
+ }
+ ]
+}
diff --git a/source/operand.cpp b/source/operand.cpp
index cf234e77..29da94d1 100644
--- a/source/operand.cpp
+++ b/source/operand.cpp
@@ -16,6 +16,7 @@
#include <assert.h>
#include <string.h>
+#include <algorithm>
#include "macro.h"
@@ -218,26 +219,28 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
return "unknown";
}
-void spvPrependOperandTypes(const spv_operand_type_t* types,
- spv_operand_pattern_t* pattern) {
+void spvPushOperandTypes(const spv_operand_type_t* types,
+ spv_operand_pattern_t* pattern) {
const spv_operand_type_t* endTypes;
for (endTypes = types; *endTypes != SPV_OPERAND_TYPE_NONE; ++endTypes)
;
- pattern->insert(pattern->begin(), types, endTypes);
+ while (endTypes-- != types) {
+ pattern->push_back(*endTypes);
+ }
}
-void spvPrependOperandTypesForMask(const spv_operand_table operandTable,
- const spv_operand_type_t type,
- const uint32_t mask,
- spv_operand_pattern_t* pattern) {
- // Scan from highest bits to lowest bits because we will prepend in LIFO
- // fashion, and we need the operands for lower order bits to appear first.
+void spvPushOperandTypesForMask(const spv_operand_table operandTable,
+ const spv_operand_type_t type,
+ const uint32_t mask,
+ spv_operand_pattern_t* pattern) {
+ // Scan from highest bits to lowest bits because we will append in LIFO
+ // fashion, and we need the operands for lower order bits to be consumed first
for (uint32_t candidate_bit = (1u << 31u); candidate_bit; candidate_bit >>= 1) {
if (candidate_bit & mask) {
spv_operand_desc entry = nullptr;
if (SPV_SUCCESS == spvOperandTableValueLookup(operandTable, type,
candidate_bit, &entry)) {
- spvPrependOperandTypes(entry->operandTypes, pattern);
+ spvPushOperandTypes(entry->operandTypes, pattern);
}
}
}
@@ -262,24 +265,25 @@ bool spvExpandOperandSequenceOnce(spv_operand_type_t type,
spv_operand_pattern_t* pattern) {
switch (type) {
case SPV_OPERAND_TYPE_VARIABLE_ID:
- pattern->insert(pattern->begin(), {SPV_OPERAND_TYPE_OPTIONAL_ID, type});
+ pattern->push_back(type);
+ pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_ID);
return true;
case SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER:
- pattern->insert(pattern->begin(),
- {SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER, type});
+ pattern->push_back(type);
+ pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER);
return true;
case SPV_OPERAND_TYPE_VARIABLE_LITERAL_INTEGER_ID:
// Represents Zero or more (Literal number, Id) pairs,
// where the literal number must be a scalar integer.
- pattern->insert(pattern->begin(),
- {SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER,
- SPV_OPERAND_TYPE_ID, type});
+ pattern->push_back(type);
+ pattern->push_back(SPV_OPERAND_TYPE_ID);
+ pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER);
return true;
case SPV_OPERAND_TYPE_VARIABLE_ID_LITERAL_INTEGER:
// Represents Zero or more (Id, Literal number) pairs.
- pattern->insert(pattern->begin(),
- {SPV_OPERAND_TYPE_OPTIONAL_ID,
- SPV_OPERAND_TYPE_LITERAL_INTEGER, type});
+ pattern->push_back(type);
+ pattern->push_back(SPV_OPERAND_TYPE_LITERAL_INTEGER);
+ pattern->push_back(SPV_OPERAND_TYPE_OPTIONAL_ID);
return true;
default:
break;
@@ -292,25 +296,24 @@ spv_operand_type_t spvTakeFirstMatchableOperand(
assert(!pattern->empty());
spv_operand_type_t result;
do {
- result = pattern->front();
- pattern->pop_front();
+ result = pattern->back();
+ pattern->pop_back();
} while (spvExpandOperandSequenceOnce(result, pattern));
return result;
}
spv_operand_pattern_t spvAlternatePatternFollowingImmediate(
const spv_operand_pattern_t& pattern) {
- spv_operand_pattern_t alternatePattern;
- for (const auto& operand : pattern) {
- if (operand == SPV_OPERAND_TYPE_RESULT_ID) {
- alternatePattern.push_back(operand);
- alternatePattern.push_back(SPV_OPERAND_TYPE_OPTIONAL_CIV);
- return alternatePattern;
- }
- alternatePattern.push_back(SPV_OPERAND_TYPE_OPTIONAL_CIV);
+
+ auto it = std::find(pattern.crbegin(), pattern.crend(), SPV_OPERAND_TYPE_RESULT_ID);
+ if (it != pattern.crend()) {
+ spv_operand_pattern_t alternatePattern(it - pattern.crbegin() + 2, SPV_OPERAND_TYPE_OPTIONAL_CIV);
+ alternatePattern[1] = SPV_OPERAND_TYPE_RESULT_ID;
+ return alternatePattern;
}
+
// No result-id found, so just expect CIVs.
- return {SPV_OPERAND_TYPE_OPTIONAL_CIV};
+ return{ SPV_OPERAND_TYPE_OPTIONAL_CIV };
}
bool spvIsIdType(spv_operand_type_t type) {
diff --git a/source/operand.h b/source/operand.h
index 5d77a347..fa7c6f2e 100644
--- a/source/operand.h
+++ b/source/operand.h
@@ -26,8 +26,13 @@
// next on the input.
//
// As we parse an instruction in text or binary form from left to right,
-// we pull and push from the front of the pattern.
-using spv_operand_pattern_t = std::deque<spv_operand_type_t>;
+// we pop and push at the end of the pattern vector. Symbols later in the
+// pattern vector are matched against the input before symbols earlier in the
+// pattern vector are matched.
+
+// Using a vector in this way reduces memory traffic, which is good for
+// performance.
+using spv_operand_pattern_t = std::vector<spv_operand_type_t>;
// Finds the named operand in the table. The type parameter specifies the
// operand's group. A handle of the operand table entry for this operand will
@@ -62,24 +67,24 @@ bool spvOperandIsOptional(spv_operand_type_t type);
// operand.
bool spvOperandIsVariable(spv_operand_type_t type);
-// Inserts a list of operand types into the front of the given pattern.
+// Append a list of operand types to the end of the pattern vector.
// The types parameter specifies the source array of types, ending with
// SPV_OPERAND_TYPE_NONE.
-void spvPrependOperandTypes(const spv_operand_type_t* types,
- spv_operand_pattern_t* pattern);
+void spvPushOperandTypes(const spv_operand_type_t* types,
+ spv_operand_pattern_t* pattern);
-// Inserts the operands expected after the given typed mask onto the
-// front of the given pattern.
+// Appends the operands expected after the given typed mask onto the
+// end of the given pattern.
//
// Each set bit in the mask represents zero or more operand types that should
-// be prepended onto the pattern. Operands for a less significant bit always
-// appear before operands for a more significant bit.
+// be appended onto the pattern. Operands for a less significant bit always
+// appear after operands for a more significant bit.
//
// If a set bit is unknown, then we assume it has no operands.
-void spvPrependOperandTypesForMask(const spv_operand_table operand_table,
- const spv_operand_type_t mask_type,
- const uint32_t mask,
- spv_operand_pattern_t* pattern);
+void spvPushOperandTypesForMask(const spv_operand_table operand_table,
+ const spv_operand_type_t mask_type,
+ const uint32_t mask,
+ spv_operand_pattern_t* pattern);
// Expands an operand type representing zero or more logical operands,
// exactly once.
diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt
index 64fdc025..e7fd25f0 100644
--- a/source/opt/CMakeLists.txt
+++ b/source/opt/CMakeLists.txt
@@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
add_library(SPIRV-Tools-opt
+ aggressive_dead_code_elim_pass.h
basic_block.h
+ block_merge_pass.h
build_module.h
compact_ids_pass.h
constants.h
+ dead_branch_elim_pass.h
def_use_manager.h
eliminate_dead_constant_pass.h
flatten_decoration_pass.h
@@ -23,10 +26,13 @@ add_library(SPIRV-Tools-opt
fold_spec_constant_op_and_composite_pass.h
freeze_spec_constant_value_pass.h
inline_pass.h
+ insert_extract_elim.h
instruction.h
ir_loader.h
local_access_chain_convert_pass.h
local_single_block_elim_pass.h
+ local_single_store_elim_pass.h
+ local_ssa_elim_pass.h
log.h
module.h
null_pass.h
@@ -40,20 +46,26 @@ add_library(SPIRV-Tools-opt
type_manager.h
unify_const_pass.h
+ aggressive_dead_code_elim_pass.cpp
basic_block.cpp
+ block_merge_pass.cpp
build_module.cpp
compact_ids_pass.cpp
def_use_manager.cpp
+ dead_branch_elim_pass.cpp
eliminate_dead_constant_pass.cpp
flatten_decoration_pass.cpp
function.cpp
fold_spec_constant_op_and_composite_pass.cpp
freeze_spec_constant_value_pass.cpp
inline_pass.cpp
+ insert_extract_elim.cpp
instruction.cpp
ir_loader.cpp
local_access_chain_convert_pass.cpp
local_single_block_elim_pass.cpp
+ local_single_store_elim_pass.cpp
+ local_ssa_elim_pass.cpp
module.cpp
set_spec_constant_default_value_pass.cpp
optimizer.cpp
@@ -76,7 +88,10 @@ target_link_libraries(SPIRV-Tools-opt
set_property(TARGET SPIRV-Tools-opt PROPERTY FOLDER "SPIRV-Tools libraries")
-install(TARGETS SPIRV-Tools-opt
- RUNTIME DESTINATION bin
- LIBRARY DESTINATION lib
- ARCHIVE DESTINATION lib)
+if(ENABLE_SPIRV_TOOLS_INSTALL)
+ install(TARGETS SPIRV-Tools-opt
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
+endif(ENABLE_SPIRV_TOOLS_INSTALL)
+
diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp
new file mode 100644
index 00000000..0b1fd3b9
--- /dev/null
+++ b/source/opt/aggressive_dead_code_elim_pass.cpp
@@ -0,0 +1,546 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "aggressive_dead_code_elim_pass.h"
+
+#include "iterator.h"
+#include "spirv/1.0/GLSL.std.450.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+
+const uint32_t kEntryPointFunctionIdInIdx = 1;
+const uint32_t kStorePtrIdInIdx = 0;
+const uint32_t kLoadPtrIdInIdx = 0;
+const uint32_t kAccessChainPtrIdInIdx = 0;
+const uint32_t kTypePointerStorageClassInIdx = 0;
+const uint32_t kCopyObjectOperandInIdx = 0;
+const uint32_t kExtInstSetIdInIndx = 0;
+const uint32_t kExtInstInstructionInIndx = 1;
+
+} // namespace anonymous
+
+bool AggressiveDCEPass::IsNonPtrAccessChain(const SpvOp opcode) const {
+ return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
+}
+
+ir::Instruction* AggressiveDCEPass::GetPtr(
+ ir::Instruction* ip, uint32_t* varId) {
+ const SpvOp op = ip->opcode();
+ assert(op == SpvOpStore || op == SpvOpLoad);
+ *varId = ip->GetSingleWordInOperand(
+ op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
+ ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
+ while (ptrInst->opcode() == SpvOpCopyObject) {
+ *varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ ptrInst = def_use_mgr_->GetDef(*varId);
+ }
+ ir::Instruction* varInst = ptrInst;
+ while (varInst->opcode() != SpvOpVariable) {
+ if (IsNonPtrAccessChain(varInst->opcode())) {
+ *varId = varInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
+ }
+ else {
+ assert(varInst->opcode() == SpvOpCopyObject);
+ *varId = varInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ }
+ varInst = def_use_mgr_->GetDef(*varId);
+ }
+ return ptrInst;
+}
+
+bool AggressiveDCEPass::IsLocalVar(uint32_t varId) {
+ const ir::Instruction* varInst = def_use_mgr_->GetDef(varId);
+ assert(varInst->opcode() == SpvOpVariable);
+ const uint32_t varTypeId = varInst->type_id();
+ const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
+ return varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) ==
+ SpvStorageClassFunction;
+}
+
+void AggressiveDCEPass::AddStores(uint32_t ptrId) {
+ const analysis::UseList* uses = def_use_mgr_->GetUses(ptrId);
+ if (uses == nullptr)
+ return;
+ for (const auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ switch (op) {
+ case SpvOpAccessChain:
+ case SpvOpInBoundsAccessChain:
+ case SpvOpCopyObject: {
+ AddStores(u.inst->result_id());
+ } break;
+ case SpvOpLoad:
+ break;
+ // Assume it stores eg frexp, modf
+ case SpvOpStore:
+ default: {
+ if (live_insts_.find(u.inst) == live_insts_.end())
+ worklist_.push(u.inst);
+ } break;
+ }
+ }
+}
+
+bool AggressiveDCEPass::IsCombinator(uint32_t op) const {
+ return combinator_ops_shader_.find(op) != combinator_ops_shader_.end();
+}
+
+bool AggressiveDCEPass::IsCombinatorExt(ir::Instruction* inst) const {
+ assert(inst->opcode() == SpvOpExtInst);
+ if (inst->GetSingleWordInOperand(kExtInstSetIdInIndx) == glsl_std_450_id_) {
+ uint32_t op = inst->GetSingleWordInOperand(kExtInstInstructionInIndx);
+ return combinator_ops_glsl_std_450_.find(op) !=
+ combinator_ops_glsl_std_450_.end();
+ }
+ else
+ return false;
+}
+
+bool AggressiveDCEPass::AllExtensionsSupported() const {
+ // If any extension not in whitelist, return false
+ for (auto& ei : module_->extensions()) {
+ const char* extName = reinterpret_cast<const char*>(
+ &ei.GetInOperand(0).words[0]);
+ if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
+ return false;
+ }
+ return true;
+}
+
+void AggressiveDCEPass::KillInstIfTargetDead(ir::Instruction* inst) {
+ const uint32_t tId = inst->GetSingleWordInOperand(0);
+ const ir::Instruction* tInst = def_use_mgr_->GetDef(tId);
+ if (dead_insts_.find(tInst) != dead_insts_.end())
+ def_use_mgr_->KillInst(inst);
+}
+
+bool AggressiveDCEPass::AggressiveDCE(ir::Function* func) {
+ bool modified = false;
+ // Add all control flow and instructions with external side effects
+ // to worklist
+ // TODO(greg-lunarg): Handle Frexp, Modf more optimally
+ // TODO(greg-lunarg): Handle FunctionCall more optimally
+ // TODO(greg-lunarg): Handle CopyMemory more optimally
+ for (auto& blk : *func) {
+ for (auto& inst : blk) {
+ uint32_t op = inst.opcode();
+ switch (op) {
+ case SpvOpStore: {
+ uint32_t varId;
+ (void) GetPtr(&inst, &varId);
+ // non-function-scope stores
+ if (!IsLocalVar(varId)) {
+ worklist_.push(&inst);
+ }
+ } break;
+ case SpvOpExtInst: {
+ // eg. GLSL frexp, modf
+ if (!IsCombinatorExt(&inst))
+ worklist_.push(&inst);
+ } break;
+ case SpvOpCopyMemory:
+ case SpvOpFunctionCall: {
+ return false;
+ } break;
+ default: {
+ // eg. control flow, function call, atomics
+ if (!IsCombinator(op))
+ worklist_.push(&inst);
+ } break;
+ }
+ }
+ }
+ // Add OpGroupDecorates to worklist because they are a pain to remove
+ // ids from.
+ // TODO(greg-lunarg): Handle dead ids in OpGroupDecorate
+ for (auto& ai : module_->annotations()) {
+ if (ai.opcode() == SpvOpGroupDecorate)
+ worklist_.push(&ai);
+ }
+ // Perform closure on live instruction set.
+ while (!worklist_.empty()) {
+ ir::Instruction* liveInst = worklist_.front();
+ live_insts_.insert(liveInst);
+ // Add all operand instructions if not already live
+ liveInst->ForEachInId([this](const uint32_t* iid) {
+ ir::Instruction* inInst = def_use_mgr_->GetDef(*iid);
+ if (live_insts_.find(inInst) == live_insts_.end())
+ worklist_.push(inInst);
+ });
+ // If local load, add all variable's stores if variable not already live
+ if (liveInst->opcode() == SpvOpLoad) {
+ uint32_t varId;
+ (void) GetPtr(liveInst, &varId);
+ if (IsLocalVar(varId)) {
+ if (live_local_vars_.find(varId) == live_local_vars_.end()) {
+ AddStores(varId);
+ live_local_vars_.insert(varId);
+ }
+ }
+ }
+ worklist_.pop();
+ }
+ // Mark all non-live instructions dead
+ for (auto& blk : *func) {
+ for (auto& inst : blk) {
+ if (live_insts_.find(&inst) != live_insts_.end())
+ continue;
+ dead_insts_.insert(&inst);
+ }
+ }
+ // Remove debug and annotation statements referencing dead instructions.
+ // This must be done before killing the instructions, otherwise there are
+ // dead objects in the def/use database.
+ for (auto& di : module_->debugs()) {
+ if (di.opcode() != SpvOpName)
+ continue;
+ KillInstIfTargetDead(&di);
+ modified = true;
+ }
+ for (auto& ai : module_->annotations()) {
+ if (ai.opcode() != SpvOpDecorate && ai.opcode() != SpvOpDecorateId)
+ continue;
+ KillInstIfTargetDead(&ai);
+ modified = true;
+ }
+ // Kill dead instructions
+ for (auto& blk : *func) {
+ for (auto& inst : blk) {
+ if (dead_insts_.find(&inst) == dead_insts_.end())
+ continue;
+ def_use_mgr_->KillInst(&inst);
+ modified = true;
+ }
+ }
+ return modified;
+}
+
+void AggressiveDCEPass::Initialize(ir::Module* module) {
+ module_ = module;
+
+ // Initialize id-to-function map
+ id2function_.clear();
+ for (auto& fn : *module_)
+ id2function_[fn.result_id()] = &fn;
+
+ // Clear collections
+ worklist_ = std::queue<ir::Instruction*>{};
+ live_insts_.clear();
+ live_local_vars_.clear();
+ dead_insts_.clear();
+ combinator_ops_shader_.clear();
+ combinator_ops_glsl_std_450_.clear();
+
+ // TODO(greg-lunarg): Reuse def/use from previous passes
+ def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
+
+ // Initialize extensions whitelist
+ InitExtensions();
+}
+
+Pass::Status AggressiveDCEPass::ProcessImpl() {
+ // Current functionality assumes shader capability
+ // TODO(greg-lunarg): Handle additional capabilities
+ if (!module_->HasCapability(SpvCapabilityShader))
+ return Status::SuccessWithoutChange;
+ // Current functionality assumes logical addressing only
+ // TODO(greg-lunarg): Handle non-logical addressing
+ if (module_->HasCapability(SpvCapabilityAddresses))
+ return Status::SuccessWithoutChange;
+ // If any extensions in the module are not explicitly supported,
+ // return unmodified.
+ if (!AllExtensionsSupported())
+ return Status::SuccessWithoutChange;
+ // Initialize combinator whitelists
+ InitCombinatorSets();
+ // Process all entry point functions
+ bool modified = false;
+ for (auto& e : module_->entry_points()) {
+ ir::Function* fn =
+ id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
+ modified = AggressiveDCE(fn) || modified;
+ }
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+AggressiveDCEPass::AggressiveDCEPass()
+ : module_(nullptr), def_use_mgr_(nullptr) {}
+
+Pass::Status AggressiveDCEPass::Process(ir::Module* module) {
+ Initialize(module);
+ return ProcessImpl();
+}
+
+void AggressiveDCEPass::InitCombinatorSets() {
+ combinator_ops_shader_ = {
+ SpvOpNop,
+ SpvOpUndef,
+ SpvOpVariable,
+ SpvOpImageTexelPointer,
+ SpvOpLoad,
+ SpvOpAccessChain,
+ SpvOpInBoundsAccessChain,
+ SpvOpArrayLength,
+ SpvOpVectorExtractDynamic,
+ SpvOpVectorInsertDynamic,
+ SpvOpVectorShuffle,
+ SpvOpCompositeConstruct,
+ SpvOpCompositeExtract,
+ SpvOpCompositeInsert,
+ SpvOpCopyObject,
+ SpvOpTranspose,
+ SpvOpSampledImage,
+ SpvOpImageSampleImplicitLod,
+ SpvOpImageSampleExplicitLod,
+ SpvOpImageSampleDrefImplicitLod,
+ SpvOpImageSampleDrefExplicitLod,
+ SpvOpImageSampleProjImplicitLod,
+ SpvOpImageSampleProjExplicitLod,
+ SpvOpImageSampleProjDrefImplicitLod,
+ SpvOpImageSampleProjDrefExplicitLod,
+ SpvOpImageFetch,
+ SpvOpImageGather,
+ SpvOpImageDrefGather,
+ SpvOpImageRead,
+ SpvOpImage,
+ SpvOpConvertFToU,
+ SpvOpConvertFToS,
+ SpvOpConvertSToF,
+ SpvOpConvertUToF,
+ SpvOpUConvert,
+ SpvOpSConvert,
+ SpvOpFConvert,
+ SpvOpQuantizeToF16,
+ SpvOpBitcast,
+ SpvOpSNegate,
+ SpvOpFNegate,
+ SpvOpIAdd,
+ SpvOpFAdd,
+ SpvOpISub,
+ SpvOpFSub,
+ SpvOpIMul,
+ SpvOpFMul,
+ SpvOpUDiv,
+ SpvOpSDiv,
+ SpvOpFDiv,
+ SpvOpUMod,
+ SpvOpSRem,
+ SpvOpSMod,
+ SpvOpFRem,
+ SpvOpFMod,
+ SpvOpVectorTimesScalar,
+ SpvOpMatrixTimesScalar,
+ SpvOpVectorTimesMatrix,
+ SpvOpMatrixTimesVector,
+ SpvOpMatrixTimesMatrix,
+ SpvOpOuterProduct,
+ SpvOpDot,
+ SpvOpIAddCarry,
+ SpvOpISubBorrow,
+ SpvOpUMulExtended,
+ SpvOpSMulExtended,
+ SpvOpAny,
+ SpvOpAll,
+ SpvOpIsNan,
+ SpvOpIsInf,
+ SpvOpLogicalEqual,
+ SpvOpLogicalNotEqual,
+ SpvOpLogicalOr,
+ SpvOpLogicalAnd,
+ SpvOpLogicalNot,
+ SpvOpSelect,
+ SpvOpIEqual,
+ SpvOpINotEqual,
+ SpvOpUGreaterThan,
+ SpvOpSGreaterThan,
+ SpvOpUGreaterThanEqual,
+ SpvOpSGreaterThanEqual,
+ SpvOpULessThan,
+ SpvOpSLessThan,
+ SpvOpULessThanEqual,
+ SpvOpSLessThanEqual,
+ SpvOpFOrdEqual,
+ SpvOpFUnordEqual,
+ SpvOpFOrdNotEqual,
+ SpvOpFUnordNotEqual,
+ SpvOpFOrdLessThan,
+ SpvOpFUnordLessThan,
+ SpvOpFOrdGreaterThan,
+ SpvOpFUnordGreaterThan,
+ SpvOpFOrdLessThanEqual,
+ SpvOpFUnordLessThanEqual,
+ SpvOpFOrdGreaterThanEqual,
+ SpvOpFUnordGreaterThanEqual,
+ SpvOpShiftRightLogical,
+ SpvOpShiftRightArithmetic,
+ SpvOpShiftLeftLogical,
+ SpvOpBitwiseOr,
+ SpvOpBitwiseXor,
+ SpvOpBitwiseAnd,
+ SpvOpNot,
+ SpvOpBitFieldInsert,
+ SpvOpBitFieldSExtract,
+ SpvOpBitFieldUExtract,
+ SpvOpBitReverse,
+ SpvOpBitCount,
+ SpvOpDPdx,
+ SpvOpDPdy,
+ SpvOpFwidth,
+ SpvOpDPdxFine,
+ SpvOpDPdyFine,
+ SpvOpFwidthFine,
+ SpvOpDPdxCoarse,
+ SpvOpDPdyCoarse,
+ SpvOpFwidthCoarse,
+ SpvOpPhi,
+ SpvOpImageSparseSampleImplicitLod,
+ SpvOpImageSparseSampleExplicitLod,
+ SpvOpImageSparseSampleDrefImplicitLod,
+ SpvOpImageSparseSampleDrefExplicitLod,
+ SpvOpImageSparseSampleProjImplicitLod,
+ SpvOpImageSparseSampleProjExplicitLod,
+ SpvOpImageSparseSampleProjDrefImplicitLod,
+ SpvOpImageSparseSampleProjDrefExplicitLod,
+ SpvOpImageSparseFetch,
+ SpvOpImageSparseGather,
+ SpvOpImageSparseDrefGather,
+ SpvOpImageSparseTexelsResident,
+ SpvOpImageSparseRead,
+ SpvOpSizeOf
+ // TODO(dneto): Add instructions enabled by ImageQuery
+ };
+
+ // Find supported extension instruction set ids
+ glsl_std_450_id_ = module_->GetExtInstImportId("GLSL.std.450");
+
+ combinator_ops_glsl_std_450_ = {
+ GLSLstd450Round,
+ GLSLstd450RoundEven,
+ GLSLstd450Trunc,
+ GLSLstd450FAbs,
+ GLSLstd450SAbs,
+ GLSLstd450FSign,
+ GLSLstd450SSign,
+ GLSLstd450Floor,
+ GLSLstd450Ceil,
+ GLSLstd450Fract,
+ GLSLstd450Radians,
+ GLSLstd450Degrees,
+ GLSLstd450Sin,
+ GLSLstd450Cos,
+ GLSLstd450Tan,
+ GLSLstd450Asin,
+ GLSLstd450Acos,
+ GLSLstd450Atan,
+ GLSLstd450Sinh,
+ GLSLstd450Cosh,
+ GLSLstd450Tanh,
+ GLSLstd450Asinh,
+ GLSLstd450Acosh,
+ GLSLstd450Atanh,
+ GLSLstd450Atan2,
+ GLSLstd450Pow,
+ GLSLstd450Exp,
+ GLSLstd450Log,
+ GLSLstd450Exp2,
+ GLSLstd450Log2,
+ GLSLstd450Sqrt,
+ GLSLstd450InverseSqrt,
+ GLSLstd450Determinant,
+ GLSLstd450MatrixInverse,
+ GLSLstd450ModfStruct,
+ GLSLstd450FMin,
+ GLSLstd450UMin,
+ GLSLstd450SMin,
+ GLSLstd450FMax,
+ GLSLstd450UMax,
+ GLSLstd450SMax,
+ GLSLstd450FClamp,
+ GLSLstd450UClamp,
+ GLSLstd450SClamp,
+ GLSLstd450FMix,
+ GLSLstd450IMix,
+ GLSLstd450Step,
+ GLSLstd450SmoothStep,
+ GLSLstd450Fma,
+ GLSLstd450FrexpStruct,
+ GLSLstd450Ldexp,
+ GLSLstd450PackSnorm4x8,
+ GLSLstd450PackUnorm4x8,
+ GLSLstd450PackSnorm2x16,
+ GLSLstd450PackUnorm2x16,
+ GLSLstd450PackHalf2x16,
+ GLSLstd450PackDouble2x32,
+ GLSLstd450UnpackSnorm2x16,
+ GLSLstd450UnpackUnorm2x16,
+ GLSLstd450UnpackHalf2x16,
+ GLSLstd450UnpackSnorm4x8,
+ GLSLstd450UnpackUnorm4x8,
+ GLSLstd450UnpackDouble2x32,
+ GLSLstd450Length,
+ GLSLstd450Distance,
+ GLSLstd450Cross,
+ GLSLstd450Normalize,
+ GLSLstd450FaceForward,
+ GLSLstd450Reflect,
+ GLSLstd450Refract,
+ GLSLstd450FindILsb,
+ GLSLstd450FindSMsb,
+ GLSLstd450FindUMsb,
+ GLSLstd450InterpolateAtCentroid,
+ GLSLstd450InterpolateAtSample,
+ GLSLstd450InterpolateAtOffset,
+ GLSLstd450NMin,
+ GLSLstd450NMax,
+ GLSLstd450NClamp
+ };
+}
+
+void AggressiveDCEPass::InitExtensions() {
+ extensions_whitelist_.clear();
+ extensions_whitelist_.insert({
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
+ "SPV_AMD_gcn_shader",
+ "SPV_KHR_shader_ballot",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_KHR_shader_draw_parameters",
+ "SPV_KHR_subgroup_vote",
+ "SPV_KHR_16bit_storage",
+ "SPV_KHR_device_group",
+ "SPV_KHR_multiview",
+ "SPV_NVX_multiview_per_view_attributes",
+ "SPV_NV_viewport_array2",
+ "SPV_NV_stereo_view_rendering",
+ "SPV_NV_sample_mask_override_coverage",
+ "SPV_NV_geometry_shader_passthrough",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_KHR_storage_buffer_storage_class",
+ // SPV_KHR_variable_pointers
+ // Currently do not support extended pointer expressions
+ "SPV_AMD_gpu_shader_int16",
+ "SPV_KHR_post_depth_coverage",
+ "SPV_KHR_shader_atomic_counter_ops",
+ });
+}
+
+} // namespace opt
+} // namespace spvtools
+
diff --git a/source/opt/aggressive_dead_code_elim_pass.h b/source/opt/aggressive_dead_code_elim_pass.h
new file mode 100644
index 00000000..b386c85c
--- /dev/null
+++ b/source/opt/aggressive_dead_code_elim_pass.h
@@ -0,0 +1,147 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_AGGRESSIVE_DCE_PASS_H_
+#define LIBSPIRV_OPT_AGGRESSIVE_DCE_PASS_H_
+
+#include <algorithm>
+#include <map>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "basic_block.h"
+#include "def_use_manager.h"
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class AggressiveDCEPass : public Pass {
+
+ using cbb_ptr = const ir::BasicBlock*;
+
+ public:
+ using GetBlocksFunction =
+ std::function<std::vector<ir::BasicBlock*>*(const ir::BasicBlock*)>;
+
+ AggressiveDCEPass();
+ const char* name() const override { return "aggressive-dce"; }
+ Status Process(ir::Module*) override;
+
+ private:
+ // Returns true if |opcode| is a non-ptr access chain op
+ bool IsNonPtrAccessChain(const SpvOp opcode) const;
+
+ // Given a load or store |ip|, return the pointer instruction.
+ // Also return the base variable's id in |varId|.
+ ir::Instruction* GetPtr(ir::Instruction* ip, uint32_t* varId);
+
+ // Add all store instruction which use |ptrId|, directly or indirectly,
+ // to the live instruction worklist.
+ void AddStores(uint32_t ptrId);
+
+ // Return true if variable with |varId| is function scope
+ bool IsLocalVar(uint32_t varId);
+
+ // Initialize combinator data structures
+ void InitCombinatorSets();
+
+ // Return true if core operator |op| has no side-effects. Currently returns
+ // true only for shader capability operations.
+ // TODO(greg-lunarg): Add kernel and other operators
+ bool IsCombinator(uint32_t op) const;
+
+ // Return true if OpExtInst |inst| has no side-effects. Currently returns
+ // true only for std.GLSL.450 extensions
+ // TODO(greg-lunarg): Add support for other extensions
+ bool IsCombinatorExt(ir::Instruction* inst) const;
+
+ // Initialize extensions whitelist
+ void InitExtensions();
+
+ // Return true if all extensions in this module are supported by this pass.
+ bool AllExtensionsSupported() const;
+
+ // Kill debug or annotation |inst| if target operand is dead.
+ void KillInstIfTargetDead(ir::Instruction* inst);
+
+ // For function |func|, mark all Stores to non-function-scope variables
+ // and block terminating instructions as live. Recursively mark the values
+ // they use. When complete, delete any non-live instructions. Return true
+ // if the function has been modified.
+ //
+ // Note: This function does not delete useless control structures. All
+ // existing control structures will remain. This can leave not-insignificant
+ // sequences of ultimately useless code.
+ // TODO(): Remove useless control constructs.
+ bool AggressiveDCE(ir::Function* func);
+
+ void Initialize(ir::Module* module);
+ Pass::Status ProcessImpl();
+
+ // Module this pass is processing
+ ir::Module* module_;
+
+ // Def-Uses for the module we are processing
+ std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
+
+ // Map from function's result id to function
+ std::unordered_map<uint32_t, ir::Function*> id2function_;
+
+ // Live Instruction Worklist. An instruction is added to this list
+ // if it might have a side effect, either directly or indirectly.
+ // If we don't know, then add it to this list. Instructions are
+ // removed from this list as the algorithm traces side effects,
+ // building up the live instructions set |live_insts_|.
+ std::queue<ir::Instruction*> worklist_;
+
+ // Live Instructions
+ std::unordered_set<const ir::Instruction*> live_insts_;
+
+ // Live Local Variables
+ std::unordered_set<uint32_t> live_local_vars_;
+
+ // Dead instructions. Use for debug cleanup.
+ std::unordered_set<const ir::Instruction*> dead_insts_;
+
+ // Opcodes of shader capability core executable instructions
+ // without side-effect. This is a whitelist of operators
+ // that can safely be left unmarked as live at the beginning of
+ // aggressive DCE.
+ std::unordered_set<uint32_t> combinator_ops_shader_;
+
+ // Opcodes of GLSL_std_450 extension executable instructions
+ // without side-effect. This is a whitelist of operators
+ // that can safely be left unmarked as live at the beginning of
+ // aggressive DCE.
+ std::unordered_set<uint32_t> combinator_ops_glsl_std_450_;
+
+ // Extensions supported by this pass.
+ std::unordered_set<std::string> extensions_whitelist_;
+
+ // Set id for glsl_std_450 extension instructions
+ uint32_t glsl_std_450_id_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_AGGRESSIVE_DCE_PASS_H_
+
diff --git a/source/opt/basic_block.h b/source/opt/basic_block.h
index 73249051..b30abe48 100644
--- a/source/opt/basic_block.h
+++ b/source/opt/basic_block.h
@@ -42,10 +42,15 @@ class BasicBlock {
// Sets the enclosing function for this basic block.
void SetParent(Function* function) { function_ = function; }
+
// Appends an instruction to this basic block.
inline void AddInstruction(std::unique_ptr<Instruction> i);
+
+ // Appends all of block's instructions (except label) to this block
+ inline void AddInstructions(BasicBlock* bp);
+
// The label starting this basic block.
- Instruction& Label() { return *label_; }
+ Instruction* GetLabelInst() { return label_.get(); }
// Returns the id of the label at the top of this block
inline uint32_t id() const { return label_->result_id(); }
@@ -59,6 +64,11 @@ class BasicBlock {
return const_iterator(&insts_, insts_.cend());
}
+ iterator tail() {
+ assert(!insts_.empty());
+ return iterator(&insts_, std::prev(insts_.end()));
+ }
+
// Runs the given function |f| on each instruction in this basic block, and
// optionally on the debug line instructions that might precede them.
inline void ForEachInst(const std::function<void(Instruction*)>& f,
@@ -91,6 +101,11 @@ inline void BasicBlock::AddInstruction(std::unique_ptr<Instruction> i) {
insts_.emplace_back(std::move(i));
}
+inline void BasicBlock::AddInstructions(BasicBlock* bp) {
+ auto bEnd = end();
+ (void) bEnd.InsertBefore(&bp->insts_);
+}
+
inline void BasicBlock::ForEachInst(const std::function<void(Instruction*)>& f,
bool run_on_debug_line_insts) {
if (label_) label_->ForEachInst(f, run_on_debug_line_insts);
diff --git a/source/opt/block_merge_pass.cpp b/source/opt/block_merge_pass.cpp
new file mode 100644
index 00000000..ad9de47e
--- /dev/null
+++ b/source/opt/block_merge_pass.cpp
@@ -0,0 +1,189 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "block_merge_pass.h"
+
+#include "iterator.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+
+const int kEntryPointFunctionIdInIdx = 1;
+
+} // anonymous namespace
+
+bool BlockMergePass::IsLoopHeader(ir::BasicBlock* block_ptr) {
+ auto iItr = block_ptr->tail();
+ if (iItr == block_ptr->begin())
+ return false;
+ --iItr;
+ return iItr->opcode() == SpvOpLoopMerge;
+}
+
+bool BlockMergePass::HasMultipleRefs(uint32_t labId) {
+ const analysis::UseList* uses = def_use_mgr_->GetUses(labId);
+ int rcnt = 0;
+ for (const auto u : *uses) {
+ // Don't count OpName
+ if (u.inst->opcode() == SpvOpName)
+ continue;
+ if (rcnt == 1)
+ return true;
+ ++rcnt;
+ }
+ return false;
+}
+
+void BlockMergePass::KillInstAndName(ir::Instruction* inst) {
+ const uint32_t id = inst->result_id();
+ if (id != 0) {
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses != nullptr)
+ for (auto u : *uses)
+ if (u.inst->opcode() == SpvOpName) {
+ def_use_mgr_->KillInst(u.inst);
+ break;
+ }
+ }
+ def_use_mgr_->KillInst(inst);
+}
+
+bool BlockMergePass::MergeBlocks(ir::Function* func) {
+ bool modified = false;
+ for (auto bi = func->begin(); bi != func->end(); ) {
+ // Do not merge loop header blocks, at least for now.
+ if (IsLoopHeader(&*bi)) {
+ ++bi;
+ continue;
+ }
+ // Find block with single successor which has no other predecessors.
+ // Continue and Merge blocks are currently ruled out as second blocks.
+ // Happily any such candidate blocks will have >1 uses due to their
+ // LoopMerge instruction.
+ // TODO(): Deal with phi instructions that reference the
+ // second block. Happily, these references currently inhibit
+ // the merge.
+ auto ii = bi->end();
+ --ii;
+ ir::Instruction* br = &*ii;
+ if (br->opcode() != SpvOpBranch) {
+ ++bi;
+ continue;
+ }
+ const uint32_t labId = br->GetSingleWordInOperand(0);
+ if (HasMultipleRefs(labId)) {
+ ++bi;
+ continue;
+ }
+ // Merge blocks
+ def_use_mgr_->KillInst(br);
+ auto sbi = bi;
+ for (; sbi != func->end(); ++sbi)
+ if (sbi->id() == labId)
+ break;
+ // If bi is sbi's only predecessor, it dominates sbi and thus
+ // sbi must follow bi in func's ordering.
+ assert(sbi != func->end());
+ bi->AddInstructions(&*sbi);
+ KillInstAndName(sbi->GetLabelInst());
+ (void) sbi.Erase();
+ // reprocess block
+ modified = true;
+ }
+ return modified;
+}
+
+void BlockMergePass::Initialize(ir::Module* module) {
+
+ module_ = module;
+
+ // Initialize function and block maps
+ id2function_.clear();
+ for (auto& fn : *module_)
+ id2function_[fn.result_id()] = &fn;
+
+ // TODO(greg-lunarg): Reuse def/use from previous passes
+ def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
+
+ // Initialize extension whitelist
+ InitExtensions();
+};
+
+bool BlockMergePass::AllExtensionsSupported() const {
+ // If any extension not in whitelist, return false
+ for (auto& ei : module_->extensions()) {
+ const char* extName = reinterpret_cast<const char*>(
+ &ei.GetInOperand(0).words[0]);
+ if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
+ return false;
+ }
+ return true;
+}
+
+Pass::Status BlockMergePass::ProcessImpl() {
+ // Do not process if any disallowed extensions are enabled
+ if (!AllExtensionsSupported())
+ return Status::SuccessWithoutChange;
+ bool modified = false;
+ for (auto& e : module_->entry_points()) {
+ ir::Function* fn =
+ id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
+ modified = MergeBlocks(fn) || modified;
+ }
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+BlockMergePass::BlockMergePass()
+ : module_(nullptr), def_use_mgr_(nullptr) {}
+
+Pass::Status BlockMergePass::Process(ir::Module* module) {
+ Initialize(module);
+ return ProcessImpl();
+}
+
+void BlockMergePass::InitExtensions() {
+ extensions_whitelist_.clear();
+ extensions_whitelist_.insert({
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
+ "SPV_AMD_gcn_shader",
+ "SPV_KHR_shader_ballot",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_KHR_shader_draw_parameters",
+ "SPV_KHR_subgroup_vote",
+ "SPV_KHR_16bit_storage",
+ "SPV_KHR_device_group",
+ "SPV_KHR_multiview",
+ "SPV_NVX_multiview_per_view_attributes",
+ "SPV_NV_viewport_array2",
+ "SPV_NV_stereo_view_rendering",
+ "SPV_NV_sample_mask_override_coverage",
+ "SPV_NV_geometry_shader_passthrough",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_KHR_storage_buffer_storage_class",
+ "SPV_KHR_variable_pointers",
+ "SPV_AMD_gpu_shader_int16",
+ "SPV_KHR_post_depth_coverage",
+ "SPV_KHR_shader_atomic_counter_ops",
+ });
+}
+
+} // namespace opt
+} // namespace spvtools
+
diff --git a/source/opt/block_merge_pass.h b/source/opt/block_merge_pass.h
new file mode 100644
index 00000000..73adf0ef
--- /dev/null
+++ b/source/opt/block_merge_pass.h
@@ -0,0 +1,82 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_BLOCK_MERGE_PASS_H_
+#define LIBSPIRV_OPT_BLOCK_MERGE_PASS_H_
+
+#include <algorithm>
+#include <map>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "basic_block.h"
+#include "def_use_manager.h"
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class BlockMergePass : public Pass {
+ public:
+ BlockMergePass();
+ const char* name() const override { return "sroa"; }
+ Status Process(ir::Module*) override;
+
+ private:
+ // Return true if |block_ptr| is loop header block
+ bool IsLoopHeader(ir::BasicBlock* block_ptr);
+
+ // Return true if |labId| has multiple refs. Do not count OpName.
+ bool HasMultipleRefs(uint32_t labId);
+
+ // Kill any OpName instruction referencing |inst|, then kill |inst|.
+ void KillInstAndName(ir::Instruction* inst);
+
+ // Search |func| for blocks which have a single Branch to a block
+ // with no other predecessors. Merge these blocks into a single block.
+ bool MergeBlocks(ir::Function* func);
+
+ // Initialize extensions whitelist
+ void InitExtensions();
+
+ // Return true if all extensions in this module are allowed by this pass.
+ bool AllExtensionsSupported() const;
+
+ void Initialize(ir::Module* module);
+ Pass::Status ProcessImpl();
+
+ // Module this pass is processing
+ ir::Module* module_;
+
+ // Def-Uses for the module we are processing
+ std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
+
+ // Map from function's result id to function
+ std::unordered_map<uint32_t, ir::Function*> id2function_;
+
+ // Extensions supported by this pass.
+ std::unordered_set<std::string> extensions_whitelist_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_BLOCK_MERGE_PASS_H_
+
diff --git a/source/opt/dead_branch_elim_pass.cpp b/source/opt/dead_branch_elim_pass.cpp
new file mode 100644
index 00000000..0c57307b
--- /dev/null
+++ b/source/opt/dead_branch_elim_pass.cpp
@@ -0,0 +1,418 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "dead_branch_elim_pass.h"
+
+#include "cfa.h"
+#include "iterator.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+
+const uint32_t kEntryPointFunctionIdInIdx = 1;
+const uint32_t kBranchCondConditionalIdInIdx = 0;
+const uint32_t kBranchCondTrueLabIdInIdx = 1;
+const uint32_t kBranchCondFalseLabIdInIdx = 2;
+const uint32_t kSelectionMergeMergeBlockIdInIdx = 0;
+const uint32_t kPhiVal0IdInIdx = 0;
+const uint32_t kPhiLab0IdInIdx = 1;
+const uint32_t kPhiVal1IdInIdx = 2;
+const uint32_t kLoopMergeMergeBlockIdInIdx = 0;
+const uint32_t kLoopMergeContinueBlockIdInIdx = 1;
+
+} // anonymous namespace
+
+uint32_t DeadBranchElimPass::MergeBlockIdIfAny(
+ const ir::BasicBlock& blk, uint32_t* cbid) const {
+ auto merge_ii = blk.cend();
+ --merge_ii;
+ uint32_t mbid = 0;
+ *cbid = 0;
+ if (merge_ii != blk.cbegin()) {
+ --merge_ii;
+ if (merge_ii->opcode() == SpvOpLoopMerge) {
+ mbid = merge_ii->GetSingleWordInOperand(kLoopMergeMergeBlockIdInIdx);
+ *cbid = merge_ii->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx);
+ }
+ else if (merge_ii->opcode() == SpvOpSelectionMerge) {
+ mbid = merge_ii->GetSingleWordInOperand(
+ kSelectionMergeMergeBlockIdInIdx);
+ }
+ }
+ return mbid;
+}
+
+void DeadBranchElimPass::ComputeStructuredSuccessors(ir::Function* func) {
+ // If header, make merge block first successor. If a loop header, make
+ // the second successor the continue target.
+ for (auto& blk : *func) {
+ uint32_t cbid;
+ uint32_t mbid = MergeBlockIdIfAny(blk, &cbid);
+ if (mbid != 0) {
+ block2structured_succs_[&blk].push_back(id2block_[mbid]);
+ if (cbid != 0)
+ block2structured_succs_[&blk].push_back(id2block_[cbid]);
+ }
+ // add true successors
+ blk.ForEachSuccessorLabel([&blk, this](uint32_t sbid) {
+ block2structured_succs_[&blk].push_back(id2block_[sbid]);
+ });
+ }
+}
+
+void DeadBranchElimPass::ComputeStructuredOrder(
+ ir::Function* func, std::list<ir::BasicBlock*>* order) {
+ // Compute structured successors and do DFS
+ ComputeStructuredSuccessors(func);
+ auto ignore_block = [](cbb_ptr) {};
+ auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
+ auto get_structured_successors = [this](const ir::BasicBlock* block) {
+ return &(block2structured_succs_[block]); };
+ // TODO(greg-lunarg): Get rid of const_cast by making moving const
+ // out of the cfa.h prototypes and into the invoking code.
+ auto post_order = [&](cbb_ptr b) {
+ order->push_front(const_cast<ir::BasicBlock*>(b)); };
+
+ spvtools::CFA<ir::BasicBlock>::DepthFirstTraversal(
+ &*func->begin(), get_structured_successors, ignore_block, post_order,
+ ignore_edge);
+}
+
+void DeadBranchElimPass::GetConstCondition(
+ uint32_t condId, bool* condVal, bool* condIsConst) {
+ ir::Instruction* cInst = def_use_mgr_->GetDef(condId);
+ switch (cInst->opcode()) {
+ case SpvOpConstantFalse: {
+ *condVal = false;
+ *condIsConst = true;
+ } break;
+ case SpvOpConstantTrue: {
+ *condVal = true;
+ *condIsConst = true;
+ } break;
+ case SpvOpLogicalNot: {
+ bool negVal;
+ (void)GetConstCondition(cInst->GetSingleWordInOperand(0),
+ &negVal, condIsConst);
+ if (*condIsConst)
+ *condVal = !negVal;
+ } break;
+ default: {
+ *condIsConst = false;
+ } break;
+ }
+}
+
+void DeadBranchElimPass::AddBranch(uint32_t labelId, ir::BasicBlock* bp) {
+ std::unique_ptr<ir::Instruction> newBranch(
+ new ir::Instruction(SpvOpBranch, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}}));
+ def_use_mgr_->AnalyzeInstDefUse(&*newBranch);
+ bp->AddInstruction(std::move(newBranch));
+}
+
+void DeadBranchElimPass::AddSelectionMerge(uint32_t labelId,
+ ir::BasicBlock* bp) {
+ std::unique_ptr<ir::Instruction> newMerge(
+ new ir::Instruction(SpvOpSelectionMerge, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {0}}}));
+ def_use_mgr_->AnalyzeInstDefUse(&*newMerge);
+ bp->AddInstruction(std::move(newMerge));
+}
+
+void DeadBranchElimPass::AddBranchConditional(uint32_t condId,
+ uint32_t trueLabId, uint32_t falseLabId, ir::BasicBlock* bp) {
+ std::unique_ptr<ir::Instruction> newBranchCond(
+ new ir::Instruction(SpvOpBranchConditional, 0, 0,
+ {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {condId}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {trueLabId}},
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {falseLabId}}}));
+ def_use_mgr_->AnalyzeInstDefUse(&*newBranchCond);
+ bp->AddInstruction(std::move(newBranchCond));
+}
+
+void DeadBranchElimPass::KillNamesAndDecorates(uint32_t id) {
+ // TODO(greg-lunarg): Remove id from any OpGroupDecorate and
+ // kill if no other operands.
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return;
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return;
+ std::list<ir::Instruction*> killList;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op == SpvOpName || IsDecorate(op))
+ killList.push_back(u.inst);
+ }
+ for (auto kip : killList)
+ def_use_mgr_->KillInst(kip);
+}
+
+void DeadBranchElimPass::KillNamesAndDecorates(ir::Instruction* inst) {
+ const uint32_t rId = inst->result_id();
+ if (rId == 0)
+ return;
+ KillNamesAndDecorates(rId);
+}
+
+void DeadBranchElimPass::KillAllInsts(ir::BasicBlock* bp) {
+ bp->ForEachInst([this](ir::Instruction* ip) {
+ KillNamesAndDecorates(ip);
+ def_use_mgr_->KillInst(ip);
+ });
+}
+
+bool DeadBranchElimPass::GetConstConditionalSelectionBranch(ir::BasicBlock* bp,
+ ir::Instruction** branchInst, ir::Instruction** mergeInst,
+ uint32_t *condId, bool *condVal) {
+ auto ii = bp->end();
+ --ii;
+ *branchInst = &*ii;
+ if ((*branchInst)->opcode() != SpvOpBranchConditional)
+ return false;
+ if (ii == bp->begin())
+ return false;
+ --ii;
+ *mergeInst = &*ii;
+ if ((*mergeInst)->opcode() != SpvOpSelectionMerge)
+ return false;
+ bool condIsConst;
+ *condId = (*branchInst)->GetSingleWordInOperand(
+ kBranchCondConditionalIdInIdx);
+ (void) GetConstCondition(*condId, condVal, &condIsConst);
+ return condIsConst;
+}
+
+bool DeadBranchElimPass::HasNonPhiRef(uint32_t labelId) {
+ analysis::UseList* uses = def_use_mgr_->GetUses(labelId);
+ if (uses == nullptr)
+ return false;
+ for (auto u : *uses)
+ if (u.inst->opcode() != SpvOpPhi)
+ return true;
+ return false;
+}
+
+bool DeadBranchElimPass::EliminateDeadBranches(ir::Function* func) {
+ // Traverse blocks in structured order
+ std::list<ir::BasicBlock*> structuredOrder;
+ ComputeStructuredOrder(func, &structuredOrder);
+ std::unordered_set<ir::BasicBlock*> elimBlocks;
+ bool modified = false;
+ for (auto bi = structuredOrder.begin(); bi != structuredOrder.end(); ++bi) {
+ // Skip blocks that are already in the elimination set
+ if (elimBlocks.find(*bi) != elimBlocks.end())
+ continue;
+ // Skip blocks that don't have constant conditional branch preceded
+ // by OpSelectionMerge
+ ir::Instruction* br;
+ ir::Instruction* mergeInst;
+ uint32_t condId;
+ bool condVal;
+ if (!GetConstConditionalSelectionBranch(*bi, &br, &mergeInst, &condId,
+ &condVal))
+ continue;
+
+ // Replace conditional branch with unconditional branch
+ const uint32_t trueLabId =
+ br->GetSingleWordInOperand(kBranchCondTrueLabIdInIdx);
+ const uint32_t falseLabId =
+ br->GetSingleWordInOperand(kBranchCondFalseLabIdInIdx);
+ const uint32_t mergeLabId =
+ mergeInst->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx);
+ const uint32_t liveLabId = condVal == true ? trueLabId : falseLabId;
+ const uint32_t deadLabId = condVal == true ? falseLabId : trueLabId;
+ AddBranch(liveLabId, *bi);
+ def_use_mgr_->KillInst(br);
+ def_use_mgr_->KillInst(mergeInst);
+
+ // Iterate to merge block adding dead blocks to elimination set
+ auto dbi = bi;
+ ++dbi;
+ uint32_t dLabId = (*dbi)->id();
+ while (dLabId != mergeLabId) {
+ if (!HasNonPhiRef(dLabId)) {
+ // Kill use/def for all instructions and mark block for elimination
+ KillAllInsts(*dbi);
+ elimBlocks.insert(*dbi);
+ }
+ ++dbi;
+ dLabId = (*dbi)->id();
+ }
+
+ // Process phi instructions in merge block.
+ // elimBlocks are now blocks which cannot precede merge block. Also,
+ // if eliminated branch is to merge label, remember the conditional block
+ // also cannot precede merge block.
+ uint32_t deadCondLabId = 0;
+ if (deadLabId == mergeLabId)
+ deadCondLabId = (*bi)->id();
+ (*dbi)->ForEachPhiInst([&elimBlocks, &deadCondLabId, this](
+ ir::Instruction* phiInst) {
+ const uint32_t phiLabId0 =
+ phiInst->GetSingleWordInOperand(kPhiLab0IdInIdx);
+ const bool useFirst =
+ elimBlocks.find(id2block_[phiLabId0]) == elimBlocks.end() &&
+ phiLabId0 != deadCondLabId;
+ const uint32_t phiValIdx =
+ useFirst ? kPhiVal0IdInIdx : kPhiVal1IdInIdx;
+ const uint32_t replId = phiInst->GetSingleWordInOperand(phiValIdx);
+ const uint32_t phiId = phiInst->result_id();
+ KillNamesAndDecorates(phiId);
+ (void)def_use_mgr_->ReplaceAllUsesWith(phiId, replId);
+ def_use_mgr_->KillInst(phiInst);
+ });
+
+ // If merge block has no predecessors, replace the new branch with
+ // a MergeSelection/BranchCondition using the original constant condition
+ // and the mergeblock as the false branch. This is done so the merge block
+ // is not orphaned, which could cause invalid control flow in certain case.
+ // TODO(greg-lunarg): Do this only in cases where invalid code is caused.
+ if (!HasNonPhiRef(mergeLabId)) {
+ auto eii = (*bi)->end();
+ --eii;
+ ir::Instruction* nbr = &*eii;
+ AddSelectionMerge(mergeLabId, *bi);
+ if (condVal == true)
+ AddBranchConditional(condId, liveLabId, mergeLabId, *bi);
+ else
+ AddBranchConditional(condId, mergeLabId, liveLabId, *bi);
+ def_use_mgr_->KillInst(nbr);
+ }
+ modified = true;
+ }
+
+ // Erase dead blocks
+ for (auto ebi = func->begin(); ebi != func->end(); )
+ if (elimBlocks.find(&*ebi) != elimBlocks.end())
+ ebi = ebi.Erase();
+ else
+ ++ebi;
+ return modified;
+}
+
+void DeadBranchElimPass::Initialize(ir::Module* module) {
+
+ module_ = module;
+
+ // Initialize function and block maps
+ id2function_.clear();
+ id2block_.clear();
+ block2structured_succs_.clear();
+ for (auto& fn : *module_) {
+ // Initialize function and block maps.
+ id2function_[fn.result_id()] = &fn;
+ for (auto& blk : fn) {
+ id2block_[blk.id()] = &blk;
+ }
+ }
+
+ // TODO(greg-lunarg): Reuse def/use from previous passes
+ def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
+
+ // Initialize extension whitelist
+ InitExtensions();
+};
+
+void DeadBranchElimPass::FindNamedOrDecoratedIds() {
+ for (auto& di : module_->debugs())
+ if (di.opcode() == SpvOpName)
+ named_or_decorated_ids_.insert(di.GetSingleWordInOperand(0));
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpDecorate || ai.opcode() == SpvOpDecorateId)
+ named_or_decorated_ids_.insert(ai.GetSingleWordInOperand(0));
+}
+
+bool DeadBranchElimPass::AllExtensionsSupported() const {
+ // If any extension not in whitelist, return false
+ for (auto& ei : module_->extensions()) {
+ const char* extName = reinterpret_cast<const char*>(
+ &ei.GetInOperand(0).words[0]);
+ if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
+ return false;
+ }
+ return true;
+}
+
+Pass::Status DeadBranchElimPass::ProcessImpl() {
+ // Current functionality assumes structured control flow.
+ // TODO(greg-lunarg): Handle non-structured control-flow.
+ if (!module_->HasCapability(SpvCapabilityShader))
+ return Status::SuccessWithoutChange;
+ // Do not process if module contains OpGroupDecorate. Additional
+ // support required in KillNamesAndDecorates().
+ // TODO(greg-lunarg): Add support for OpGroupDecorate
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpGroupDecorate)
+ return Status::SuccessWithoutChange;
+ // Do not process if any disallowed extensions are enabled
+ if (!AllExtensionsSupported())
+ return Status::SuccessWithoutChange;
+ // Collect all named and decorated ids
+ FindNamedOrDecoratedIds();
+ // Process all entry point functions
+ bool modified = false;
+ for (const auto& e : module_->entry_points()) {
+ ir::Function* fn =
+ id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
+ modified = EliminateDeadBranches(fn) || modified;
+ }
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+DeadBranchElimPass::DeadBranchElimPass()
+ : module_(nullptr), def_use_mgr_(nullptr) {}
+
+Pass::Status DeadBranchElimPass::Process(ir::Module* module) {
+ Initialize(module);
+ return ProcessImpl();
+}
+
+void DeadBranchElimPass::InitExtensions() {
+ extensions_whitelist_.clear();
+ extensions_whitelist_.insert({
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
+ "SPV_AMD_gcn_shader",
+ "SPV_KHR_shader_ballot",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_KHR_shader_draw_parameters",
+ "SPV_KHR_subgroup_vote",
+ "SPV_KHR_16bit_storage",
+ "SPV_KHR_device_group",
+ "SPV_KHR_multiview",
+ "SPV_NVX_multiview_per_view_attributes",
+ "SPV_NV_viewport_array2",
+ "SPV_NV_stereo_view_rendering",
+ "SPV_NV_sample_mask_override_coverage",
+ "SPV_NV_geometry_shader_passthrough",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_KHR_storage_buffer_storage_class",
+ "SPV_KHR_variable_pointers",
+ "SPV_AMD_gpu_shader_int16",
+ "SPV_KHR_post_depth_coverage",
+ "SPV_KHR_shader_atomic_counter_ops",
+ });
+}
+
+} // namespace opt
+} // namespace spvtools
+
diff --git a/source/opt/dead_branch_elim_pass.h b/source/opt/dead_branch_elim_pass.h
new file mode 100644
index 00000000..912149ec
--- /dev/null
+++ b/source/opt/dead_branch_elim_pass.h
@@ -0,0 +1,159 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_DEAD_BRANCH_ELIM_PASS_H_
+#define LIBSPIRV_OPT_DEAD_BRANCH_ELIM_PASS_H_
+
+
+#include <algorithm>
+#include <map>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "basic_block.h"
+#include "def_use_manager.h"
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class DeadBranchElimPass : public Pass {
+
+ using cbb_ptr = const ir::BasicBlock*;
+
+ public:
+ using GetBlocksFunction =
+ std::function<std::vector<ir::BasicBlock*>*(const ir::BasicBlock*)>;
+
+ DeadBranchElimPass();
+ const char* name() const override { return "dead-branch-elim"; }
+ Status Process(ir::Module*) override;
+
+ private:
+ // Returns the id of the merge block declared by a merge instruction in
+ // this block |blk|, if any. If none, returns zero. If loop merge, returns
+ // the continue target id in |cbid|. Otherwise sets to zero.
+ uint32_t MergeBlockIdIfAny(const ir::BasicBlock& blk, uint32_t* cbid) const;
+
+ // Compute structured successors for function |func|.
+ // A block's structured successors are the blocks it branches to
+ // together with its declared merge block if it has one.
+ // When order matters, the merge block always appears first and if
+ // a loop merge block, the continue target always appears second.
+ // This assures correct depth first search in the presence of early
+ // returns and kills. If the successor vector contain duplicates
+ // of the merge and continue blocks, they are safely ignored by DFS.
+ void ComputeStructuredSuccessors(ir::Function* func);
+
+ // Compute structured block order |order| for function |func|. This order
+ // has the property that dominators are before all blocks they dominate and
+ // merge blocks are after all blocks that are in the control constructs of
+ // their header.
+ void ComputeStructuredOrder(
+ ir::Function* func, std::list<ir::BasicBlock*>* order);
+
+ // If |condId| is boolean constant, return value in |condVal| and
+ // |condIsConst| as true, otherwise return |condIsConst| as false.
+ void GetConstCondition(uint32_t condId, bool* condVal, bool* condIsConst);
+
+ // Add branch to |labelId| to end of block |bp|.
+ void AddBranch(uint32_t labelId, ir::BasicBlock* bp);
+
+ // Add selction merge of |labelId| to end of block |bp|.
+ void AddSelectionMerge(uint32_t labelId, ir::BasicBlock* bp);
+
+ // Add conditional branch of |condId|, |trueLabId| and |falseLabId| to end
+ // of block |bp|.
+ void AddBranchConditional(uint32_t condId, uint32_t trueLabId,
+ uint32_t falseLabId, ir::BasicBlock* bp);
+
+ // Kill all instructions in block |bp|.
+ void KillAllInsts(ir::BasicBlock* bp);
+
+ // If block |bp| contains constant conditional branch preceeded by an
+ // OpSelctionMerge, return true and return branch and merge instructions
+ // in |branchInst| and |mergeInst| and the boolean constant in |condVal|.
+ bool GetConstConditionalSelectionBranch(ir::BasicBlock* bp,
+ ir::Instruction** branchInst, ir::Instruction** mergeInst,
+ uint32_t *condId, bool *condVal);
+
+ // Return true if |labelId| has any non-phi references
+ bool HasNonPhiRef(uint32_t labelId);
+
+ // Return true if |op| is supported decorate.
+ inline bool IsDecorate(uint32_t op) const {
+ return (op == SpvOpDecorate || op == SpvOpDecorateId);
+ }
+
+ // Kill all name and decorate ops using |inst|
+ void KillNamesAndDecorates(ir::Instruction* inst);
+
+ // Kill all name and decorate ops using |id|
+ void KillNamesAndDecorates(uint32_t id);
+
+ // Collect all named or decorated ids in module
+ void FindNamedOrDecoratedIds();
+
+ // For function |func|, look for BranchConditionals with constant condition
+ // and convert to a Branch to the indicated label. Delete resulting dead
+ // blocks. Assumes only structured control flow in shader. Note some such
+ // branches and blocks may be left to avoid creating invalid control flow.
+ // TODO(greg-lunarg): Remove remaining constant conditional branches and
+ // dead blocks.
+ bool EliminateDeadBranches(ir::Function* func);
+
+ // Initialize extensions whitelist
+ void InitExtensions();
+
+ // Return true if all extensions in this module are allowed by this pass.
+ bool AllExtensionsSupported() const;
+
+ void Initialize(ir::Module* module);
+ Pass::Status ProcessImpl();
+
+ // Module this pass is processing
+ ir::Module* module_;
+
+ // Def-Uses for the module we are processing
+ std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
+
+ // Map from function's result id to function
+ std::unordered_map<uint32_t, ir::Function*> id2function_;
+
+ // Map from block's label id to block.
+ std::unordered_map<uint32_t, ir::BasicBlock*> id2block_;
+
+ // Map from block to its structured successor blocks. See
+ // ComputeStructuredSuccessors() for definition.
+ std::unordered_map<const ir::BasicBlock*, std::vector<ir::BasicBlock*>>
+ block2structured_succs_;
+
+ // named or decorated ids
+ std::unordered_set<uint32_t> named_or_decorated_ids_;
+
+ // Extensions supported by this pass.
+ std::unordered_set<std::string> extensions_whitelist_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_DEAD_BRANCH_ELIM_PASS_H_
+
diff --git a/source/opt/def_use_manager.cpp b/source/opt/def_use_manager.cpp
index 3d8fd50e..a144acd4 100644
--- a/source/opt/def_use_manager.cpp
+++ b/source/opt/def_use_manager.cpp
@@ -21,7 +21,7 @@ namespace spvtools {
namespace opt {
namespace analysis {
-void DefUseManager::AnalyzeInstDefUse(ir::Instruction* inst) {
+void DefUseManager::AnalyzeInstDef(ir::Instruction* inst) {
const uint32_t def_id = inst->result_id();
if (def_id != 0) {
auto iter = id_to_def_.find(def_id);
@@ -31,10 +31,13 @@ void DefUseManager::AnalyzeInstDefUse(ir::Instruction* inst) {
ClearInst(iter->second);
}
id_to_def_[def_id] = inst;
- } else {
+ }
+ else {
ClearInst(inst);
}
+}
+void DefUseManager::AnalyzeInstUse(ir::Instruction* inst) {
// Create entry for the given instruction. Note that the instruction may
// not have any in-operands. In such cases, we still need a entry for those
// instructions so this manager knows it has seen the instruction later.
@@ -43,21 +46,26 @@ void DefUseManager::AnalyzeInstDefUse(ir::Instruction* inst) {
for (uint32_t i = 0; i < inst->NumOperands(); ++i) {
switch (inst->GetOperand(i).type) {
// For any id type but result id type
- case SPV_OPERAND_TYPE_ID:
- case SPV_OPERAND_TYPE_TYPE_ID:
- case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
- case SPV_OPERAND_TYPE_SCOPE_ID: {
- uint32_t use_id = inst->GetSingleWordOperand(i);
- // use_id is used by the instruction generating def_id.
- id_to_uses_[use_id].push_back({inst, i});
- inst_to_used_ids_[inst].push_back(use_id);
- } break;
- default:
- break;
+ case SPV_OPERAND_TYPE_ID:
+ case SPV_OPERAND_TYPE_TYPE_ID:
+ case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
+ case SPV_OPERAND_TYPE_SCOPE_ID: {
+ uint32_t use_id = inst->GetSingleWordOperand(i);
+ // use_id is used by the instruction generating def_id.
+ id_to_uses_[use_id].push_back({ inst, i });
+ inst_to_used_ids_[inst].push_back(use_id);
+ } break;
+ default:
+ break;
}
}
}
+void DefUseManager::AnalyzeInstDefUse(ir::Instruction* inst) {
+ AnalyzeInstDef(inst);
+ AnalyzeInstUse(inst);
+}
+
ir::Instruction* DefUseManager::GetDef(uint32_t id) {
auto iter = id_to_def_.find(id);
if (iter == id_to_def_.end()) return nullptr;
diff --git a/source/opt/def_use_manager.h b/source/opt/def_use_manager.h
index cd779d53..a639caba 100644
--- a/source/opt/def_use_manager.h
+++ b/source/opt/def_use_manager.h
@@ -59,6 +59,12 @@ class DefUseManager {
DefUseManager& operator=(const DefUseManager&) = delete;
DefUseManager& operator=(DefUseManager&&) = delete;
+ // Analyzes the defs in the given |inst|.
+ void AnalyzeInstDef(ir::Instruction* inst);
+
+ // Analyzes the uses in the given |inst|.
+ void AnalyzeInstUse(ir::Instruction* inst);
+
// Analyzes the defs and uses in the given |inst|.
void AnalyzeInstDefUse(ir::Instruction* inst);
diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp
index de55688d..de7b98cb 100644
--- a/source/opt/inline_pass.cpp
+++ b/source/opt/inline_pass.cpp
@@ -626,7 +626,7 @@ Pass::Status InlinePass::ProcessImpl() {
for (auto& e : module_->entry_points()) {
ir::Function* fn =
id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
- modified = modified || Inline(fn);
+ modified = Inline(fn) || modified;
}
FinalizeNextId(module_);
diff --git a/source/opt/insert_extract_elim.cpp b/source/opt/insert_extract_elim.cpp
new file mode 100644
index 00000000..cae3bab0
--- /dev/null
+++ b/source/opt/insert_extract_elim.cpp
@@ -0,0 +1,169 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "insert_extract_elim.h"
+
+#include "iterator.h"
+
+static const int kSpvEntryPointFunctionId = 1;
+static const int kSpvExtractCompositeId = 0;
+static const int kSpvInsertObjectId = 0;
+static const int kSpvInsertCompositeId = 1;
+
+namespace spvtools {
+namespace opt {
+
+bool InsertExtractElimPass::ExtInsMatch(const ir::Instruction* extInst,
+ const ir::Instruction* insInst) const {
+ if (extInst->NumInOperands() != insInst->NumInOperands() - 1)
+ return false;
+ uint32_t numIdx = extInst->NumInOperands() - 1;
+ for (uint32_t i = 0; i < numIdx; ++i)
+ if (extInst->GetSingleWordInOperand(i + 1) !=
+ insInst->GetSingleWordInOperand(i + 2))
+ return false;
+ return true;
+}
+
+bool InsertExtractElimPass::ExtInsConflict(const ir::Instruction* extInst,
+ const ir::Instruction* insInst) const {
+ if (extInst->NumInOperands() == insInst->NumInOperands() - 1)
+ return false;
+ uint32_t extNumIdx = extInst->NumInOperands() - 1;
+ uint32_t insNumIdx = insInst->NumInOperands() - 2;
+ uint32_t numIdx = std::min(extNumIdx, insNumIdx);
+ for (uint32_t i = 0; i < numIdx; ++i)
+ if (extInst->GetSingleWordInOperand(i + 1) !=
+ insInst->GetSingleWordInOperand(i + 2))
+ return false;
+ return true;
+}
+
+bool InsertExtractElimPass::EliminateInsertExtract(ir::Function* func) {
+ bool modified = false;
+ for (auto bi = func->begin(); bi != func->end(); ++bi) {
+ for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
+ switch (ii->opcode()) {
+ case SpvOpCompositeExtract: {
+ uint32_t cid = ii->GetSingleWordInOperand(kSpvExtractCompositeId);
+ ir::Instruction* cinst = def_use_mgr_->GetDef(cid);
+ uint32_t replId = 0;
+ while (cinst->opcode() == SpvOpCompositeInsert) {
+ if (ExtInsConflict(&*ii, cinst))
+ break;
+ if (ExtInsMatch(&*ii, cinst)) {
+ replId = cinst->GetSingleWordInOperand(kSpvInsertObjectId);
+ break;
+ }
+ cid = cinst->GetSingleWordInOperand(kSpvInsertCompositeId);
+ cinst = def_use_mgr_->GetDef(cid);
+ }
+ if (replId != 0) {
+ const uint32_t extId = ii->result_id();
+ (void)def_use_mgr_->ReplaceAllUsesWith(extId, replId);
+ def_use_mgr_->KillInst(&*ii);
+ modified = true;
+ }
+ } break;
+ default:
+ break;
+ }
+ }
+ }
+ return modified;
+}
+
+void InsertExtractElimPass::Initialize(ir::Module* module) {
+
+ module_ = module;
+
+ // Initialize function and block maps
+ id2function_.clear();
+ for (auto& fn : *module_)
+ id2function_[fn.result_id()] = &fn;
+
+ // Do def/use on whole module
+ def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
+
+ // Initialize extension whitelist
+ InitExtensions();
+};
+
+bool InsertExtractElimPass::AllExtensionsSupported() const {
+ // If any extension not in whitelist, return false
+ for (auto& ei : module_->extensions()) {
+ const char* extName = reinterpret_cast<const char*>(
+ &ei.GetInOperand(0).words[0]);
+ if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
+ return false;
+ }
+ return true;
+}
+
+Pass::Status InsertExtractElimPass::ProcessImpl() {
+ // Do not process if any disallowed extensions are enabled
+ if (!AllExtensionsSupported())
+ return Status::SuccessWithoutChange;
+ // Process all entry point functions.
+ bool modified = false;
+ for (auto& e : module_->entry_points()) {
+ ir::Function* fn =
+ id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
+ modified = EliminateInsertExtract(fn) || modified;
+ }
+
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+InsertExtractElimPass::InsertExtractElimPass()
+ : module_(nullptr), def_use_mgr_(nullptr) {}
+
+Pass::Status InsertExtractElimPass::Process(ir::Module* module) {
+ Initialize(module);
+ return ProcessImpl();
+}
+
+void InsertExtractElimPass::InitExtensions() {
+ extensions_whitelist_.clear();
+ extensions_whitelist_.insert({
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
+ "SPV_AMD_gcn_shader",
+ "SPV_KHR_shader_ballot",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_KHR_shader_draw_parameters",
+ "SPV_KHR_subgroup_vote",
+ "SPV_KHR_16bit_storage",
+ "SPV_KHR_device_group",
+ "SPV_KHR_multiview",
+ "SPV_NVX_multiview_per_view_attributes",
+ "SPV_NV_viewport_array2",
+ "SPV_NV_stereo_view_rendering",
+ "SPV_NV_sample_mask_override_coverage",
+ "SPV_NV_geometry_shader_passthrough",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_KHR_storage_buffer_storage_class",
+ "SPV_KHR_variable_pointers",
+ "SPV_AMD_gpu_shader_int16",
+ "SPV_KHR_post_depth_coverage",
+ "SPV_KHR_shader_atomic_counter_ops",
+ });
+}
+
+} // namespace opt
+} // namespace spvtools
+
diff --git a/source/opt/insert_extract_elim.h b/source/opt/insert_extract_elim.h
new file mode 100644
index 00000000..440dcd63
--- /dev/null
+++ b/source/opt/insert_extract_elim.h
@@ -0,0 +1,87 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_INSERT_EXTRACT_ELIM_PASS_H_
+#define LIBSPIRV_OPT_INSERT_EXTRACT_ELIM_PASS_H_
+
+
+#include <algorithm>
+#include <map>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "basic_block.h"
+#include "def_use_manager.h"
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class InsertExtractElimPass : public Pass {
+ public:
+ InsertExtractElimPass();
+ const char* name() const override { return "insert_extract_elim"; }
+ Status Process(ir::Module*) override;
+
+ private:
+ // Return true if indices of extract |extInst| and insert |insInst| match
+ bool ExtInsMatch(
+ const ir::Instruction* extInst, const ir::Instruction* insInst) const;
+
+ // Return true if indices of extract |extInst| and insert |insInst| conflict,
+ // specifically, if the insert changes bits specified by the extract, but
+ // changes either more bits or less bits than the extract specifies,
+ // meaning the exact value being inserted cannot be used to replace
+ // the extract.
+ bool ExtInsConflict(
+ const ir::Instruction* extInst, const ir::Instruction* insInst) const;
+
+ // Look for OpExtract on sequence of OpInserts in |func|. If there is an
+ // insert with identical indices, replace the extract with the value
+ // that is inserted if possible. Specifically, replace if there is no
+ // intervening insert which conflicts.
+ bool EliminateInsertExtract(ir::Function* func);
+
+ // Initialize extensions whitelist
+ void InitExtensions();
+
+ // Return true if all extensions in this module are allowed by this pass.
+ bool AllExtensionsSupported() const;
+
+ void Initialize(ir::Module* module);
+ Pass::Status ProcessImpl();
+
+ // Module this pass is processing
+ ir::Module* module_;
+
+ // Def-Uses for the module we are processing
+ std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
+
+ // Map from function's result id to function
+ std::unordered_map<uint32_t, ir::Function*> id2function_;
+
+ // Extensions supported by this pass.
+ std::unordered_set<std::string> extensions_whitelist_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_INSERT_EXTRACT_ELIM_PASS_H_
+
diff --git a/source/opt/instruction.h b/source/opt/instruction.h
index 0143f7c9..89c9da0b 100644
--- a/source/opt/instruction.h
+++ b/source/opt/instruction.h
@@ -247,14 +247,30 @@ inline void Instruction::ForEachInst(
}
inline void Instruction::ForEachInId(const std::function<void(uint32_t*)>& f) {
- for (auto& opnd : operands_)
- if (opnd.type == SPV_OPERAND_TYPE_ID) f(&opnd.words[0]);
+ for (auto& opnd : operands_) {
+ switch (opnd.type) {
+ case SPV_OPERAND_TYPE_RESULT_ID:
+ case SPV_OPERAND_TYPE_TYPE_ID:
+ break;
+ default:
+ if (spvIsIdType(opnd.type)) f(&opnd.words[0]);
+ break;
+ }
+ }
}
inline void Instruction::ForEachInId(
const std::function<void(const uint32_t*)>& f) const {
- for (const auto& opnd : operands_)
- if (opnd.type == SPV_OPERAND_TYPE_ID) f(&opnd.words[0]);
+ for (const auto& opnd : operands_) {
+ switch (opnd.type) {
+ case SPV_OPERAND_TYPE_RESULT_ID:
+ case SPV_OPERAND_TYPE_TYPE_ID:
+ break;
+ default:
+ if (spvIsIdType(opnd.type)) f(&opnd.words[0]);
+ break;
+ }
+ }
}
inline bool Instruction::HasLabels() const {
diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp
index 187494ff..1093bd8c 100644
--- a/source/opt/local_access_chain_convert_pass.cpp
+++ b/source/opt/local_access_chain_convert_pass.cpp
@@ -14,22 +14,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iterator.h"
#include "local_access_chain_convert_pass.h"
-static const int kSpvEntryPointFunctionId = 1;
-static const int kSpvStorePtrId = 0;
-static const int kSpvStoreValId = 1;
-static const int kSpvLoadPtrId = 0;
-static const int kSpvAccessChainPtrId = 0;
-static const int kSpvTypePointerStorageClass = 0;
-static const int kSpvTypePointerTypeId = 1;
-static const int kSpvConstantValue = 0;
-static const int kSpvTypeIntWidth = 0;
+#include "iterator.h"
namespace spvtools {
namespace opt {
+namespace {
+
+const uint32_t kEntryPointFunctionIdInIdx = 1;
+const uint32_t kStorePtrIdInIdx = 0;
+const uint32_t kStoreValIdInIdx = 1;
+const uint32_t kLoadPtrIdInIdx = 0;
+const uint32_t kAccessChainPtrIdInIdx = 0;
+const uint32_t kTypePointerStorageClassInIdx = 0;
+const uint32_t kTypePointerTypeIdInIdx = 1;
+const uint32_t kConstantValueInIdx = 0;
+const uint32_t kTypeIntWidthInIdx = 0;
+const uint32_t kCopyObjectOperandInIdx = 0;
+
+} // anonymous namespace
+
bool LocalAccessChainConvertPass::IsNonPtrAccessChain(
const SpvOp opcode) const {
return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
@@ -68,14 +74,27 @@ bool LocalAccessChainConvertPass::IsTargetType(
}
ir::Instruction* LocalAccessChainConvertPass::GetPtr(
- ir::Instruction* ip,
- uint32_t* varId) {
- const uint32_t ptrId = ip->GetSingleWordInOperand(
- ip->opcode() == SpvOpStore ? kSpvStorePtrId : kSpvLoadPtrId);
- ir::Instruction* ptrInst = def_use_mgr_->GetDef(ptrId);
- *varId = IsNonPtrAccessChain(ptrInst->opcode()) ?
- ptrInst->GetSingleWordInOperand(kSpvAccessChainPtrId) :
- ptrId;
+ ir::Instruction* ip, uint32_t* varId) {
+ const SpvOp op = ip->opcode();
+ assert(op == SpvOpStore || op == SpvOpLoad);
+ *varId = ip->GetSingleWordInOperand(
+ op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
+ ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
+ while (ptrInst->opcode() == SpvOpCopyObject) {
+ *varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ ptrInst = def_use_mgr_->GetDef(*varId);
+ }
+ ir::Instruction* varInst = ptrInst;
+ while (varInst->opcode() != SpvOpVariable) {
+ if (IsNonPtrAccessChain(varInst->opcode())) {
+ *varId = varInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
+ }
+ else {
+ assert(varInst->opcode() == SpvOpCopyObject);
+ *varId = varInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ }
+ varInst = def_use_mgr_->GetDef(*varId);
+ }
return ptrInst;
}
@@ -89,13 +108,13 @@ bool LocalAccessChainConvertPass::IsTargetVar(uint32_t varId) {
return false;;
const uint32_t varTypeId = varInst->type_id();
const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
- if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
+ if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
SpvStorageClassFunction) {
seen_non_target_vars_.insert(varId);
return false;
}
const uint32_t varPteTypeId =
- varTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
+ varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId);
if (!IsTargetType(varPteTypeInst)) {
seen_non_target_vars_.insert(varId);
@@ -105,12 +124,27 @@ bool LocalAccessChainConvertPass::IsTargetVar(uint32_t varId) {
return true;
}
+bool LocalAccessChainConvertPass::HasOnlyNamesAndDecorates(uint32_t id) const {
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return true;
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return false;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op != SpvOpName && !IsDecorate(op))
+ return false;
+ }
+ return true;
+}
+
void LocalAccessChainConvertPass::DeleteIfUseless(ir::Instruction* inst) {
const uint32_t resId = inst->result_id();
assert(resId != 0);
- analysis::UseList* uses = def_use_mgr_->GetUses(resId);
- if (uses == nullptr)
+ if (HasOnlyNamesAndDecorates(resId)) {
+ KillNamesAndDecorates(resId);
def_use_mgr_->KillInst(inst);
+ }
}
void LocalAccessChainConvertPass::ReplaceAndDeleteLoad(
@@ -118,6 +152,7 @@ void LocalAccessChainConvertPass::ReplaceAndDeleteLoad(
uint32_t replId,
ir::Instruction* ptrInst) {
const uint32_t loadId = loadInst->result_id();
+ KillNamesAndDecorates(loadId);
(void) def_use_mgr_->ReplaceAllUsesWith(loadId, replId);
// remove load instruction
def_use_mgr_->KillInst(loadInst);
@@ -127,11 +162,36 @@ void LocalAccessChainConvertPass::ReplaceAndDeleteLoad(
}
}
+void LocalAccessChainConvertPass::KillNamesAndDecorates(uint32_t id) {
+ // TODO(greg-lunarg): Remove id from any OpGroupDecorate and
+ // kill if no other operands.
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return;
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return;
+ std::list<ir::Instruction*> killList;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op == SpvOpName || IsDecorate(op))
+ killList.push_back(u.inst);
+ }
+ for (auto kip : killList)
+ def_use_mgr_->KillInst(kip);
+}
+
+void LocalAccessChainConvertPass::KillNamesAndDecorates(ir::Instruction* inst) {
+ const uint32_t rId = inst->result_id();
+ if (rId == 0)
+ return;
+ KillNamesAndDecorates(rId);
+}
+
uint32_t LocalAccessChainConvertPass::GetPointeeTypeId(
const ir::Instruction* ptrInst) const {
const uint32_t ptrTypeId = ptrInst->type_id();
const ir::Instruction* ptrTypeInst = def_use_mgr_->GetDef(ptrTypeId);
- return ptrTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
+ return ptrTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
}
void LocalAccessChainConvertPass::BuildAndAppendInst(
@@ -152,7 +212,7 @@ uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
uint32_t* varPteTypeId,
std::vector<std::unique_ptr<ir::Instruction>>* newInsts) {
const uint32_t ldResultId = TakeNextId();
- *varId = ptrInst->GetSingleWordInOperand(kSpvAccessChainPtrId);
+ *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
const ir::Instruction* varInst = def_use_mgr_->GetDef(*varId);
assert(varInst->opcode() == SpvOpVariable);
*varPteTypeId = GetPointeeTypeId(varInst);
@@ -168,7 +228,7 @@ void LocalAccessChainConvertPass::AppendConstantOperands(
ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t *iid) {
if (iidIdx > 0) {
const ir::Instruction* cInst = def_use_mgr_->GetDef(*iid);
- uint32_t val = cInst->GetSingleWordInOperand(kSpvConstantValue);
+ uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx);
in_opnds->push_back(
{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
}
@@ -246,13 +306,23 @@ void LocalAccessChainConvertPass::FindTargetVars(ir::Function* func) {
case SpvOpLoad: {
uint32_t varId;
ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
- // For now, only convert non-ptr access chains
- if (!IsNonPtrAccessChain(ptrInst->opcode()))
+ if (!IsTargetVar(varId))
+ break;
+ // Rule out variables with non-non-ptr access chain refs
+ const SpvOp op = ptrInst->opcode();
+ if (!IsNonPtrAccessChain(op) && op != SpvOpVariable) {
+ seen_non_target_vars_.insert(varId);
+ seen_target_vars_.erase(varId);
break;
- // For now, only convert non-nested access chains
+ }
+ // Rule out variables with nested access chains
// TODO(): Convert nested access chains
- if (!IsTargetVar(varId))
+ if (IsNonPtrAccessChain(op) &&
+ ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx) != varId) {
+ seen_non_target_vars_.insert(varId);
+ seen_target_vars_.erase(varId);
break;
+ }
// Rule out variables accessed with non-constant indices
if (!IsConstantIndexAccessChain(ptrInst)) {
seen_non_target_vars_.insert(varId);
@@ -299,7 +369,7 @@ bool LocalAccessChainConvertPass::ConvertLocalAccessChains(ir::Function* func) {
if (!IsTargetVar(varId))
break;
std::vector<std::unique_ptr<ir::Instruction>> newInsts;
- uint32_t valId = ii->GetSingleWordInOperand(kSpvStoreValId);
+ uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx);
GenAccessChainStoreReplacement(ptrInst, valId, &newInsts);
def_use_mgr_->KillInst(&*ii);
DeleteIfUseless(ptrInst);
@@ -334,21 +404,56 @@ void LocalAccessChainConvertPass::Initialize(ir::Module* module) {
// Initialize next unused Id.
next_id_ = module->id_bound();
+
+ // Initialize extension whitelist
+ InitExtensions();
};
+void LocalAccessChainConvertPass::FindNamedOrDecoratedIds() {
+ for (auto& di : module_->debugs())
+ if (di.opcode() == SpvOpName)
+ named_or_decorated_ids_.insert(di.GetSingleWordInOperand(0));
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpDecorate || ai.opcode() == SpvOpDecorateId)
+ named_or_decorated_ids_.insert(ai.GetSingleWordInOperand(0));
+}
+
+bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
+ // If any extension not in whitelist, return false
+ for (auto& ei : module_->extensions()) {
+ const char* extName = reinterpret_cast<const char*>(
+ &ei.GetInOperand(0).words[0]);
+ if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
+ return false;
+ }
+ return true;
+}
+
Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
// If non-32-bit integer type in module, terminate processing
// TODO(): Handle non-32-bit integer constants in access chains
for (const ir::Instruction& inst : module_->types_values())
if (inst.opcode() == SpvOpTypeInt &&
- inst.GetSingleWordInOperand(kSpvTypeIntWidth) != 32)
+ inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
return Status::SuccessWithoutChange;
+
+ // Do not process if module contains OpGroupDecorate. Additional
+ // support required in KillNamesAndDecorates().
+ // TODO(greg-lunarg): Add support for OpGroupDecorate
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpGroupDecorate)
+ return Status::SuccessWithoutChange;
+ // Do not process if any disallowed extensions are enabled
+ if (!AllExtensionsSupported())
+ return Status::SuccessWithoutChange;
+ // Collect all named and decorated ids
+ FindNamedOrDecoratedIds();
// Process all entry point functions.
bool modified = false;
for (auto& e : module_->entry_points()) {
ir::Function* fn =
- id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
- modified = modified || ConvertLocalAccessChains(fn);
+ id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
+ modified = ConvertLocalAccessChains(fn) || modified;
}
FinalizeNextId(module_);
@@ -364,6 +469,35 @@ Pass::Status LocalAccessChainConvertPass::Process(ir::Module* module) {
return ProcessImpl();
}
+void LocalAccessChainConvertPass::InitExtensions() {
+ extensions_whitelist_.clear();
+ extensions_whitelist_.insert({
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
+ "SPV_AMD_gcn_shader",
+ "SPV_KHR_shader_ballot",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_KHR_shader_draw_parameters",
+ "SPV_KHR_subgroup_vote",
+ "SPV_KHR_16bit_storage",
+ "SPV_KHR_device_group",
+ "SPV_KHR_multiview",
+ "SPV_NVX_multiview_per_view_attributes",
+ "SPV_NV_viewport_array2",
+ "SPV_NV_stereo_view_rendering",
+ "SPV_NV_sample_mask_override_coverage",
+ "SPV_NV_geometry_shader_passthrough",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_KHR_storage_buffer_storage_class",
+ // SPV_KHR_variable_pointers
+ // Currently do not support extended pointer expressions
+ "SPV_AMD_gpu_shader_int16",
+ "SPV_KHR_post_depth_coverage",
+ "SPV_KHR_shader_atomic_counter_ops",
+ });
+}
+
} // namespace opt
} // namespace spvtools
diff --git a/source/opt/local_access_chain_convert_pass.h b/source/opt/local_access_chain_convert_pass.h
index 3a2d6054..c56fc190 100644
--- a/source/opt/local_access_chain_convert_pass.h
+++ b/source/opt/local_access_chain_convert_pass.h
@@ -72,6 +72,23 @@ class LocalAccessChainConvertPass : public Pass {
// variables.
bool IsTargetVar(uint32_t varId);
+ // Return true if |op| is supported decorate.
+ inline bool IsDecorate(uint32_t op) const {
+ return (op == SpvOpDecorate || op == SpvOpDecorateId);
+ }
+
+ // Return true if all uses of |id| are only name or decorate ops.
+ bool HasOnlyNamesAndDecorates(uint32_t id) const;
+
+ // Kill all name and decorate ops using |inst|
+ void KillNamesAndDecorates(ir::Instruction* inst);
+
+ // Kill all name and decorate ops using |id|
+ void KillNamesAndDecorates(uint32_t id);
+
+ // Collect all named or decorated ids in module
+ void FindNamedOrDecoratedIds();
+
// Delete |inst| if it has no uses. Assumes |inst| has a non-zero resultId.
void DeleteIfUseless(ir::Instruction* inst);
@@ -128,6 +145,12 @@ class LocalAccessChainConvertPass : public Pass {
// converted.
bool ConvertLocalAccessChains(ir::Function* func);
+ // Initialize extensions whitelist
+ void InitExtensions();
+
+ // Return true if all extensions in this module are allowed by this pass.
+ bool AllExtensionsSupported() const;
+
// Save next available id into |module|.
inline void FinalizeNextId(ir::Module* module) {
module->SetIdBound(next_id_);
@@ -156,6 +179,12 @@ class LocalAccessChainConvertPass : public Pass {
// Cache of verified non-target vars
std::unordered_set<uint32_t> seen_non_target_vars_;
+ // named or decorated ids
+ std::unordered_set<uint32_t> named_or_decorated_ids_;
+
+ // Extensions supported by this pass.
+ std::unordered_set<std::string> extensions_whitelist_;
+
// Next unused ID
uint32_t next_id_;
};
diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp
index b18b08d0..54c19ca3 100644
--- a/source/opt/local_single_block_elim_pass.cpp
+++ b/source/opt/local_single_block_elim_pass.cpp
@@ -14,20 +14,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "iterator.h"
#include "local_single_block_elim_pass.h"
-static const int kSpvEntryPointFunctionId = 1;
-static const int kSpvStorePtrId = 0;
-static const int kSpvStoreValId = 1;
-static const int kSpvLoadPtrId = 0;
-static const int kSpvAccessChainPtrId = 0;
-static const int kSpvTypePointerStorageClass = 0;
-static const int kSpvTypePointerTypeId = 1;
+#include "iterator.h"
namespace spvtools {
namespace opt {
+namespace {
+
+const uint32_t kEntryPointFunctionIdInIdx = 1;
+const uint32_t kStorePtrIdInIdx = 0;
+const uint32_t kStoreValIdInIdx = 1;
+const uint32_t kLoadPtrIdInIdx = 0;
+const uint32_t kAccessChainPtrIdInIdx = 0;
+const uint32_t kTypePointerStorageClassInIdx = 0;
+const uint32_t kTypePointerTypeIdInIdx = 1;
+const uint32_t kCopyObjectOperandInIdx = 0;
+
+} // anonymous namespace
+
bool LocalSingleBlockLoadStoreElimPass::IsNonPtrAccessChain(
const SpvOp opcode) const {
return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
@@ -67,12 +73,24 @@ bool LocalSingleBlockLoadStoreElimPass::IsTargetType(
ir::Instruction* LocalSingleBlockLoadStoreElimPass::GetPtr(
ir::Instruction* ip, uint32_t* varId) {
+ const SpvOp op = ip->opcode();
+ assert(op == SpvOpStore || op == SpvOpLoad);
*varId = ip->GetSingleWordInOperand(
- ip->opcode() == SpvOpStore ? kSpvStorePtrId : kSpvLoadPtrId);
+ op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
+ while (ptrInst->opcode() == SpvOpCopyObject) {
+ *varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ ptrInst = def_use_mgr_->GetDef(*varId);
+ }
ir::Instruction* varInst = ptrInst;
- while (IsNonPtrAccessChain(varInst->opcode())) {
- *varId = varInst->GetSingleWordInOperand(kSpvAccessChainPtrId);
+ while (varInst->opcode() != SpvOpVariable) {
+ if (IsNonPtrAccessChain(varInst->opcode())) {
+ *varId = varInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
+ }
+ else {
+ assert(varInst->opcode() == SpvOpCopyObject);
+ *varId = varInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ }
varInst = def_use_mgr_->GetDef(*varId);
}
return ptrInst;
@@ -87,13 +105,13 @@ bool LocalSingleBlockLoadStoreElimPass::IsTargetVar(uint32_t varId) {
assert(varInst->opcode() == SpvOpVariable);
const uint32_t varTypeId = varInst->type_id();
const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
- if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
+ if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
SpvStorageClassFunction) {
seen_non_target_vars_.insert(varId);
return false;
}
const uint32_t varPteTypeId =
- varTypeInst->GetSingleWordInOperand(kSpvTypePointerTypeId);
+ varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId);
if (!IsTargetType(varPteTypeInst)) {
seen_non_target_vars_.insert(varId);
@@ -106,6 +124,7 @@ bool LocalSingleBlockLoadStoreElimPass::IsTargetVar(uint32_t varId) {
void LocalSingleBlockLoadStoreElimPass::ReplaceAndDeleteLoad(
ir::Instruction* loadInst, uint32_t replId) {
const uint32_t loadId = loadInst->result_id();
+ KillNamesAndDecorates(loadId);
(void) def_use_mgr_->ReplaceAllUsesWith(loadId, replId);
// TODO(greg-lunarg): Consider moving DCE into separate pass
DCEInst(loadInst);
@@ -137,7 +156,7 @@ bool LocalSingleBlockLoadStoreElimPass::IsLiveVar(uint32_t varId) const {
assert(varInst->opcode() == SpvOpVariable);
const uint32_t varTypeId = varInst->type_id();
const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
- if (varTypeInst->GetSingleWordInOperand(kSpvTypePointerStorageClass) !=
+ if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
SpvStorageClassFunction)
return true;
// test if variable is loaded from
@@ -165,6 +184,47 @@ void LocalSingleBlockLoadStoreElimPass::AddStores(
}
}
+bool LocalSingleBlockLoadStoreElimPass::HasOnlyNamesAndDecorates(
+ uint32_t id) const {
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return true;
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return false;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op != SpvOpName && !IsDecorate(op))
+ return false;
+ }
+ return true;
+}
+
+void LocalSingleBlockLoadStoreElimPass::KillNamesAndDecorates(uint32_t id) {
+ // TODO(greg-lunarg): Remove id from any OpGroupDecorate and
+ // kill if no other operands.
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return;
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return;
+ std::list<ir::Instruction*> killList;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op == SpvOpName || IsDecorate(op))
+ killList.push_back(u.inst);
+ }
+ for (auto kip : killList)
+ def_use_mgr_->KillInst(kip);
+}
+
+void LocalSingleBlockLoadStoreElimPass::KillNamesAndDecorates(
+ ir::Instruction* inst) {
+ const uint32_t rId = inst->result_id();
+ if (rId == 0)
+ return;
+ KillNamesAndDecorates(rId);
+}
+
void LocalSingleBlockLoadStoreElimPass::DCEInst(ir::Instruction* inst) {
std::queue<ir::Instruction*> deadInsts;
deadInsts.push(inst);
@@ -184,12 +244,12 @@ void LocalSingleBlockLoadStoreElimPass::DCEInst(ir::Instruction* inst) {
// Remember variable if dead load
if (di->opcode() == SpvOpLoad)
(void) GetPtr(di, &varId);
+ KillNamesAndDecorates(di);
def_use_mgr_->KillInst(di);
// For all operands with no remaining uses, add their instruction
// to the dead instruction queue.
for (auto id : ids) {
- analysis::UseList* uses = def_use_mgr_->GetUses(id);
- if (uses == nullptr)
+ if (HasOnlyNamesAndDecorates(id))
deadInsts.push(def_use_mgr_->GetDef(id));
}
// if a load was deleted and it was the variable's
@@ -200,14 +260,24 @@ void LocalSingleBlockLoadStoreElimPass::DCEInst(ir::Instruction* inst) {
}
}
+bool LocalSingleBlockLoadStoreElimPass::HasOnlySupportedRefs(uint32_t ptrId) {
+ if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end())
+ return true;
+ analysis::UseList* uses = def_use_mgr_->GetUses(ptrId);
+ assert(uses != nullptr);
+ for (auto u : *uses) {
+ SpvOp op = u.inst->opcode();
+ if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
+ if (!HasOnlySupportedRefs(u.inst->result_id())) return false;
+ } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName)
+ return false;
+ }
+ supported_ref_ptrs_.insert(ptrId);
+ return true;
+}
+
bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim(
ir::Function* func) {
- // Verify no CopyObject ops in function. This is a pre-SSA pass and
- // is generally not useful for code already in CSSA form.
- for (auto& blk : *func)
- for (auto& inst : blk)
- if (inst.opcode() == SpvOpCopyObject)
- return false;
// Perform local store/load and load/load elimination on each block
bool modified = false;
for (auto bi = func->begin(); bi != func->end(); ++bi) {
@@ -222,6 +292,8 @@ bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim(
ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
if (!IsTargetVar(varId))
continue;
+ if (!HasOnlySupportedRefs(varId))
+ continue;
// Register the store
if (ptrInst->opcode() == SpvOpVariable) {
// if not pinned, look for WAW
@@ -246,12 +318,14 @@ bool LocalSingleBlockLoadStoreElimPass::LocalSingleBlockLoadStoreElim(
ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
if (!IsTargetVar(varId))
continue;
+ if (!HasOnlySupportedRefs(varId))
+ continue;
// Look for previous store or load
uint32_t replId = 0;
if (ptrInst->opcode() == SpvOpVariable) {
auto si = var2store_.find(varId);
if (si != var2store_.end()) {
- replId = si->second->GetSingleWordInOperand(kSpvStoreValId);
+ replId = si->second->GetSingleWordInOperand(kStoreValIdInIdx);
}
else {
auto li = var2load_.find(varId);
@@ -308,24 +382,61 @@ void LocalSingleBlockLoadStoreElimPass::Initialize(ir::Module* module) {
seen_target_vars_.clear();
seen_non_target_vars_.clear();
- // TODO(): Reuse def/use from previous passes
+ // Clear collections
+ supported_ref_ptrs_.clear();
+
+ // TODO(greg-lunarg): Reuse def/use from previous passes
def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
// Start new ids with next availablein module
next_id_ = module_->id_bound();
+ // Initialize extensions whitelist
+ InitExtensions();
};
+void LocalSingleBlockLoadStoreElimPass::FindNamedOrDecoratedIds() {
+ for (auto& di : module_->debugs())
+ if (di.opcode() == SpvOpName)
+ named_or_decorated_ids_.insert(di.GetSingleWordInOperand(0));
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpDecorate || ai.opcode() == SpvOpDecorateId)
+ named_or_decorated_ids_.insert(ai.GetSingleWordInOperand(0));
+}
+
+bool LocalSingleBlockLoadStoreElimPass::AllExtensionsSupported() const {
+ // If any extension not in whitelist, return false
+ for (auto& ei : module_->extensions()) {
+ const char* extName = reinterpret_cast<const char*>(
+ &ei.GetInOperand(0).words[0]);
+ if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
+ return false;
+ }
+ return true;
+}
+
Pass::Status LocalSingleBlockLoadStoreElimPass::ProcessImpl() {
// Assumes logical addressing only
if (module_->HasCapability(SpvCapabilityAddresses))
return Status::SuccessWithoutChange;
+ // Do not process if module contains OpGroupDecorate. Additional
+ // support required in KillNamesAndDecorates().
+ // TODO(greg-lunarg): Add support for OpGroupDecorate
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpGroupDecorate)
+ return Status::SuccessWithoutChange;
+ // If any extensions in the module are not explicitly supported,
+ // return unmodified.
+ if (!AllExtensionsSupported())
+ return Status::SuccessWithoutChange;
+ // Collect all named and decorated ids
+ FindNamedOrDecoratedIds();
+ // Process all entry point functions
bool modified = false;
- // Call Mem2Reg on all remaining functions.
for (auto& e : module_->entry_points()) {
ir::Function* fn =
- id2function_[e.GetSingleWordOperand(kSpvEntryPointFunctionId)];
- modified = modified || LocalSingleBlockLoadStoreElim(fn);
+ id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
+ modified = LocalSingleBlockLoadStoreElim(fn) || modified;
}
FinalizeNextId(module_);
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
@@ -339,6 +450,34 @@ Pass::Status LocalSingleBlockLoadStoreElimPass::Process(ir::Module* module) {
return ProcessImpl();
}
+void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
+ extensions_whitelist_.clear();
+ extensions_whitelist_.insert({
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
+ "SPV_AMD_gcn_shader",
+ "SPV_KHR_shader_ballot",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_KHR_shader_draw_parameters",
+ "SPV_KHR_subgroup_vote",
+ "SPV_KHR_16bit_storage",
+ "SPV_KHR_device_group",
+ "SPV_KHR_multiview",
+ "SPV_NVX_multiview_per_view_attributes",
+ "SPV_NV_viewport_array2",
+ "SPV_NV_stereo_view_rendering",
+ "SPV_NV_sample_mask_override_coverage",
+ "SPV_NV_geometry_shader_passthrough",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_KHR_storage_buffer_storage_class",
+ // SPV_KHR_variable_pointers
+ // Currently do not support extended pointer expressions
+ "SPV_AMD_gpu_shader_int16",
+ "SPV_KHR_post_depth_coverage",
+ "SPV_KHR_shader_atomic_counter_ops",
+ });
+}
+
} // namespace opt
} // namespace spvtools
-
diff --git a/source/opt/local_single_block_elim_pass.h b/source/opt/local_single_block_elim_pass.h
index b5a14f42..674b699f 100644
--- a/source/opt/local_single_block_elim_pass.h
+++ b/source/opt/local_single_block_elim_pass.h
@@ -81,10 +81,31 @@ class LocalSingleBlockLoadStoreElimPass : public Pass {
// Add stores using |ptr_id| to |insts|
void AddStores(uint32_t ptr_id, std::queue<ir::Instruction*>* insts);
+ // Return true if |op| is supported decorate.
+ inline bool IsDecorate(uint32_t op) const {
+ return (op == SpvOpDecorate || op == SpvOpDecorateId);
+ }
+
+ // Return true if all uses of |id| are only name or decorate ops.
+ bool HasOnlyNamesAndDecorates(uint32_t id) const;
+
+ // Kill all name and decorate ops using |inst|
+ void KillNamesAndDecorates(ir::Instruction* inst);
+
+ // Kill all name and decorate ops using |id|
+ void KillNamesAndDecorates(uint32_t id);
+
+ // Collect all named or decorated ids in module
+ void FindNamedOrDecoratedIds();
+
// Delete |inst| and iterate DCE on all its operands. Won't delete
// labels.
void DCEInst(ir::Instruction* inst);
+ // Return true if all uses of |varId| are only through supported reference
+ // operations ie. loads and store. Also cache in supported_ref_ptrs_;
+ bool HasOnlySupportedRefs(uint32_t varId);
+
// On all entry point functions, within each basic block, eliminate
// loads and stores to function variables where possible. For
// loads, if previous load or store to same variable, replace
@@ -103,6 +124,12 @@ class LocalSingleBlockLoadStoreElimPass : public Pass {
return next_id_++;
}
+ // Initialize extensions whitelist
+ void InitExtensions();
+
+ // Return true if all extensions in this module are supported by this pass.
+ bool AllExtensionsSupported() const;
+
void Initialize(ir::Module* module);
Pass::Status ProcessImpl();
@@ -142,6 +169,16 @@ class LocalSingleBlockLoadStoreElimPass : public Pass {
// from this set each time a new store of that variable is encountered.
std::unordered_set<uint32_t> pinned_vars_;
+ // named or decorated ids
+ std::unordered_set<uint32_t> named_or_decorated_ids_;
+
+ // Extensions supported by this pass.
+ std::unordered_set<std::string> extensions_whitelist_;
+
+ // Variables that are only referenced by supported operations for this
+ // pass ie. loads and stores.
+ std::unordered_set<uint32_t> supported_ref_ptrs_;
+
// Next unused ID
uint32_t next_id_;
};
diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp
new file mode 100644
index 00000000..ed11c441
--- /dev/null
+++ b/source/opt/local_single_store_elim_pass.cpp
@@ -0,0 +1,585 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "local_single_store_elim_pass.h"
+
+#include "cfa.h"
+#include "iterator.h"
+#include "spirv/1.0/GLSL.std.450.h"
+
+// Universal Limit of ResultID + 1
+static const int kInvalidId = 0x400000;
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+
+const uint32_t kEntryPointFunctionIdInIdx = 1;
+const uint32_t kStorePtrIdInIdx = 0;
+const uint32_t kStoreValIdInIdx = 1;
+const uint32_t kLoadPtrIdInIdx = 0;
+const uint32_t kAccessChainPtrIdInIdx = 0;
+const uint32_t kTypePointerStorageClassInIdx = 0;
+const uint32_t kTypePointerTypeIdInIdx = 1;
+const uint32_t kCopyObjectOperandInIdx = 0;
+
+} // anonymous namespace
+
+bool LocalSingleStoreElimPass::IsNonPtrAccessChain(const SpvOp opcode) const {
+ return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
+}
+
+bool LocalSingleStoreElimPass::IsMathType(
+ const ir::Instruction* typeInst) const {
+ switch (typeInst->opcode()) {
+ case SpvOpTypeInt:
+ case SpvOpTypeFloat:
+ case SpvOpTypeBool:
+ case SpvOpTypeVector:
+ case SpvOpTypeMatrix:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+bool LocalSingleStoreElimPass::IsTargetType(
+ const ir::Instruction* typeInst) const {
+ if (IsMathType(typeInst))
+ return true;
+ if (typeInst->opcode() == SpvOpTypeArray)
+ return IsMathType(def_use_mgr_->GetDef(typeInst->GetSingleWordOperand(1)));
+ if (typeInst->opcode() != SpvOpTypeStruct)
+ return false;
+ // All struct members must be math type
+ int nonMathComp = 0;
+ typeInst->ForEachInId([&nonMathComp,this](const uint32_t* tid) {
+ ir::Instruction* compTypeInst = def_use_mgr_->GetDef(*tid);
+ if (!IsMathType(compTypeInst)) ++nonMathComp;
+ });
+ return nonMathComp == 0;
+}
+
+ir::Instruction* LocalSingleStoreElimPass::GetPtr(
+ ir::Instruction* ip, uint32_t* varId) {
+ const SpvOp op = ip->opcode();
+ assert(op == SpvOpStore || op == SpvOpLoad);
+ *varId = ip->GetSingleWordInOperand(
+ op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
+ ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
+ while (ptrInst->opcode() == SpvOpCopyObject) {
+ *varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ ptrInst = def_use_mgr_->GetDef(*varId);
+ }
+ ir::Instruction* varInst = ptrInst;
+ while (varInst->opcode() != SpvOpVariable) {
+ if (IsNonPtrAccessChain(varInst->opcode())) {
+ *varId = varInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
+ }
+ else {
+ assert(varInst->opcode() == SpvOpCopyObject);
+ *varId = varInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ }
+ varInst = def_use_mgr_->GetDef(*varId);
+ }
+ return ptrInst;
+}
+
+bool LocalSingleStoreElimPass::IsTargetVar(uint32_t varId) {
+ if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end())
+ return false;
+ if (seen_target_vars_.find(varId) != seen_target_vars_.end())
+ return true;
+ const ir::Instruction* varInst = def_use_mgr_->GetDef(varId);
+ assert(varInst->opcode() == SpvOpVariable);
+ const uint32_t varTypeId = varInst->type_id();
+ const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
+ if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
+ SpvStorageClassFunction) {
+ seen_non_target_vars_.insert(varId);
+ return false;
+ }
+ const uint32_t varPteTypeId =
+ varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
+ ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId);
+ if (!IsTargetType(varPteTypeInst)) {
+ seen_non_target_vars_.insert(varId);
+ return false;
+ }
+ seen_target_vars_.insert(varId);
+ return true;
+}
+
+bool LocalSingleStoreElimPass::HasOnlySupportedRefs(uint32_t ptrId) {
+ if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end())
+ return true;
+ analysis::UseList* uses = def_use_mgr_->GetUses(ptrId);
+ assert(uses != nullptr);
+ for (auto u : *uses) {
+ SpvOp op = u.inst->opcode();
+ if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
+ if (!HasOnlySupportedRefs(u.inst->result_id())) return false;
+ } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName)
+ return false;
+ }
+ supported_ref_ptrs_.insert(ptrId);
+ return true;
+}
+
+void LocalSingleStoreElimPass::SingleStoreAnalyze(ir::Function* func) {
+ ssa_var2store_.clear();
+ non_ssa_vars_.clear();
+ store2idx_.clear();
+ store2blk_.clear();
+ for (auto bi = func->begin(); bi != func->end(); ++bi) {
+ uint32_t instIdx = 0;
+ for (auto ii = bi->begin(); ii != bi->end(); ++ii, ++instIdx) {
+ switch (ii->opcode()) {
+ case SpvOpStore: {
+ // Verify store variable is target type
+ uint32_t varId;
+ ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
+ if (non_ssa_vars_.find(varId) != non_ssa_vars_.end())
+ continue;
+ if (!HasOnlySupportedRefs(varId)) {
+ non_ssa_vars_.insert(varId);
+ continue;
+ }
+ if (ptrInst->opcode() != SpvOpVariable) {
+ non_ssa_vars_.insert(varId);
+ ssa_var2store_.erase(varId);
+ continue;
+ }
+ // Verify target type and function storage class
+ if (!IsTargetVar(varId)) {
+ non_ssa_vars_.insert(varId);
+ continue;
+ }
+ // Ignore variables with multiple stores
+ if (ssa_var2store_.find(varId) != ssa_var2store_.end()) {
+ non_ssa_vars_.insert(varId);
+ ssa_var2store_.erase(varId);
+ continue;
+ }
+ // Remember pointer to variable's store and it's
+ // ordinal position in block
+ ssa_var2store_[varId] = &*ii;
+ store2idx_[&*ii] = instIdx;
+ store2blk_[&*ii] = &*bi;
+ } break;
+ default:
+ break;
+ } // switch
+ }
+ }
+}
+
+void LocalSingleStoreElimPass::ReplaceAndDeleteLoad(
+ ir::Instruction* loadInst, uint32_t replId) {
+ const uint32_t loadId = loadInst->result_id();
+ KillNamesAndDecorates(loadId);
+ (void) def_use_mgr_->ReplaceAllUsesWith(loadId, replId);
+ DCEInst(loadInst);
+}
+
+LocalSingleStoreElimPass::GetBlocksFunction
+LocalSingleStoreElimPass::AugmentedCFGSuccessorsFunction() const {
+ return [this](const ir::BasicBlock* block) {
+ auto asmi = augmented_successors_map_.find(block);
+ if (asmi != augmented_successors_map_.end())
+ return &(*asmi).second;
+ auto smi = successors_map_.find(block);
+ return &(*smi).second;
+ };
+}
+
+LocalSingleStoreElimPass::GetBlocksFunction
+LocalSingleStoreElimPass::AugmentedCFGPredecessorsFunction() const {
+ return [this](const ir::BasicBlock* block) {
+ auto apmi = augmented_predecessors_map_.find(block);
+ if (apmi != augmented_predecessors_map_.end())
+ return &(*apmi).second;
+ auto pmi = predecessors_map_.find(block);
+ return &(*pmi).second;
+ };
+}
+
+void LocalSingleStoreElimPass::CalculateImmediateDominators(
+ ir::Function* func) {
+ // Compute CFG
+ vector<ir::BasicBlock*> ordered_blocks;
+ predecessors_map_.clear();
+ successors_map_.clear();
+ for (auto& blk : *func) {
+ ordered_blocks.push_back(&blk);
+ blk.ForEachSuccessorLabel([&blk, &ordered_blocks, this](uint32_t sbid) {
+ successors_map_[&blk].push_back(label2block_[sbid]);
+ predecessors_map_[label2block_[sbid]].push_back(&blk);
+ });
+ }
+ // Compute Augmented CFG
+ augmented_successors_map_.clear();
+ augmented_predecessors_map_.clear();
+ successors_map_[&pseudo_exit_block_] = {};
+ predecessors_map_[&pseudo_entry_block_] = {};
+ auto succ_func = [this](const ir::BasicBlock* b)
+ { return &successors_map_[b]; };
+ auto pred_func = [this](const ir::BasicBlock* b)
+ { return &predecessors_map_[b]; };
+ CFA<ir::BasicBlock>::ComputeAugmentedCFG(
+ ordered_blocks,
+ &pseudo_entry_block_,
+ &pseudo_exit_block_,
+ &augmented_successors_map_,
+ &augmented_predecessors_map_,
+ succ_func,
+ pred_func);
+ // Compute Dominators
+ vector<const ir::BasicBlock*> postorder;
+ auto ignore_block = [](cbb_ptr) {};
+ auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
+ spvtools::CFA<ir::BasicBlock>::DepthFirstTraversal(
+ ordered_blocks[0], AugmentedCFGSuccessorsFunction(),
+ ignore_block, [&](cbb_ptr b) { postorder.push_back(b); },
+ ignore_edge);
+ auto edges = spvtools::CFA<ir::BasicBlock>::CalculateDominators(
+ postorder, AugmentedCFGPredecessorsFunction());
+ idom_.clear();
+ for (auto edge : edges)
+ idom_[edge.first] = edge.second;
+}
+
+bool LocalSingleStoreElimPass::Dominates(
+ ir::BasicBlock* blk0, uint32_t idx0,
+ ir::BasicBlock* blk1, uint32_t idx1) {
+ if (blk0 == blk1)
+ return idx0 <= idx1;
+ ir::BasicBlock* b = blk1;
+ while (idom_[b] != b) {
+ b = idom_[b];
+ if (b == blk0)
+ return true;
+ }
+ return false;
+}
+
+bool LocalSingleStoreElimPass::SingleStoreProcess(ir::Function* func) {
+ CalculateImmediateDominators(func);
+ bool modified = false;
+ for (auto bi = func->begin(); bi != func->end(); ++bi) {
+ uint32_t instIdx = 0;
+ for (auto ii = bi->begin(); ii != bi->end(); ++ii, ++instIdx) {
+ if (ii->opcode() != SpvOpLoad)
+ continue;
+ uint32_t varId;
+ ir::Instruction* ptrInst = GetPtr(&*ii, &varId);
+ // Skip access chain loads
+ if (ptrInst->opcode() != SpvOpVariable)
+ continue;
+ const auto vsi = ssa_var2store_.find(varId);
+ if (vsi == ssa_var2store_.end())
+ continue;
+ if (non_ssa_vars_.find(varId) != non_ssa_vars_.end())
+ continue;
+ // store must dominate load
+ if (!Dominates(store2blk_[vsi->second], store2idx_[vsi->second], &*bi, instIdx))
+ continue;
+ // Use store value as replacement id
+ uint32_t replId = vsi->second->GetSingleWordInOperand(kStoreValIdInIdx);
+ // replace all instances of the load's id with the SSA value's id
+ ReplaceAndDeleteLoad(&*ii, replId);
+ modified = true;
+ }
+ }
+ return modified;
+}
+
+bool LocalSingleStoreElimPass::HasLoads(uint32_t varId) const {
+ analysis::UseList* uses = def_use_mgr_->GetUses(varId);
+ if (uses == nullptr)
+ return false;
+ for (auto u : *uses) {
+ SpvOp op = u.inst->opcode();
+ // TODO(): The following is slightly conservative. Could be
+ // better handling of non-store/name.
+ if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
+ if (HasLoads(u.inst->result_id()))
+ return true;
+ }
+ else if (op != SpvOpStore && op != SpvOpName)
+ return true;
+ }
+ return false;
+}
+
+bool LocalSingleStoreElimPass::IsLiveVar(uint32_t varId) const {
+ // non-function scope vars are live
+ const ir::Instruction* varInst = def_use_mgr_->GetDef(varId);
+ assert(varInst->opcode() == SpvOpVariable);
+ const uint32_t varTypeId = varInst->type_id();
+ const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
+ if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
+ SpvStorageClassFunction)
+ return true;
+ // test if variable is loaded from
+ return HasLoads(varId);
+}
+
+bool LocalSingleStoreElimPass::IsLiveStore(ir::Instruction* storeInst) {
+ // get store's variable
+ uint32_t varId;
+ (void) GetPtr(storeInst, &varId);
+ return IsLiveVar(varId);
+}
+
+void LocalSingleStoreElimPass::AddStores(
+ uint32_t ptr_id, std::queue<ir::Instruction*>* insts) {
+ analysis::UseList* uses = def_use_mgr_->GetUses(ptr_id);
+ if (uses != nullptr) {
+ for (auto u : *uses) {
+ if (IsNonPtrAccessChain(u.inst->opcode()))
+ AddStores(u.inst->result_id(), insts);
+ else if (u.inst->opcode() == SpvOpStore)
+ insts->push(u.inst);
+ }
+ }
+}
+
+bool LocalSingleStoreElimPass::HasOnlyNamesAndDecorates(
+ uint32_t id) const {
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return true;
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return false;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op != SpvOpName && !IsDecorate(op))
+ return false;
+ }
+ return true;
+}
+
+void LocalSingleStoreElimPass::KillNamesAndDecorates(uint32_t id) {
+ // TODO(greg-lunarg): Remove id from any OpGroupDecorate and
+ // kill if no other operands.
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return;
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return;
+ std::list<ir::Instruction*> killList;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op == SpvOpName || IsDecorate(op))
+ killList.push_back(u.inst);
+ }
+ for (auto kip : killList)
+ def_use_mgr_->KillInst(kip);
+}
+
+void LocalSingleStoreElimPass::KillNamesAndDecorates(
+ ir::Instruction* inst) {
+ const uint32_t rId = inst->result_id();
+ if (rId == 0)
+ return;
+ KillNamesAndDecorates(rId);
+}
+
+void LocalSingleStoreElimPass::DCEInst(ir::Instruction* inst) {
+ std::queue<ir::Instruction*> deadInsts;
+ deadInsts.push(inst);
+ while (!deadInsts.empty()) {
+ ir::Instruction* di = deadInsts.front();
+ // Don't delete labels
+ if (di->opcode() == SpvOpLabel) {
+ deadInsts.pop();
+ continue;
+ }
+ // Remember operands
+ std::vector<uint32_t> ids;
+ di->ForEachInId([&ids](uint32_t* iid) {
+ ids.push_back(*iid);
+ });
+ uint32_t varId = 0;
+ // Remember variable if dead load
+ if (di->opcode() == SpvOpLoad)
+ (void) GetPtr(di, &varId);
+ KillNamesAndDecorates(di);
+ def_use_mgr_->KillInst(di);
+ // For all operands with no remaining uses, add their instruction
+ // to the dead instruction queue.
+ for (auto id : ids) {
+ if (HasOnlyNamesAndDecorates(id))
+ deadInsts.push(def_use_mgr_->GetDef(id));
+ }
+ // if a load was deleted and it was the variable's
+ // last load, add all its stores to dead queue
+ if (varId != 0 && !IsLiveVar(varId))
+ AddStores(varId, &deadInsts);
+ deadInsts.pop();
+ }
+}
+
+bool LocalSingleStoreElimPass::SingleStoreDCE() {
+ bool modified = false;
+ for (auto v : ssa_var2store_) {
+ // check that it hasn't already been DCE'd
+ if (v.second->opcode() != SpvOpStore)
+ continue;
+ if (non_ssa_vars_.find(v.first) != non_ssa_vars_.end())
+ continue;
+ if (!IsLiveStore(v.second)) {
+ DCEInst(v.second);
+ modified = true;
+ }
+ }
+ return modified;
+}
+
+bool LocalSingleStoreElimPass::LocalSingleStoreElim(ir::Function* func) {
+ bool modified = false;
+ SingleStoreAnalyze(func);
+ if (ssa_var2store_.empty())
+ return false;
+ modified |= SingleStoreProcess(func);
+ modified |= SingleStoreDCE();
+ return modified;
+}
+
+void LocalSingleStoreElimPass::Initialize(ir::Module* module) {
+ module_ = module;
+
+ // Initialize function and block maps
+ id2function_.clear();
+ label2block_.clear();
+ for (auto& fn : *module_) {
+ id2function_[fn.result_id()] = &fn;
+ for (auto& blk : fn) {
+ uint32_t bid = blk.id();
+ label2block_[bid] = &blk;
+ }
+ }
+
+ // Initialize Target Type Caches
+ seen_target_vars_.clear();
+ seen_non_target_vars_.clear();
+
+ // Initialize Supported Ref Pointer Cache
+ supported_ref_ptrs_.clear();
+
+ // TODO: Reuse def/use (and other state) from previous passes
+ def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
+
+ // Initialize next unused Id
+ next_id_ = module_->id_bound();
+
+ // Initialize extension whitelist
+ InitExtensions();
+};
+
+void LocalSingleStoreElimPass::FindNamedOrDecoratedIds() {
+ for (auto& di : module_->debugs())
+ if (di.opcode() == SpvOpName)
+ named_or_decorated_ids_.insert(di.GetSingleWordInOperand(0));
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpDecorate || ai.opcode() == SpvOpDecorateId)
+ named_or_decorated_ids_.insert(ai.GetSingleWordInOperand(0));
+}
+
+bool LocalSingleStoreElimPass::AllExtensionsSupported() const {
+ // If any extension not in whitelist, return false
+ for (auto& ei : module_->extensions()) {
+ const char* extName = reinterpret_cast<const char*>(
+ &ei.GetInOperand(0).words[0]);
+ if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
+ return false;
+ }
+ return true;
+}
+
+Pass::Status LocalSingleStoreElimPass::ProcessImpl() {
+ // Assumes logical addressing only
+ if (module_->HasCapability(SpvCapabilityAddresses))
+ return Status::SuccessWithoutChange;
+ // Do not process if module contains OpGroupDecorate. Additional
+ // support required in KillNamesAndDecorates().
+ // TODO(greg-lunarg): Add support for OpGroupDecorate
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpGroupDecorate)
+ return Status::SuccessWithoutChange;
+ // Do not process if any disallowed extensions are enabled
+ if (!AllExtensionsSupported())
+ return Status::SuccessWithoutChange;
+ // Collect all named and decorated ids
+ FindNamedOrDecoratedIds();
+ // Process all entry point functions
+ bool modified = false;
+ for (auto& e : module_->entry_points()) {
+ ir::Function* fn =
+ id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
+ modified = LocalSingleStoreElim(fn) || modified;
+ }
+ FinalizeNextId(module_);
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+LocalSingleStoreElimPass::LocalSingleStoreElimPass()
+ : module_(nullptr), def_use_mgr_(nullptr),
+ pseudo_entry_block_(std::unique_ptr<ir::Instruction>(
+ new ir::Instruction(SpvOpLabel, 0, 0, {}))),
+ pseudo_exit_block_(std::unique_ptr<ir::Instruction>(
+ new ir::Instruction(SpvOpLabel, 0, kInvalidId, {}))),
+ next_id_(0) {}
+
+Pass::Status LocalSingleStoreElimPass::Process(ir::Module* module) {
+ Initialize(module);
+ return ProcessImpl();
+}
+
+void LocalSingleStoreElimPass::InitExtensions() {
+ extensions_whitelist_.clear();
+ extensions_whitelist_.insert({
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
+ "SPV_AMD_gcn_shader",
+ "SPV_KHR_shader_ballot",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_KHR_shader_draw_parameters",
+ "SPV_KHR_subgroup_vote",
+ "SPV_KHR_16bit_storage",
+ "SPV_KHR_device_group",
+ "SPV_KHR_multiview",
+ "SPV_NVX_multiview_per_view_attributes",
+ "SPV_NV_viewport_array2",
+ "SPV_NV_stereo_view_rendering",
+ "SPV_NV_sample_mask_override_coverage",
+ "SPV_NV_geometry_shader_passthrough",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_KHR_storage_buffer_storage_class",
+ // SPV_KHR_variable_pointers
+ // Currently do not support extended pointer expressions
+ "SPV_AMD_gpu_shader_int16",
+ "SPV_KHR_post_depth_coverage",
+ "SPV_KHR_shader_atomic_counter_ops",
+ });
+}
+
+} // namespace opt
+} // namespace spvtools
diff --git a/source/opt/local_single_store_elim_pass.h b/source/opt/local_single_store_elim_pass.h
new file mode 100644
index 00000000..84b9ea6f
--- /dev/null
+++ b/source/opt/local_single_store_elim_pass.h
@@ -0,0 +1,248 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_
+#define LIBSPIRV_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_
+
+
+#include <algorithm>
+#include <map>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "basic_block.h"
+#include "def_use_manager.h"
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class LocalSingleStoreElimPass : public Pass {
+ using cbb_ptr = const ir::BasicBlock*;
+
+ public:
+ LocalSingleStoreElimPass();
+ const char* name() const override { return "eliminate-local-single-store"; }
+ Status Process(ir::Module*) override;
+
+ private:
+ // Returns true if |opcode| is a non-ptr access chain op
+ bool IsNonPtrAccessChain(const SpvOp opcode) const;
+
+ // Returns true if |typeInst| is a scalar type
+ // or a vector or matrix
+ bool IsMathType(const ir::Instruction* typeInst) const;
+
+ // Returns true if |typeInst| is a math type or a struct or array
+ // of a math type.
+ bool IsTargetType(const ir::Instruction* typeInst) const;
+
+ // Given a load or store |ip|, return the pointer instruction.
+ // Also return the base variable's id in |varId|.
+ ir::Instruction* GetPtr(ir::Instruction* ip, uint32_t* varId);
+
+ // Return true if |varId| is a previously identified target variable.
+ // Return false if |varId| is a previously identified non-target variable.
+ // If variable is not cached, return true if variable is a function scope
+ // variable of target type, false otherwise. Updates caches of target
+ // and non-target variables.
+ bool IsTargetVar(uint32_t varId);
+
+ // Return true if all refs through |ptrId| are only loads or stores and
+ // cache ptrId in supported_ref_ptrs_.
+ bool HasOnlySupportedRefs(uint32_t ptrId);
+
+ // Find all function scope variables in |func| that are stored to
+ // only once (SSA) and map to their stored value id. Only analyze
+ // variables of scalar, vector, matrix types and struct and array
+ // types comprising only these types. Currently this analysis is
+ // is not done in the presence of function calls. TODO(): Allow
+ // analysis in the presence of function calls.
+ void SingleStoreAnalyze(ir::Function* func);
+
+ // Replace all instances of |loadInst|'s id with |replId| and delete
+ // |loadInst|.
+ void ReplaceAndDeleteLoad(ir::Instruction* loadInst, uint32_t replId);
+
+ using GetBlocksFunction =
+ std::function<const std::vector<ir::BasicBlock*>*(const ir::BasicBlock*)>;
+
+ /// Returns the block successors function for the augmented CFG.
+ GetBlocksFunction AugmentedCFGSuccessorsFunction() const;
+
+ /// Returns the block predecessors function for the augmented CFG.
+ GetBlocksFunction AugmentedCFGPredecessorsFunction() const;
+
+ // Calculate immediate dominators for |func|'s CFG. Leaves result
+ // in idom_. Entries for augmented CFG (pseudo blocks) are not created.
+ void CalculateImmediateDominators(ir::Function* func);
+
+ // Return true if instruction in |blk0| at ordinal position |idx0|
+ // dominates instruction in |blk1| at position |idx1|.
+ bool Dominates(ir::BasicBlock* blk0, uint32_t idx0,
+ ir::BasicBlock* blk1, uint32_t idx1);
+
+ // For each load of an SSA variable in |func|, replace all uses of
+ // the load with the value stored if the store dominates the load.
+ // Assumes that SingleStoreAnalyze() has just been run. Return true
+ // if any instructions are modified.
+ bool SingleStoreProcess(ir::Function* func);
+
+ // Return true if any instruction loads from |varId|
+ bool HasLoads(uint32_t varId) const;
+
+ // Return true if |varId| is not a function variable or if it has
+ // a load
+ bool IsLiveVar(uint32_t varId) const;
+
+ // Return true if |storeInst| is not a function variable or if its
+ // base variable has a load
+ bool IsLiveStore(ir::Instruction* storeInst);
+
+ // Add stores using |ptr_id| to |insts|
+ void AddStores(uint32_t ptr_id, std::queue<ir::Instruction*>* insts);
+
+ // Return true if |op| is supported decorate.
+ inline bool IsDecorate(uint32_t op) const {
+ return (op == SpvOpDecorate || op == SpvOpDecorateId);
+ }
+
+ // Return true if all uses of |id| are only name or decorate ops.
+ bool HasOnlyNamesAndDecorates(uint32_t id) const;
+
+ // Kill all name and decorate ops using |inst|
+ void KillNamesAndDecorates(ir::Instruction* inst);
+
+ // Kill all name and decorate ops using |id|
+ void KillNamesAndDecorates(uint32_t id);
+
+ // Collect all named or decorated ids in module
+ void FindNamedOrDecoratedIds();
+
+ // Delete |inst| and iterate DCE on all its operands if they are now
+ // useless. If a load is deleted and its variable has no other loads,
+ // delete all its variable's stores.
+ void DCEInst(ir::Instruction* inst);
+
+ // Remove all stores to useless SSA variables. Remove useless
+ // access chains and variables as well. Assumes SingleStoreAnalyze
+ // and SingleStoreProcess has been run.
+ bool SingleStoreDCE();
+
+ // Do "single-store" optimization of function variables defined only
+ // with a single non-access-chain store in |func|. Replace all their
+ // non-access-chain loads with the value that is stored and eliminate
+ // any resulting dead code.
+ bool LocalSingleStoreElim(ir::Function* func);
+
+ // Initialize extensions whitelist
+ void InitExtensions();
+
+ // Return true if all extensions in this module are allowed by this pass.
+ bool AllExtensionsSupported() const;
+
+ // Save next available id into |module|.
+ inline void FinalizeNextId(ir::Module* module) {
+ module->SetIdBound(next_id_);
+ }
+
+ // Return next available id and generate next.
+ inline uint32_t TakeNextId() {
+ return next_id_++;
+ }
+
+ void Initialize(ir::Module* module);
+ Pass::Status ProcessImpl();
+
+ // Module this pass is processing
+ ir::Module* module_;
+
+ // Def-Uses for the module we are processing
+ std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
+
+ // Map from function's result id to function
+ std::unordered_map<uint32_t, ir::Function*> id2function_;
+
+ // Map from block's label id to block
+ std::unordered_map<uint32_t, ir::BasicBlock*> label2block_;
+
+ // Map from SSA Variable to its single store
+ std::unordered_map<uint32_t, ir::Instruction*> ssa_var2store_;
+
+ // Map from store to its ordinal position in its block.
+ std::unordered_map<ir::Instruction*, uint32_t> store2idx_;
+
+ // Map from store to its block.
+ std::unordered_map<ir::Instruction*, ir::BasicBlock*> store2blk_;
+
+ // Set of non-SSA Variables
+ std::unordered_set<uint32_t> non_ssa_vars_;
+
+ // Cache of previously seen target types
+ std::unordered_set<uint32_t> seen_target_vars_;
+
+ // Cache of previously seen non-target types
+ std::unordered_set<uint32_t> seen_non_target_vars_;
+
+ // Variables with only supported references, ie. loads and stores using
+ // variable directly or through non-ptr access chains.
+ std::unordered_set<uint32_t> supported_ref_ptrs_;
+
+ // Augmented CFG Entry Block
+ ir::BasicBlock pseudo_entry_block_;
+
+ // Augmented CFG Exit Block
+ ir::BasicBlock pseudo_exit_block_;
+
+ // CFG Predecessors
+ std::unordered_map<const ir::BasicBlock*, std::vector<ir::BasicBlock*>>
+ predecessors_map_;
+
+ // CFG Successors
+ std::unordered_map<const ir::BasicBlock*, std::vector<ir::BasicBlock*>>
+ successors_map_;
+
+ // CFG Augmented Predecessors
+ std::unordered_map<const ir::BasicBlock*, std::vector<ir::BasicBlock*>>
+ augmented_predecessors_map_;
+
+ // CFG Augmented Successors
+ std::unordered_map<const ir::BasicBlock*, std::vector<ir::BasicBlock*>>
+ augmented_successors_map_;
+
+ // Immediate Dominator Map
+ // If block has no idom it points to itself.
+ std::unordered_map<ir::BasicBlock*, ir::BasicBlock*> idom_;
+
+ // named or decorated ids
+ std::unordered_set<uint32_t> named_or_decorated_ids_;
+
+ // Extensions supported by this pass.
+ std::unordered_set<std::string> extensions_whitelist_;
+
+ // Next unused ID
+ uint32_t next_id_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_LOCAL_SINGLE_STORE_ELIM_PASS_H_
+
diff --git a/source/opt/local_ssa_elim_pass.cpp b/source/opt/local_ssa_elim_pass.cpp
new file mode 100644
index 00000000..60cb5c97
--- /dev/null
+++ b/source/opt/local_ssa_elim_pass.cpp
@@ -0,0 +1,825 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "local_ssa_elim_pass.h"
+
+#include "iterator.h"
+#include "cfa.h"
+
+namespace spvtools {
+namespace opt {
+
+namespace {
+
+const uint32_t kEntryPointFunctionIdInIdx = 1;
+const uint32_t kStorePtrIdInIdx = 0;
+const uint32_t kStoreValIdInIdx = 1;
+const uint32_t kLoadPtrIdInIdx = 0;
+const uint32_t kAccessChainPtrIdInIdx = 0;
+const uint32_t kTypePointerStorageClassInIdx = 0;
+const uint32_t kTypePointerTypeIdInIdx = 1;
+const uint32_t kSelectionMergeMergeBlockIdInIdx = 0;
+const uint32_t kLoopMergeMergeBlockIdInIdx = 0;
+const uint32_t kLoopMergeContinueBlockIdInIdx = 1;
+const uint32_t kCopyObjectOperandInIdx = 0;
+
+} // anonymous namespace
+
+bool LocalMultiStoreElimPass::IsNonPtrAccessChain(
+ const SpvOp opcode) const {
+ return opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain;
+}
+
+bool LocalMultiStoreElimPass::IsMathType(
+ const ir::Instruction* typeInst) const {
+ switch (typeInst->opcode()) {
+ case SpvOpTypeInt:
+ case SpvOpTypeFloat:
+ case SpvOpTypeBool:
+ case SpvOpTypeVector:
+ case SpvOpTypeMatrix:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+bool LocalMultiStoreElimPass::IsTargetType(
+ const ir::Instruction* typeInst) const {
+ if (IsMathType(typeInst))
+ return true;
+ if (typeInst->opcode() == SpvOpTypeArray)
+ return IsMathType(def_use_mgr_->GetDef(typeInst->GetSingleWordOperand(1)));
+ if (typeInst->opcode() != SpvOpTypeStruct)
+ return false;
+ // All struct members must be math type
+ int nonMathComp = 0;
+ typeInst->ForEachInId([&nonMathComp,this](const uint32_t* tid) {
+ const ir::Instruction* compTypeInst = def_use_mgr_->GetDef(*tid);
+ if (!IsMathType(compTypeInst)) ++nonMathComp;
+ });
+ return nonMathComp == 0;
+}
+
+ir::Instruction* LocalMultiStoreElimPass::GetPtr(
+ ir::Instruction* ip, uint32_t* varId) {
+ const SpvOp op = ip->opcode();
+ assert(op == SpvOpStore || op == SpvOpLoad);
+ *varId = ip->GetSingleWordInOperand(
+ op == SpvOpStore ? kStorePtrIdInIdx : kLoadPtrIdInIdx);
+ ir::Instruction* ptrInst = def_use_mgr_->GetDef(*varId);
+ while (ptrInst->opcode() == SpvOpCopyObject) {
+ *varId = ptrInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ ptrInst = def_use_mgr_->GetDef(*varId);
+ }
+ ir::Instruction* varInst = ptrInst;
+ while (varInst->opcode() != SpvOpVariable) {
+ if (IsNonPtrAccessChain(varInst->opcode())) {
+ *varId = varInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
+ }
+ else {
+ assert(varInst->opcode() == SpvOpCopyObject);
+ *varId = varInst->GetSingleWordInOperand(kCopyObjectOperandInIdx);
+ }
+ varInst = def_use_mgr_->GetDef(*varId);
+ }
+ return ptrInst;
+}
+
+bool LocalMultiStoreElimPass::IsTargetVar(uint32_t varId) {
+ if (seen_non_target_vars_.find(varId) != seen_non_target_vars_.end())
+ return false;
+ if (seen_target_vars_.find(varId) != seen_target_vars_.end())
+ return true;
+ const ir::Instruction* varInst = def_use_mgr_->GetDef(varId);
+ assert(varInst->opcode() == SpvOpVariable);
+ const uint32_t varTypeId = varInst->type_id();
+ const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
+ if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
+ SpvStorageClassFunction) {
+ seen_non_target_vars_.insert(varId);
+ return false;
+ }
+ const uint32_t varPteTypeId =
+ varTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
+ ir::Instruction* varPteTypeInst = def_use_mgr_->GetDef(varPteTypeId);
+ if (!IsTargetType(varPteTypeInst)) {
+ seen_non_target_vars_.insert(varId);
+ return false;
+ }
+ seen_target_vars_.insert(varId);
+ return true;
+}
+
+bool LocalMultiStoreElimPass::HasLoads(uint32_t ptrId) const {
+ analysis::UseList* uses = def_use_mgr_->GetUses(ptrId);
+ if (uses == nullptr)
+ return false;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
+ if (HasLoads(u.inst->result_id()))
+ return true;
+ }
+ else {
+ // Conservatively assume that any non-store use is a load
+ // TODO(greg-lunarg): Improve analysis around function calls, etc
+ if (op != SpvOpStore && op != SpvOpName && !IsDecorate(op))
+ return true;
+ }
+ }
+ return false;
+}
+
+bool LocalMultiStoreElimPass::IsLiveVar(uint32_t varId) const {
+ // non-function scope vars are live
+ const ir::Instruction* varInst = def_use_mgr_->GetDef(varId);
+ assert(varInst->opcode() == SpvOpVariable);
+ const uint32_t varTypeId = varInst->type_id();
+ const ir::Instruction* varTypeInst = def_use_mgr_->GetDef(varTypeId);
+ if (varTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx) !=
+ SpvStorageClassFunction)
+ return true;
+ // test if variable is loaded from
+ return HasLoads(varId);
+}
+
+void LocalMultiStoreElimPass::AddStores(
+ uint32_t ptr_id, std::queue<ir::Instruction*>* insts) {
+ analysis::UseList* uses = def_use_mgr_->GetUses(ptr_id);
+ if (uses != nullptr) {
+ for (auto u : *uses) {
+ if (IsNonPtrAccessChain(u.inst->opcode()))
+ AddStores(u.inst->result_id(), insts);
+ else if (u.inst->opcode() == SpvOpStore)
+ insts->push(u.inst);
+ }
+ }
+}
+
+bool LocalMultiStoreElimPass::HasOnlyNamesAndDecorates(uint32_t id) const {
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return true;
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return false;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op != SpvOpName && !IsDecorate(op))
+ return false;
+ }
+ return true;
+}
+
+void LocalMultiStoreElimPass::KillNamesAndDecorates(uint32_t id) {
+ // TODO(greg-lunarg): Remove id from any OpGroupDecorate and
+ // kill if no other operands.
+ if (named_or_decorated_ids_.find(id) == named_or_decorated_ids_.end())
+ return;
+ analysis::UseList* uses = def_use_mgr_->GetUses(id);
+ if (uses == nullptr)
+ return;
+ std::list<ir::Instruction*> killList;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op != SpvOpName && !IsDecorate(op))
+ continue;
+ killList.push_back(u.inst);
+ }
+ for (auto kip : killList)
+ def_use_mgr_->KillInst(kip);
+}
+
+void LocalMultiStoreElimPass::KillNamesAndDecorates(ir::Instruction* inst) {
+ const uint32_t rId = inst->result_id();
+ if (rId == 0)
+ return;
+ KillNamesAndDecorates(rId);
+}
+
+void LocalMultiStoreElimPass::DCEInst(ir::Instruction* inst) {
+ std::queue<ir::Instruction*> deadInsts;
+ deadInsts.push(inst);
+ while (!deadInsts.empty()) {
+ ir::Instruction* di = deadInsts.front();
+ // Don't delete labels
+ if (di->opcode() == SpvOpLabel) {
+ deadInsts.pop();
+ continue;
+ }
+ // Remember operands
+ std::vector<uint32_t> ids;
+ di->ForEachInId([&ids](uint32_t* iid) {
+ ids.push_back(*iid);
+ });
+ uint32_t varId = 0;
+ // Remember variable if dead load
+ if (di->opcode() == SpvOpLoad)
+ (void) GetPtr(di, &varId);
+ KillNamesAndDecorates(di);
+ def_use_mgr_->KillInst(di);
+ // For all operands with no remaining uses, add their instruction
+ // to the dead instruction queue.
+ for (auto id : ids)
+ if (HasOnlyNamesAndDecorates(id))
+ deadInsts.push(def_use_mgr_->GetDef(id));
+ // if a load was deleted and it was the variable's
+ // last load, add all its stores to dead queue
+ if (varId != 0 && !IsLiveVar(varId))
+ AddStores(varId, &deadInsts);
+ deadInsts.pop();
+ }
+}
+
+bool LocalMultiStoreElimPass::HasOnlySupportedRefs(uint32_t varId) {
+ if (supported_ref_vars_.find(varId) != supported_ref_vars_.end())
+ return true;
+ analysis::UseList* uses = def_use_mgr_->GetUses(varId);
+ if (uses == nullptr)
+ return true;
+ for (auto u : *uses) {
+ const SpvOp op = u.inst->opcode();
+ if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
+ !IsDecorate(op))
+ return false;
+ }
+ supported_ref_vars_.insert(varId);
+ return true;
+}
+
+void LocalMultiStoreElimPass::InitSSARewrite(ir::Function& func) {
+ // Init predecessors
+ label2preds_.clear();
+ for (auto& blk : func) {
+ uint32_t blkId = blk.id();
+ blk.ForEachSuccessorLabel([&blkId, this](uint32_t sbid) {
+ label2preds_[sbid].push_back(blkId);
+ });
+ }
+ // Collect target (and non-) variable sets. Remove variables with
+ // non-load/store refs from target variable set
+ for (auto& blk : func) {
+ for (auto& inst : blk) {
+ switch (inst.opcode()) {
+ case SpvOpStore:
+ case SpvOpLoad: {
+ uint32_t varId;
+ (void) GetPtr(&inst, &varId);
+ if (!IsTargetVar(varId))
+ break;
+ if (HasOnlySupportedRefs(varId))
+ break;
+ seen_non_target_vars_.insert(varId);
+ seen_target_vars_.erase(varId);
+ } break;
+ default:
+ break;
+ }
+ }
+ }
+}
+
+uint32_t LocalMultiStoreElimPass::MergeBlockIdIfAny(const ir::BasicBlock& blk,
+ uint32_t* cbid) {
+ auto merge_ii = blk.cend();
+ --merge_ii;
+ *cbid = 0;
+ uint32_t mbid = 0;
+ if (merge_ii != blk.cbegin()) {
+ --merge_ii;
+ if (merge_ii->opcode() == SpvOpLoopMerge) {
+ mbid = merge_ii->GetSingleWordInOperand(kLoopMergeMergeBlockIdInIdx);
+ *cbid = merge_ii->GetSingleWordInOperand(kLoopMergeContinueBlockIdInIdx);
+ }
+ else if (merge_ii->opcode() == SpvOpSelectionMerge) {
+ mbid = merge_ii->GetSingleWordInOperand(kSelectionMergeMergeBlockIdInIdx);
+ }
+ }
+ return mbid;
+}
+
+void LocalMultiStoreElimPass::ComputeStructuredSuccessors(ir::Function* func) {
+ for (auto& blk : *func) {
+ // If no predecessors in function, make successor to pseudo entry
+ if (label2preds_[blk.id()].size() == 0)
+ block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
+ // If header, make merge block first successor.
+ uint32_t cbid;
+ const uint32_t mbid = MergeBlockIdIfAny(blk, &cbid);
+ if (mbid != 0) {
+ block2structured_succs_[&blk].push_back(id2block_[mbid]);
+ if (cbid != 0)
+ block2structured_succs_[&blk].push_back(id2block_[cbid]);
+ }
+ // add true successors
+ blk.ForEachSuccessorLabel([&blk, this](uint32_t sbid) {
+ block2structured_succs_[&blk].push_back(id2block_[sbid]);
+ });
+ }
+}
+
+void LocalMultiStoreElimPass::ComputeStructuredOrder(
+ ir::Function* func, std::list<ir::BasicBlock*>* order) {
+ // Compute structured successors and do DFS
+ ComputeStructuredSuccessors(func);
+ auto ignore_block = [](cbb_ptr) {};
+ auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
+ auto get_structured_successors = [this](const ir::BasicBlock* block) {
+ return &(block2structured_succs_[block]); };
+ // TODO(greg-lunarg): Get rid of const_cast by making moving const
+ // out of the cfa.h prototypes and into the invoking code.
+ auto post_order = [&](cbb_ptr b) {
+ order->push_front(const_cast<ir::BasicBlock*>(b)); };
+
+ spvtools::CFA<ir::BasicBlock>::DepthFirstTraversal(
+ &pseudo_entry_block_, get_structured_successors, ignore_block,
+ post_order, ignore_edge);
+}
+
+void LocalMultiStoreElimPass::SSABlockInitSinglePred(ir::BasicBlock* block_ptr) {
+ // Copy map entry from single predecessor
+ const uint32_t label = block_ptr->id();
+ const uint32_t predLabel = label2preds_[label].front();
+ assert(visitedBlocks_.find(predLabel) != visitedBlocks_.end());
+ label2ssa_map_[label] = label2ssa_map_[predLabel];
+}
+
+bool LocalMultiStoreElimPass::IsLiveAfter(uint32_t var_id, uint32_t label) const {
+ // For now, return very conservative result: true. This will result in
+ // correct, but possibly usused, phi code to be generated. A subsequent
+ // DCE pass should eliminate this code.
+ // TODO(greg-lunarg): Return more accurate information
+ (void) var_id;
+ (void) label;
+ return true;
+}
+
+uint32_t LocalMultiStoreElimPass::Type2Undef(uint32_t type_id) {
+ const auto uitr = type2undefs_.find(type_id);
+ if (uitr != type2undefs_.end())
+ return uitr->second;
+ const uint32_t undefId = TakeNextId();
+ std::unique_ptr<ir::Instruction> undef_inst(
+ new ir::Instruction(SpvOpUndef, type_id, undefId, {}));
+ def_use_mgr_->AnalyzeInstDefUse(&*undef_inst);
+ module_->AddGlobalValue(std::move(undef_inst));
+ type2undefs_[type_id] = undefId;
+ return undefId;
+}
+
+uint32_t LocalMultiStoreElimPass::GetPointeeTypeId(
+ const ir::Instruction* ptrInst) const {
+ const uint32_t ptrTypeId = ptrInst->type_id();
+ const ir::Instruction* ptrTypeInst = def_use_mgr_->GetDef(ptrTypeId);
+ return ptrTypeInst->GetSingleWordInOperand(kTypePointerTypeIdInIdx);
+}
+
+void LocalMultiStoreElimPass::SSABlockInitLoopHeader(
+ std::list<ir::BasicBlock*>::iterator block_itr) {
+ const uint32_t label = (*block_itr)->id();
+ // Determine backedge label.
+ uint32_t backLabel = 0;
+ for (uint32_t predLabel : label2preds_[label])
+ if (visitedBlocks_.find(predLabel) == visitedBlocks_.end()) {
+ assert(backLabel == 0);
+ backLabel = predLabel;
+ break;
+ }
+ assert(backLabel != 0);
+ // Determine merge block.
+ auto mergeInst = (*block_itr)->end();
+ --mergeInst;
+ --mergeInst;
+ uint32_t mergeLabel = mergeInst->GetSingleWordInOperand(
+ kLoopMergeMergeBlockIdInIdx);
+ // Collect all live variables and a default value for each across all
+ // non-backedge predecesors. Must be ordered map because phis are
+ // generated based on order and test results will otherwise vary across
+ // platforms.
+ std::map<uint32_t, uint32_t> liveVars;
+ for (uint32_t predLabel : label2preds_[label]) {
+ for (auto var_val : label2ssa_map_[predLabel]) {
+ uint32_t varId = var_val.first;
+ liveVars[varId] = var_val.second;
+ }
+ }
+ // Add all stored variables in loop. Set their default value id to zero.
+ for (auto bi = block_itr; (*bi)->id() != mergeLabel; ++bi) {
+ ir::BasicBlock* bp = *bi;
+ for (auto ii = bp->begin(); ii != bp->end(); ++ii) {
+ if (ii->opcode() != SpvOpStore)
+ continue;
+ uint32_t varId;
+ (void) GetPtr(&*ii, &varId);
+ if (!IsTargetVar(varId))
+ continue;
+ liveVars[varId] = 0;
+ }
+ }
+ // Insert phi for all live variables that require them. All variables
+ // defined in loop require a phi. Otherwise all variables
+ // with differing predecessor values require a phi.
+ auto insertItr = (*block_itr)->begin();
+ for (auto var_val : liveVars) {
+ const uint32_t varId = var_val.first;
+ if (!IsLiveAfter(varId, label))
+ continue;
+ const uint32_t val0Id = var_val.second;
+ bool needsPhi = false;
+ if (val0Id != 0) {
+ for (uint32_t predLabel : label2preds_[label]) {
+ // Skip back edge predecessor.
+ if (predLabel == backLabel)
+ continue;
+ const auto var_val_itr = label2ssa_map_[predLabel].find(varId);
+ // Missing (undef) values always cause difference with (defined) value
+ if (var_val_itr == label2ssa_map_[predLabel].end()) {
+ needsPhi = true;
+ break;
+ }
+ if (var_val_itr->second != val0Id) {
+ needsPhi = true;
+ break;
+ }
+ }
+ }
+ else {
+ needsPhi = true;
+ }
+ // If val is the same for all predecessors, enter it in map
+ if (!needsPhi) {
+ label2ssa_map_[label].insert(var_val);
+ continue;
+ }
+ // Val differs across predecessors. Add phi op to block and
+ // add its result id to the map. For back edge predecessor,
+ // use the variable id. We will patch this after visiting back
+ // edge predecessor. For predecessors that do not define a value,
+ // use undef.
+ std::vector<ir::Operand> phi_in_operands;
+ uint32_t typeId = GetPointeeTypeId(def_use_mgr_->GetDef(varId));
+ for (uint32_t predLabel : label2preds_[label]) {
+ uint32_t valId;
+ if (predLabel == backLabel) {
+ valId = varId;
+ }
+ else {
+ const auto var_val_itr = label2ssa_map_[predLabel].find(varId);
+ if (var_val_itr == label2ssa_map_[predLabel].end())
+ valId = Type2Undef(typeId);
+ else
+ valId = var_val_itr->second;
+ }
+ phi_in_operands.push_back(
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}});
+ phi_in_operands.push_back(
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {predLabel}});
+ }
+ const uint32_t phiId = TakeNextId();
+ std::unique_ptr<ir::Instruction> newPhi(
+ new ir::Instruction(SpvOpPhi, typeId, phiId, phi_in_operands));
+ // Only analyze the phi define now; analyze the phi uses after the
+ // phi backedge predecessor value is patched.
+ def_use_mgr_->AnalyzeInstDef(&*newPhi);
+ insertItr = insertItr.InsertBefore(std::move(newPhi));
+ ++insertItr;
+ label2ssa_map_[label].insert({ varId, phiId });
+ }
+}
+
+void LocalMultiStoreElimPass::SSABlockInitMultiPred(ir::BasicBlock* block_ptr) {
+ const uint32_t label = block_ptr->id();
+ // Collect all live variables and a default value for each across all
+ // predecesors. Must be ordered map because phis are generated based on
+ // order and test results will otherwise vary across platforms.
+ std::map<uint32_t, uint32_t> liveVars;
+ for (uint32_t predLabel : label2preds_[label]) {
+ assert(visitedBlocks_.find(predLabel) != visitedBlocks_.end());
+ for (auto var_val : label2ssa_map_[predLabel]) {
+ const uint32_t varId = var_val.first;
+ liveVars[varId] = var_val.second;
+ }
+ }
+ // For each live variable, look for a difference in values across
+ // predecessors that would require a phi and insert one.
+ auto insertItr = block_ptr->begin();
+ for (auto var_val : liveVars) {
+ const uint32_t varId = var_val.first;
+ if (!IsLiveAfter(varId, label))
+ continue;
+ const uint32_t val0Id = var_val.second;
+ bool differs = false;
+ for (uint32_t predLabel : label2preds_[label]) {
+ const auto var_val_itr = label2ssa_map_[predLabel].find(varId);
+ // Missing values cause a difference because we'll need to create an
+ // undef for that predecessor.
+ if (var_val_itr == label2ssa_map_[predLabel].end()) {
+ differs = true;
+ break;
+ }
+ if (var_val_itr->second != val0Id) {
+ differs = true;
+ break;
+ }
+ }
+ // If val is the same for all predecessors, enter it in map
+ if (!differs) {
+ label2ssa_map_[label].insert(var_val);
+ continue;
+ }
+ // Val differs across predecessors. Add phi op to block and
+ // add its result id to the map
+ std::vector<ir::Operand> phi_in_operands;
+ const uint32_t typeId = GetPointeeTypeId(def_use_mgr_->GetDef(varId));
+ for (uint32_t predLabel : label2preds_[label]) {
+ const auto var_val_itr = label2ssa_map_[predLabel].find(varId);
+ // If variable not defined on this path, use undef
+ const uint32_t valId = (var_val_itr != label2ssa_map_[predLabel].end()) ?
+ var_val_itr->second : Type2Undef(typeId);
+ phi_in_operands.push_back(
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}});
+ phi_in_operands.push_back(
+ {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {predLabel}});
+ }
+ const uint32_t phiId = TakeNextId();
+ std::unique_ptr<ir::Instruction> newPhi(
+ new ir::Instruction(SpvOpPhi, typeId, phiId, phi_in_operands));
+ def_use_mgr_->AnalyzeInstDefUse(&*newPhi);
+ insertItr = insertItr.InsertBefore(std::move(newPhi));
+ ++insertItr;
+ label2ssa_map_[label].insert({varId, phiId});
+ }
+}
+
+bool LocalMultiStoreElimPass::IsLoopHeader(ir::BasicBlock* block_ptr) const {
+ auto iItr = block_ptr->end();
+ --iItr;
+ if (iItr == block_ptr->begin())
+ return false;
+ --iItr;
+ return iItr->opcode() == SpvOpLoopMerge;
+}
+
+void LocalMultiStoreElimPass::SSABlockInit(
+ std::list<ir::BasicBlock*>::iterator block_itr) {
+ const size_t numPreds = label2preds_[(*block_itr)->id()].size();
+ if (numPreds == 0)
+ return;
+ if (numPreds == 1)
+ SSABlockInitSinglePred(*block_itr);
+ else if (IsLoopHeader(*block_itr))
+ SSABlockInitLoopHeader(block_itr);
+ else
+ SSABlockInitMultiPred(*block_itr);
+}
+
+void LocalMultiStoreElimPass::PatchPhis(uint32_t header_id, uint32_t back_id) {
+ ir::BasicBlock* header = id2block_[header_id];
+ auto phiItr = header->begin();
+ for (; phiItr->opcode() == SpvOpPhi; ++phiItr) {
+ uint32_t cnt = 0;
+ uint32_t idx;
+ phiItr->ForEachInId([&cnt,&back_id,&idx](uint32_t* iid) {
+ if (cnt % 2 == 1 && *iid == back_id) idx = cnt - 1;
+ ++cnt;
+ });
+ // Use undef if variable not in backedge predecessor map
+ const uint32_t varId = phiItr->GetSingleWordInOperand(idx);
+ const auto valItr = label2ssa_map_[back_id].find(varId);
+ uint32_t valId = (valItr != label2ssa_map_[back_id].end()) ?
+ valItr->second :
+ Type2Undef(GetPointeeTypeId(def_use_mgr_->GetDef(varId)));
+ phiItr->SetInOperand(idx, { valId });
+ // Analyze uses now that they are complete
+ def_use_mgr_->AnalyzeInstUse(&*phiItr);
+ }
+}
+
+bool LocalMultiStoreElimPass::EliminateMultiStoreLocal(ir::Function* func) {
+ InitSSARewrite(*func);
+ // Process all blocks in structured order. This is just one way (the
+ // simplest?) to make sure all predecessors blocks are processed before
+ // a block itself.
+ std::list<ir::BasicBlock*> structuredOrder;
+ ComputeStructuredOrder(func, &structuredOrder);
+ bool modified = false;
+ for (auto bi = structuredOrder.begin(); bi != structuredOrder.end(); ++bi) {
+ // Skip pseudo entry block
+ if (*bi == &pseudo_entry_block_)
+ continue;
+ // Initialize this block's label2ssa_map_ entry using predecessor maps.
+ // Then process all stores and loads of targeted variables.
+ SSABlockInit(bi);
+ ir::BasicBlock* bp = *bi;
+ const uint32_t label = bp->id();
+ for (auto ii = bp->begin(); ii != bp->end(); ++ii) {
+ switch (ii->opcode()) {
+ case SpvOpStore: {
+ uint32_t varId;
+ (void) GetPtr(&*ii, &varId);
+ if (!IsTargetVar(varId))
+ break;
+ // Register new stored value for the variable
+ label2ssa_map_[label][varId] =
+ ii->GetSingleWordInOperand(kStoreValIdInIdx);
+ } break;
+ case SpvOpLoad: {
+ uint32_t varId;
+ (void) GetPtr(&*ii, &varId);
+ if (!IsTargetVar(varId))
+ break;
+ uint32_t replId = 0;
+ const auto ssaItr = label2ssa_map_.find(label);
+ if (ssaItr != label2ssa_map_.end()) {
+ const auto valItr = ssaItr->second.find(varId);
+ if (valItr != ssaItr->second.end())
+ replId = valItr->second;
+ }
+ // If variable is not defined, use undef
+ if (replId == 0) {
+ replId = Type2Undef(GetPointeeTypeId(def_use_mgr_->GetDef(varId)));
+ }
+ // Replace load's id with the last stored value id for variable
+ // and delete load. Kill any names or decorates using id before
+ // replacing to prevent incorrect replacement in those instructions.
+ const uint32_t loadId = ii->result_id();
+ KillNamesAndDecorates(loadId);
+ (void)def_use_mgr_->ReplaceAllUsesWith(loadId, replId);
+ def_use_mgr_->KillInst(&*ii);
+ modified = true;
+ } break;
+ default: {
+ } break;
+ }
+ }
+ visitedBlocks_.insert(label);
+ // Look for successor backedge and patch phis in loop header
+ // if found.
+ uint32_t header = 0;
+ bp->ForEachSuccessorLabel([&header,this](uint32_t succ) {
+ if (visitedBlocks_.find(succ) == visitedBlocks_.end()) return;
+ assert(header == 0);
+ header = succ;
+ });
+ if (header != 0)
+ PatchPhis(header, label);
+ }
+ // Remove all target variable stores.
+ for (auto bi = func->begin(); bi != func->end(); ++bi) {
+ for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
+ if (ii->opcode() != SpvOpStore)
+ continue;
+ uint32_t varId;
+ (void) GetPtr(&*ii, &varId);
+ if (!IsTargetVar(varId))
+ continue;
+ assert(!HasLoads(varId));
+ DCEInst(&*ii);
+ modified = true;
+ }
+ }
+ return modified;
+}
+
+void LocalMultiStoreElimPass::Initialize(ir::Module* module) {
+
+ module_ = module;
+
+ // TODO(greg-lunarg): Reuse def/use from previous passes
+ def_use_mgr_.reset(new analysis::DefUseManager(consumer(), module_));
+
+ // Initialize function and block maps
+ id2function_.clear();
+ id2block_.clear();
+ block2structured_succs_.clear();
+ for (auto& fn : *module_) {
+ id2function_[fn.result_id()] = &fn;
+ for (auto& blk : fn)
+ id2block_[blk.id()] = &blk;
+ }
+
+ // Clear collections
+ seen_target_vars_.clear();
+ seen_non_target_vars_.clear();
+ visitedBlocks_.clear();
+ type2undefs_.clear();
+ supported_ref_vars_.clear();
+ block2structured_succs_.clear();
+ label2preds_.clear();
+ label2ssa_map_.clear();
+
+ // Start new ids with next availablein module
+ next_id_ = module_->id_bound();
+
+ // Initialize extension whitelist
+ InitExtensions();
+};
+
+bool LocalMultiStoreElimPass::AllExtensionsSupported() const {
+ // If any extension not in whitelist, return false
+ for (auto& ei : module_->extensions()) {
+ const char* extName = reinterpret_cast<const char*>(
+ &ei.GetInOperand(0).words[0]);
+ if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
+ return false;
+ }
+ return true;
+}
+
+void LocalMultiStoreElimPass::FindNamedOrDecoratedIds() {
+ for (auto& di : module_->debugs())
+ if (di.opcode() == SpvOpName)
+ named_or_decorated_ids_.insert(di.GetSingleWordInOperand(0));
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpDecorate || ai.opcode() == SpvOpDecorateId)
+ named_or_decorated_ids_.insert(ai.GetSingleWordInOperand(0));
+}
+
+Pass::Status LocalMultiStoreElimPass::ProcessImpl() {
+ // Assumes all control flow structured.
+ // TODO(greg-lunarg): Do SSA rewrite for non-structured control flow
+ if (!module_->HasCapability(SpvCapabilityShader))
+ return Status::SuccessWithoutChange;
+ // Assumes logical addressing only
+ // TODO(greg-lunarg): Add support for physical addressing
+ if (module_->HasCapability(SpvCapabilityAddresses))
+ return Status::SuccessWithoutChange;
+ // Do not process if module contains OpGroupDecorate. Additional
+ // support required in KillNamesAndDecorates().
+ // TODO(greg-lunarg): Add support for OpGroupDecorate
+ for (auto& ai : module_->annotations())
+ if (ai.opcode() == SpvOpGroupDecorate)
+ return Status::SuccessWithoutChange;
+ // Do not process if any disallowed extensions are enabled
+ if (!AllExtensionsSupported())
+ return Status::SuccessWithoutChange;
+ // Collect all named and decorated ids
+ FindNamedOrDecoratedIds();
+ // Process functions
+ bool modified = false;
+ for (auto& e : module_->entry_points()) {
+ ir::Function* fn =
+ id2function_[e.GetSingleWordInOperand(kEntryPointFunctionIdInIdx)];
+ modified = EliminateMultiStoreLocal(fn) || modified;
+ }
+ FinalizeNextId(module_);
+ return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
+}
+
+LocalMultiStoreElimPass::LocalMultiStoreElimPass()
+ : module_(nullptr), def_use_mgr_(nullptr),
+ pseudo_entry_block_(std::unique_ptr<ir::Instruction>(
+ new ir::Instruction(SpvOpLabel, 0, 0, {}))),
+ next_id_(0) {}
+
+Pass::Status LocalMultiStoreElimPass::Process(ir::Module* module) {
+ Initialize(module);
+ return ProcessImpl();
+}
+
+void LocalMultiStoreElimPass::InitExtensions() {
+ extensions_whitelist_.clear();
+ extensions_whitelist_.insert({
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
+ "SPV_AMD_gcn_shader",
+ "SPV_KHR_shader_ballot",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_KHR_shader_draw_parameters",
+ "SPV_KHR_subgroup_vote",
+ "SPV_KHR_16bit_storage",
+ "SPV_KHR_device_group",
+ "SPV_KHR_multiview",
+ "SPV_NVX_multiview_per_view_attributes",
+ "SPV_NV_viewport_array2",
+ "SPV_NV_stereo_view_rendering",
+ "SPV_NV_sample_mask_override_coverage",
+ "SPV_NV_geometry_shader_passthrough",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_KHR_storage_buffer_storage_class",
+ // SPV_KHR_variable_pointers
+ // Currently do not support extended pointer expressions
+ "SPV_AMD_gpu_shader_int16",
+ "SPV_KHR_post_depth_coverage",
+ "SPV_KHR_shader_atomic_counter_ops",
+ });
+}
+
+} // namespace opt
+} // namespace spvtools
+
diff --git a/source/opt/local_ssa_elim_pass.h b/source/opt/local_ssa_elim_pass.h
new file mode 100644
index 00000000..8237c30e
--- /dev/null
+++ b/source/opt/local_ssa_elim_pass.h
@@ -0,0 +1,268 @@
+// Copyright (c) 2017 The Khronos Group Inc.
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_OPT_LOCAL_SSA_ELIM_PASS_H_
+#define LIBSPIRV_OPT_LOCAL_SSA_ELIM_PASS_H_
+
+
+#include <algorithm>
+#include <map>
+#include <queue>
+#include <utility>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "basic_block.h"
+#include "def_use_manager.h"
+#include "module.h"
+#include "pass.h"
+
+namespace spvtools {
+namespace opt {
+
+// See optimizer.hpp for documentation.
+class LocalMultiStoreElimPass : public Pass {
+ using cbb_ptr = const ir::BasicBlock*;
+
+ public:
+ using GetBlocksFunction =
+ std::function<std::vector<ir::BasicBlock*>*(const ir::BasicBlock*)>;
+
+ LocalMultiStoreElimPass();
+ const char* name() const override { return "eliminate-local-multi-store"; }
+ Status Process(ir::Module*) override;
+
+ private:
+ // Returns true if |opcode| is a non-ptr access chain op
+ bool IsNonPtrAccessChain(const SpvOp opcode) const;
+
+ // Returns true if |typeInst| is a scalar type
+ // or a vector or matrix
+ bool IsMathType(const ir::Instruction* typeInst) const;
+
+ // Returns true if |typeInst| is a math type or a struct or array
+ // of a math type.
+ bool IsTargetType(const ir::Instruction* typeInst) const;
+
+ // Given a load or store |ip|, return the pointer instruction.
+ // Also return the base variable's id in |varId|.
+ ir::Instruction* GetPtr(ir::Instruction* ip, uint32_t* varId);
+
+ // Return true if |varId| is a previously identified target variable.
+ // Return false if |varId| is a previously identified non-target variable.
+ // If variable is not cached, return true if variable is a function scope
+ // variable of target type, false otherwise. Updates caches of target
+ // and non-target variables.
+ bool IsTargetVar(uint32_t varId);
+
+ // Return type id for |ptrInst|'s pointee
+ uint32_t GetPointeeTypeId(const ir::Instruction* ptrInst) const;
+
+ // Replace all instances of |loadInst|'s id with |replId| and delete
+ // |loadInst|.
+ void ReplaceAndDeleteLoad(ir::Instruction* loadInst, uint32_t replId);
+
+ // Return true if any instruction loads from |ptrId|
+ bool HasLoads(uint32_t ptrId) const;
+
+ // Return true if |varId| is not a function variable or if it has
+ // a load
+ bool IsLiveVar(uint32_t varId) const;
+
+ // Add stores using |ptr_id| to |insts|
+ void AddStores(uint32_t ptr_id, std::queue<ir::Instruction*>* insts);
+
+ // Delete |inst| and iterate DCE on all its operands. Won't delete
+ // labels.
+ void DCEInst(ir::Instruction* inst);
+
+ // Return true if all uses of |varId| are only through supported reference
+ // operations ie. loads and store. Also cache in supported_ref_vars_;
+ bool HasOnlySupportedRefs(uint32_t varId);
+
+ // Return true if all uses of |id| are only name or decorate ops.
+ bool HasOnlyNamesAndDecorates(uint32_t id) const;
+
+ // Kill all name and decorate ops using |inst|
+ void KillNamesAndDecorates(ir::Instruction* inst);
+
+ // Kill all name and decorate ops using |id|
+ void KillNamesAndDecorates(uint32_t id);
+
+ // Initialize data structures used by EliminateLocalMultiStore for
+ // function |func|, specifically block predecessors and target variables.
+ void InitSSARewrite(ir::Function& func);
+
+ // Returns the id of the merge block declared by a merge instruction in
+ // this block, if any. If none, returns zero.
+ uint32_t MergeBlockIdIfAny(const ir::BasicBlock& blk, uint32_t* cbid);
+
+ // Compute structured successors for function |func|.
+ // A block's structured successors are the blocks it branches to
+ // together with its declared merge block if it has one.
+ // When order matters, the merge block always appears first.
+ // This assures correct depth first search in the presence of early
+ // returns and kills. If the successor vector contain duplicates
+ // if the merge block, they are safely ignored by DFS.
+ void ComputeStructuredSuccessors(ir::Function* func);
+
+ // Compute structured block order for |func| into |structuredOrder|. This
+ // order has the property that dominators come before all blocks they
+ // dominate and merge blocks come after all blocks that are in the control
+ // constructs of their header.
+ void ComputeStructuredOrder(ir::Function* func,
+ std::list<ir::BasicBlock*>* order);
+
+ // Return true if loop header block
+ bool IsLoopHeader(ir::BasicBlock* block_ptr) const;
+
+ // Initialize label2ssa_map_ entry for block |block_ptr| with single
+ // predecessor.
+ void SSABlockInitSinglePred(ir::BasicBlock* block_ptr);
+
+ // Return true if variable is loaded in block with |label| or in
+ // any succeeding block in structured order.
+ bool IsLiveAfter(uint32_t var_id, uint32_t label) const;
+
+ // Initialize label2ssa_map_ entry for loop header block pointed to
+ // |block_itr| by merging entries from all predecessors. If any value
+ // ids differ for any variable across predecessors, create a phi function
+ // in the block and use that value id for the variable in the new map.
+ // Assumes all predecessors have been visited by EliminateLocalMultiStore
+ // except the back edge. Use a dummy value in the phi for the back edge
+ // until the back edge block is visited and patch the phi value then.
+ void SSABlockInitLoopHeader(std::list<ir::BasicBlock*>::iterator block_itr);
+
+ // Initialize label2ssa_map_ entry for multiple predecessor block
+ // |block_ptr| by merging label2ssa_map_ entries for all predecessors.
+ // If any value ids differ for any variable across predecessors, create
+ // a phi function in the block and use that value id for the variable in
+ // the new map. Assumes all predecessors have been visited by
+ // EliminateLocalMultiStore.
+ void SSABlockInitMultiPred(ir::BasicBlock* block_ptr);
+
+ // Initialize the label2ssa_map entry for a block pointed to by |block_itr|.
+ // Insert phi instructions into block when necessary. All predecessor
+ // blocks must have been visited by EliminateLocalMultiStore except for
+ // backedges.
+ void SSABlockInit(std::list<ir::BasicBlock*>::iterator block_itr);
+
+ // Return undef in function for type. Create and insert an undef after the
+ // first non-variable in the function if it doesn't already exist. Add
+ // undef to function undef map.
+ uint32_t Type2Undef(uint32_t type_id);
+
+ // Patch phis in loop header block now that the map is complete for the
+ // backedge predecessor. Specifically, for each phi, find the value
+ // corresponding to the backedge predecessor. That contains the variable id
+ // that this phi corresponds to. Change this phi operand to the the value
+ // which corresponds to that variable in the predecessor map.
+ void PatchPhis(uint32_t header_id, uint32_t back_id);
+
+ // Initialize extensions whitelist
+ void InitExtensions();
+
+ // Return true if all extensions in this module are allowed by this pass.
+ bool AllExtensionsSupported() const;
+
+ // Collect all named or decorated ids in module
+ void FindNamedOrDecoratedIds();
+
+ // Remove remaining loads and stores of function scope variables only
+ // referenced with non-access-chain loads and stores from function |func|.
+ // Insert Phi functions where necessary. Running LocalAccessChainRemoval,
+ // SingleBlockLocalElim and SingleStoreLocalElim beforehand will improve
+ // the runtime and effectiveness of this function.
+ bool EliminateMultiStoreLocal(ir::Function* func);
+
+ // Return true if |op| is decorate.
+ inline bool IsDecorate(uint32_t op) const {
+ return (op == SpvOpDecorate || op == SpvOpDecorateId);
+ }
+
+ // Save next available id into |module|.
+ inline void FinalizeNextId(ir::Module* module) {
+ module->SetIdBound(next_id_);
+ }
+
+ // Return next available id and calculate next.
+ inline uint32_t TakeNextId() {
+ return next_id_++;
+ }
+
+ void Initialize(ir::Module* module);
+ Pass::Status ProcessImpl();
+
+ // Module this pass is processing
+ ir::Module* module_;
+
+ // Def-Uses for the module we are processing
+ std::unique_ptr<analysis::DefUseManager> def_use_mgr_;
+
+ // Map from function's result id to function
+ std::unordered_map<uint32_t, ir::Function*> id2function_;
+
+ // Map from block's label id to block.
+ std::unordered_map<uint32_t, ir::BasicBlock*> id2block_;
+
+ // Cache of previously seen target types
+ std::unordered_set<uint32_t> seen_target_vars_;
+
+ // Cache of previously seen non-target types
+ std::unordered_set<uint32_t> seen_non_target_vars_;
+
+ // Set of label ids of visited blocks
+ std::unordered_set<uint32_t> visitedBlocks_;
+
+ // Map from type to undef
+ std::unordered_map<uint32_t, uint32_t> type2undefs_;
+
+ // Variables that are only referenced by supported operations for this
+ // pass ie. loads and stores.
+ std::unordered_set<uint32_t> supported_ref_vars_;
+
+ // named or decorated ids
+ std::unordered_set<uint32_t> named_or_decorated_ids_;
+
+ // Map from block to its structured successor blocks. See
+ // ComputeStructuredSuccessors() for definition.
+ std::unordered_map<const ir::BasicBlock*, std::vector<ir::BasicBlock*>>
+ block2structured_succs_;
+
+ // Map from block's label id to its predecessor blocks ids
+ std::unordered_map<uint32_t, std::vector<uint32_t>> label2preds_;
+
+ // Map from block's label id to a map of a variable to its value at the
+ // end of the block.
+ std::unordered_map<uint32_t, std::unordered_map<uint32_t, uint32_t>>
+ label2ssa_map_;
+
+ // Extra block whose successors are all blocks with no predecessors
+ // in function.
+ ir::BasicBlock pseudo_entry_block_;
+
+ // Extensions supported by this pass.
+ std::unordered_set<std::string> extensions_whitelist_;
+
+ // Next unused ID
+ uint32_t next_id_;
+};
+
+} // namespace opt
+} // namespace spvtools
+
+#endif // LIBSPIRV_OPT_LOCAL_SSA_ELIM_PASS_H_
+
diff --git a/source/opt/module.cpp b/source/opt/module.cpp
index dd08ca0e..290e88d4 100644
--- a/source/opt/module.cpp
+++ b/source/opt/module.cpp
@@ -15,6 +15,7 @@
#include "module.h"
#include <algorithm>
+#include <cstring>
#include "operand.h"
#include "reflect.h"
@@ -150,5 +151,13 @@ bool Module::HasCapability(uint32_t cap) {
return false;
}
+uint32_t Module::GetExtInstImportId(const char* extstr) {
+ for (auto& ei : ext_inst_imports_)
+ if (!strcmp(extstr, reinterpret_cast<const char*>(
+ &ei->GetInOperand(0).words[0])))
+ return ei->result_id();
+ return 0;
+}
+
} // namespace ir
} // namespace spvtools
diff --git a/source/opt/module.h b/source/opt/module.h
index 37d49025..e29615fc 100644
--- a/source/opt/module.h
+++ b/source/opt/module.h
@@ -112,6 +112,10 @@ class Module {
IteratorRange<inst_iterator> annotations();
IteratorRange<const_inst_iterator> annotations() const;
+ // Iterators for extension instructions contained in this module.
+ IteratorRange<inst_iterator> extensions();
+ IteratorRange<const_inst_iterator> extensions() const;
+
// Iterators for types, constants and global variables instructions.
inline inst_iterator types_values_begin();
inline inst_iterator types_values_end();
@@ -141,6 +145,10 @@ class Module {
// Returns true if module has capability |cap|
bool HasCapability(uint32_t cap);
+ // Returns id for OpExtInst instruction for extension |extstr|.
+ // Returns 0 if not found.
+ uint32_t GetExtInstImportId(const char* extstr);
+
private:
ModuleHeader header_; // Module header
@@ -235,6 +243,14 @@ inline IteratorRange<Module::const_inst_iterator> Module::annotations() const {
return make_const_range(annotations_);
}
+inline IteratorRange<Module::inst_iterator> Module::extensions() {
+ return make_range(extensions_);
+}
+
+inline IteratorRange<Module::const_inst_iterator> Module::extensions() const {
+ return make_const_range(extensions_);
+}
+
inline Module::inst_iterator Module::types_values_begin() {
return inst_iterator(&types_values_, types_values_.begin());
}
diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp
index cf9a8251..80d86aeb 100644
--- a/source/opt/optimizer.cpp
+++ b/source/opt/optimizer.cpp
@@ -132,6 +132,11 @@ Optimizer::PassToken CreateEliminateDeadConstantPass() {
MakeUnique<opt::EliminateDeadConstantPass>());
}
+Optimizer::PassToken CreateBlockMergePass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::BlockMergePass>());
+}
+
Optimizer::PassToken CreateInlinePass() {
return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::InlinePass>());
}
@@ -146,6 +151,31 @@ Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass() {
MakeUnique<opt::LocalSingleBlockLoadStoreElimPass>());
}
+Optimizer::PassToken CreateLocalSingleStoreElimPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::LocalSingleStoreElimPass>());
+}
+
+Optimizer::PassToken CreateInsertExtractElimPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::InsertExtractElimPass>());
+}
+
+Optimizer::PassToken CreateDeadBranchElimPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::DeadBranchElimPass>());
+}
+
+Optimizer::PassToken CreateLocalMultiStoreElimPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::LocalMultiStoreElimPass>());
+}
+
+Optimizer::PassToken CreateAggressiveDCEPass() {
+ return MakeUnique<Optimizer::PassToken::Impl>(
+ MakeUnique<opt::AggressiveDCEPass>());
+}
+
Optimizer::PassToken CreateCompactIdsPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::CompactIdsPass>());
diff --git a/source/opt/passes.h b/source/opt/passes.h
index 61361a7f..f6d69619 100644
--- a/source/opt/passes.h
+++ b/source/opt/passes.h
@@ -17,14 +17,20 @@
// A single header to include all passes.
+#include "block_merge_pass.h"
#include "compact_ids_pass.h"
+#include "dead_branch_elim_pass.h"
#include "eliminate_dead_constant_pass.h"
#include "flatten_decoration_pass.h"
#include "fold_spec_constant_op_and_composite_pass.h"
#include "inline_pass.h"
+#include "insert_extract_elim.h"
#include "local_single_block_elim_pass.h"
+#include "local_single_store_elim_pass.h"
+#include "local_ssa_elim_pass.h"
#include "freeze_spec_constant_value_pass.h"
#include "local_access_chain_convert_pass.h"
+#include "aggressive_dead_code_elim_pass.h"
#include "null_pass.h"
#include "set_spec_constant_default_value_pass.h"
#include "strip_debug_info_pass.h"
diff --git a/source/text.cpp b/source/text.cpp
index 6a6846ea..8c48814b 100644
--- a/source/text.cpp
+++ b/source/text.cpp
@@ -247,7 +247,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar,
spvInstructionAddWord(pInst, extInst->ext_inst);
// Prepare to parse the operands for the extended instructions.
- spvPrependOperandTypes(extInst->operandTypes, pExpectedOperands);
+ spvPushOperandTypes(extInst->operandTypes, pExpectedOperands);
} break;
case SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER: {
@@ -271,7 +271,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar,
assert(opcodeEntry->hasType);
assert(opcodeEntry->hasResult);
assert(opcodeEntry->numTypes >= 2);
- spvPrependOperandTypes(opcodeEntry->operandTypes + 2, pExpectedOperands);
+ spvPushOperandTypes(opcodeEntry->operandTypes + 2, pExpectedOperands);
} break;
case SPV_OPERAND_TYPE_LITERAL_INTEGER:
@@ -380,7 +380,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar,
}
if (auto error = context->binaryEncodeU32(value, pInst)) return error;
// Prepare to parse the operands for this logical operand.
- grammar.prependOperandTypesForMask(type, value, pExpectedOperands);
+ grammar.pushOperandTypesForMask(type, value, pExpectedOperands);
} break;
case SPV_OPERAND_TYPE_OPTIONAL_CIV: {
auto error = spvTextEncodeOperand(
@@ -420,7 +420,7 @@ spv_result_t spvTextEncodeOperand(const libspirv::AssemblyGrammar& grammar,
}
// Prepare to parse the operands for this logical operand.
- spvPrependOperandTypes(entry->operandTypes, pExpectedOperands);
+ spvPushOperandTypes(entry->operandTypes, pExpectedOperands);
} break;
}
return SPV_SUCCESS;
@@ -553,13 +553,14 @@ spv_result_t spvTextEncodeOpcode(const libspirv::AssemblyGrammar& grammar,
// has its own logical operands (such as the LocalSize operand for
// ExecutionMode), or for extended instructions that may have their
// own operands depending on the selected extended instruction.
- spv_operand_pattern_t expectedOperands(
- opcodeEntry->operandTypes,
- opcodeEntry->operandTypes + opcodeEntry->numTypes);
+ spv_operand_pattern_t expectedOperands;
+ expectedOperands.reserve(opcodeEntry->numTypes);
+ for (auto i = 0; i < opcodeEntry->numTypes; i++)
+ expectedOperands.push_back(opcodeEntry->operandTypes[opcodeEntry->numTypes - i - 1]);
while (!expectedOperands.empty()) {
- const spv_operand_type_t type = expectedOperands.front();
- expectedOperands.pop_front();
+ const spv_operand_type_t type = expectedOperands.back();
+ expectedOperands.pop_back();
// Expand optional tuples lazily.
if (spvExpandOperandSequenceOnce(type, &expectedOperands)) continue;
diff --git a/source/util/bit_stream.cpp b/source/util/bit_stream.cpp
index 5dac5638..bfd6af08 100644
--- a/source/util/bit_stream.cpp
+++ b/source/util/bit_stream.cpp
@@ -34,16 +34,21 @@ bool IsLittleEndian() {
return reinterpret_cast<const unsigned char*>(&kFF00)[0] == 0;
}
-// Copies uint8_t buffer to a uint64_t buffer.
+// Copies bytes from the given buffer to a uint64_t buffer.
// Motivation: casting uint64_t* to uint8_t* is ok. Casting in the other
// direction is only advisable if uint8_t* is aligned to 64-bit word boundary.
-std::vector<uint64_t> ToBuffer64(const std::vector<uint8_t>& in) {
+std::vector<uint64_t> ToBuffer64(const void* buffer, size_t num_bytes) {
std::vector<uint64_t> out;
- out.resize((in.size() + 7) / 8, 0);
- memcpy(out.data(), in.data(), in.size());
+ out.resize((num_bytes + 7) / 8, 0);
+ memcpy(out.data(), buffer, num_bytes);
return out;
}
+// Copies uint8_t buffer to a uint64_t buffer.
+std::vector<uint64_t> ToBuffer64(const std::vector<uint8_t>& in) {
+ return ToBuffer64(in.data(), in.size());
+}
+
// Returns uint64_t containing the same bits as |val|.
// Type size must be less than 8 bytes.
template <typename T>
@@ -78,7 +83,9 @@ void WriteVariableWidthInternal(BitWriterInterface* writer, uint64_t val,
assert(max_payload == 64 || (val >> max_payload) == 0);
if (val == 0) {
- writer->WriteBits(0, chunk_length + 1);
+ // Split in two writes for more readable logging.
+ writer->WriteBits(0, chunk_length);
+ writer->WriteBits(0, 1);
return;
}
@@ -193,6 +200,41 @@ bool ReadVariableWidthSigned(BitReaderInterface* reader, T* val,
} // namespace
+size_t Log2U64(uint64_t val) {
+ size_t res = 0;
+
+ if (val & 0xFFFFFFFF00000000) {
+ val >>= 32;
+ res |= 32;
+ }
+
+ if (val & 0xFFFF0000) {
+ val >>= 16;
+ res |= 16;
+ }
+
+ if (val & 0xFF00) {
+ val >>= 8;
+ res |= 8;
+ }
+
+ if (val & 0xF0) {
+ val >>= 4;
+ res |= 4;
+ }
+
+ if (val & 0xC) {
+ val >>= 2;
+ res |= 2;
+ }
+
+ if (val & 0x2) {
+ res |= 1;
+ }
+
+ return res;
+}
+
void BitWriterInterface::WriteVariableWidthU64(uint64_t val,
size_t chunk_length) {
WriteVariableWidthUnsigned(this, val, chunk_length);
@@ -237,6 +279,16 @@ void BitWriterInterface::WriteVariableWidthS8(int8_t val,
WriteVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
}
+void BitWriterInterface::WriteFixedWidth(uint64_t val, uint64_t max_val) {
+ if (val > max_val) {
+ assert(0 && "WriteFixedWidth: value too wide");
+ return;
+ }
+
+ const size_t num_bits = 1 + Log2U64(max_val);
+ WriteBits(val, num_bits);
+}
+
BitWriterWord64::BitWriterWord64(size_t reserve_bits) : end_(0) {
buffer_.reserve(NumBitsToNumWords<64>(reserve_bits));
}
@@ -250,6 +302,8 @@ void BitWriterWord64::WriteBits(uint64_t bits, size_t num_bits) {
bits = GetLowerBits(bits, num_bits);
+ EmitSequence(bits, num_bits);
+
// Offset from the start of the current word.
const size_t offset = end_ % 64;
@@ -320,12 +374,20 @@ bool BitReaderInterface::ReadVariableWidthS8(int8_t* val,
return ReadVariableWidthSigned(this, val, chunk_length, zigzag_exponent);
}
+bool BitReaderInterface::ReadFixedWidth(uint64_t* val, uint64_t max_val) {
+ const size_t num_bits = 1 + Log2U64(max_val);
+ return ReadBits(val, num_bits) == num_bits;
+}
+
BitReaderWord64::BitReaderWord64(std::vector<uint64_t>&& buffer)
: buffer_(std::move(buffer)), pos_(0) {}
BitReaderWord64::BitReaderWord64(const std::vector<uint8_t>& buffer)
: buffer_(ToBuffer64(buffer)), pos_(0) {}
+BitReaderWord64::BitReaderWord64(const void* buffer, size_t num_bytes)
+ : buffer_(ToBuffer64(buffer, num_bytes)), pos_(0) {}
+
size_t BitReaderWord64::ReadBits(uint64_t* bits, size_t num_bits) {
assert(num_bits <= 64);
const bool is_little_endian = IsLittleEndian();
diff --git a/source/util/bit_stream.h b/source/util/bit_stream.h
index a139b633..b626d2c7 100644
--- a/source/util/bit_stream.h
+++ b/source/util/bit_stream.h
@@ -17,14 +17,19 @@
#ifndef LIBSPIRV_UTIL_BIT_STREAM_H_
#define LIBSPIRV_UTIL_BIT_STREAM_H_
+#include <algorithm>
#include <bitset>
#include <cstdint>
+#include <functional>
#include <string>
#include <sstream>
#include <vector>
namespace spvutils {
+// Returns rounded down log2(val). log2(0) is considered 0.
+size_t Log2U64(uint64_t val);
+
// Terminology:
// Bits - usually used for a uint64 word, first bit is the lowest.
// Stream - std::string of '0' and '1', read left-to-right,
@@ -212,6 +217,16 @@ class BitWriterInterface {
WriteBits(bits.to_ullong(), num_bits);
}
+ // Writes bits from value of type |T| to the stream. No encoding is done.
+ // Always writes 8 * sizeof(T) bits.
+ template <typename T>
+ void WriteUnencoded(T val) {
+ static_assert(sizeof(T) <= 64, "Type size too large");
+ uint64_t bits = 0;
+ memcpy(&bits, &val, sizeof(T));
+ WriteBits(bits, sizeof(T) * 8);
+ }
+
// Writes |val| in chunks of size |chunk_length| followed by a signal bit:
// 0 - no more chunks to follow
// 1 - more chunks to follow
@@ -231,6 +246,18 @@ class BitWriterInterface {
void WriteVariableWidthS8(
int8_t val, size_t chunk_length, size_t zigzag_exponent);
+ // Writes |val| using fixed bit width. Bit width is determined by |max_val|:
+ // max_val 0 -> bit width 1
+ // max_val 1 -> bit width 1
+ // max_val 2 -> bit width 2
+ // max_val 3 -> bit width 2
+ // max_val 4 -> bit width 3
+ // max_val 5 -> bit width 3
+ // max_val 8 -> bit width 4
+ // max_val n -> bit width 1 + floor(log2(n))
+ // |val| needs to be <= |max_val|.
+ void WriteFixedWidth(uint64_t val, uint64_t max_val);
+
// Returns number of bits written.
virtual size_t GetNumBits() const = 0;
@@ -277,10 +304,26 @@ class BitWriterWord64 : public BitWriterInterface {
return BufferToStream(buffer_);
}
+ // Sets callback to emit bit sequences after every write.
+ void SetCallback(std::function<void(const std::string&)> callback) {
+ callback_ = callback;
+ }
+
+ protected:
+ // Sends string generated from arguments to callback_ if defined.
+ void EmitSequence(uint64_t bits, size_t num_bits) const {
+ if (callback_)
+ callback_(BitsToStream(bits, num_bits));
+ }
+
private:
std::vector<uint64_t> buffer_;
// Total number of bits written so far. Named 'end' as analogy to std::end().
size_t end_;
+
+ // If not null, the writer will use the callback to emit the written bit
+ // sequence as a string of '0' and '1'.
+ std::function<void(const std::string&)> callback_;
};
// Base class for reading sequences of bits.
@@ -314,6 +357,21 @@ class BitReaderInterface {
return BitsToStream(bits, num_read);
}
+ // Reads 8 * sizeof(T) bits and stores them in |val|.
+ template <typename T>
+ bool ReadUnencoded(T* val) {
+ static_assert(sizeof(T) <= 64, "Type size too large");
+ uint64_t bits = 0;
+ const size_t num_read = ReadBits(&bits, sizeof(T) * 8);
+ if (num_read != sizeof(T) * 8)
+ return false;
+ memcpy(val, &bits, sizeof(T));
+ return true;
+ }
+
+ // Returns number of bits already read.
+ virtual size_t GetNumReadBits() const = 0;
+
// These two functions define 'hard' and 'soft' EOF.
//
// Returns true if the end of the buffer was reached.
@@ -346,6 +404,10 @@ class BitReaderInterface {
bool ReadVariableWidthS8(
int8_t* val, size_t chunk_length, size_t zigzag_exponent);
+ // Reads value written by WriteFixedWidth (|max_val| needs to be the same).
+ // Returns true on success, false if the bit stream ends prematurely.
+ bool ReadFixedWidth(uint64_t* val, uint64_t max_val);
+
BitReaderInterface(const BitReaderInterface&) = delete;
BitReaderInterface& operator=(const BitReaderInterface&) = delete;
};
@@ -362,8 +424,14 @@ class BitReaderWord64 : public BitReaderInterface {
// Consuming the original buffer and casting it to uint64 is difficult,
// as it would potentially cause data misalignment and poor performance.
explicit BitReaderWord64(const std::vector<uint8_t>& buffer);
+ BitReaderWord64(const void* buffer, size_t num_bytes);
size_t ReadBits(uint64_t* bits, size_t num_bits) override;
+
+ size_t GetNumReadBits() const override {
+ return pos_;
+ }
+
bool ReachedEnd() const override;
bool OnlyZeroesLeft() const override;
diff --git a/source/util/huffman_codec.h b/source/util/huffman_codec.h
new file mode 100644
index 00000000..2e74d6b8
--- /dev/null
+++ b/source/util/huffman_codec.h
@@ -0,0 +1,299 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Contains utils for reading, writing and debug printing bit streams.
+
+#ifndef LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
+#define LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
+
+#include <algorithm>
+#include <cassert>
+#include <functional>
+#include <queue>
+#include <iomanip>
+#include <map>
+#include <memory>
+#include <ostream>
+#include <sstream>
+#include <stack>
+#include <tuple>
+#include <unordered_map>
+#include <vector>
+
+namespace spvutils {
+
+// Used to generate and apply a Huffman coding scheme.
+// |Val| is the type of variable being encoded (for example a string or a
+// literal).
+template <class Val>
+class HuffmanCodec {
+ struct Node;
+
+ public:
+ // Creates Huffman codec from a histogramm.
+ // Histogramm counts must not be zero.
+ explicit HuffmanCodec(const std::map<Val, uint32_t>& hist) {
+ if (hist.empty()) return;
+
+ // Heuristic estimate.
+ all_nodes_.reserve(3 * hist.size());
+
+ // The queue is sorted in ascending order by weight (or by node id if
+ // weights are equal).
+ std::vector<Node*> queue_vector;
+ queue_vector.reserve(hist.size());
+ std::priority_queue<Node*, std::vector<Node*>,
+ std::function<bool(const Node*, const Node*)>>
+ queue(LeftIsBigger, std::move(queue_vector));
+
+ // Put all leaves in the queue.
+ for (const auto& pair : hist) {
+ Node* node = CreateNode();
+ node->val = pair.first;
+ node->weight = pair.second;
+ assert(node->weight);
+ queue.push(node);
+ }
+
+ // Form the tree by combining two subtrees with the least weight,
+ // and pushing the root of the new tree in the queue.
+ while (true) {
+ // We push a node at the end of each iteration, so the queue is never
+ // supposed to be empty at this point, unless there are no leaves, but
+ // that case was already handled.
+ assert(!queue.empty());
+ Node* right = queue.top();
+ queue.pop();
+
+ // If the queue is empty at this point, then the last node is
+ // the root of the complete Huffman tree.
+ if (queue.empty()) {
+ root_ = right;
+ break;
+ }
+
+ Node* left = queue.top();
+ queue.pop();
+
+ // Combine left and right into a new tree and push it into the queue.
+ Node* parent = CreateNode();
+ parent->weight = right->weight + left->weight;
+ parent->left = left;
+ parent->right = right;
+ queue.push(parent);
+ }
+
+ // Traverse the tree and form encoding table.
+ CreateEncodingTable();
+ }
+
+ // Prints the Huffman tree in the following format:
+ // w------w------'x'
+ // w------'y'
+ // Where w stands for the weight of the node.
+ // Right tree branches appear above left branches. Taking the right path
+ // adds 1 to the code, taking the left adds 0.
+ void PrintTree(std::ostream& out) {
+ PrintTreeInternal(out, root_, 0);
+ }
+
+ // Traverses the tree and prints the Huffman table: value, code
+ // and optionally node weight for every leaf.
+ void PrintTable(std::ostream& out, bool print_weights = true) {
+ std::queue<std::pair<Node*, std::string>> queue;
+ queue.emplace(root_, "");
+
+ while (!queue.empty()) {
+ const Node* node = queue.front().first;
+ const std::string code = queue.front().second;
+ queue.pop();
+ if (!node->right && !node->left) {
+ out << node->val;
+ if (print_weights)
+ out << " " << node->weight;
+ out << " " << code << std::endl;
+ } else {
+ if (node->left)
+ queue.emplace(node->left, code + "0");
+
+ if (node->right)
+ queue.emplace(node->right, code + "1");
+ }
+ }
+ }
+
+ // Returns the Huffman table. The table was built at at construction time,
+ // this function just returns a const reference.
+ const std::unordered_map<Val, std::pair<uint64_t, size_t>>&
+ GetEncodingTable() const {
+ return encoding_table_;
+ }
+
+ // Encodes |val| and stores its Huffman code in the lower |num_bits| of
+ // |bits|. Returns false of |val| is not in the Huffman table.
+ bool Encode(const Val& val, uint64_t* bits, size_t* num_bits) {
+ auto it = encoding_table_.find(val);
+ if (it == encoding_table_.end())
+ return false;
+ *bits = it->second.first;
+ *num_bits = it->second.second;
+ return true;
+ }
+
+ // Reads bits one-by-one using callback |read_bit| until a match is found.
+ // Matching value is stored in |val|. Returns false if |read_bit| terminates
+ // before a code was mathced.
+ // |read_bit| has type bool func(bool* bit). When called, the next bit is
+ // stored in |bit|. |read_bit| returns false if the stream terminates
+ // prematurely.
+ bool DecodeFromStream(const std::function<bool(bool*)>& read_bit, Val* val) {
+ Node* node = root_;
+ while (true) {
+ assert(node);
+
+ if (node->left == nullptr && node->right == nullptr) {
+ *val = node->val;
+ return true;
+ }
+
+ bool go_right;
+ if (!read_bit(&go_right))
+ return false;
+
+ if (go_right)
+ node = node->right;
+ else
+ node = node->left;
+ }
+
+ assert (0);
+ return false;
+ }
+
+ private:
+ // Huffman tree node.
+ struct Node {
+ Val val = Val();
+ uint32_t weight = 0;
+ // Ids are issued sequentially starting from 1. Ids are used as an ordering
+ // tie-breaker, to make sure that the ordering (and resulting coding scheme)
+ // are consistent accross multiple platforms.
+ uint32_t id = 0;
+ Node* left = nullptr;
+ Node* right = nullptr;
+ };
+
+ // Returns true if |left| has bigger weight than |right|. Node ids are
+ // used as tie-breaker.
+ static bool LeftIsBigger(const Node* left, const Node* right) {
+ if (left->weight == right->weight) {
+ assert (left->id != right->id);
+ return left->id > right->id;
+ }
+ return left->weight > right->weight;
+ }
+
+ // Prints subtree (helper function used by PrintTree).
+ static void PrintTreeInternal(std::ostream& out, Node* node, size_t depth) {
+ if (!node)
+ return;
+
+ const size_t kTextFieldWidth = 7;
+
+ if (!node->right && !node->left) {
+ out << node->val << std::endl;
+ } else {
+ if (node->right) {
+ std::stringstream label;
+ label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
+ << node->right->weight;
+ out << label.str();
+ PrintTreeInternal(out, node->right, depth + 1);
+ }
+
+ if (node->left) {
+ out << std::string(depth * kTextFieldWidth, ' ');
+ std::stringstream label;
+ label << std::setfill('-') << std::left << std::setw(kTextFieldWidth)
+ << node->left->weight;
+ out << label.str();
+ PrintTreeInternal(out, node->left, depth + 1);
+ }
+ }
+ }
+
+ // Traverses the Huffman tree and saves paths to the leaves as bit
+ // sequences to encoding_table_.
+ void CreateEncodingTable() {
+ struct Context {
+ Context(Node* in_node, uint64_t in_bits, size_t in_depth)
+ : node(in_node), bits(in_bits), depth(in_depth) {}
+ Node* node;
+ // Huffman tree depth cannot exceed 64 as histogramm counts are expected
+ // to be positive and limited by numeric_limits<uint32_t>::max().
+ // For practical applications tree depth would be much smaller than 64.
+ uint64_t bits;
+ size_t depth;
+ };
+
+ std::queue<Context> queue;
+ queue.emplace(root_, 0, 0);
+
+ while (!queue.empty()) {
+ const Context& context = queue.front();
+ const Node* node = context.node;
+ const uint64_t bits = context.bits;
+ const size_t depth = context.depth;
+ queue.pop();
+
+ if (!node->right && !node->left) {
+ auto insertion_result = encoding_table_.emplace(
+ node->val, std::pair<uint64_t, size_t>(bits, depth));
+ assert(insertion_result.second);
+ (void)insertion_result;
+ } else {
+ if (node->left)
+ queue.emplace(node->left, bits, depth + 1);
+
+ if (node->right)
+ queue.emplace(node->right, bits | (1ULL << depth), depth + 1);
+ }
+ }
+ }
+
+ // Creates new Huffman tree node and stores it in the deleter array.
+ Node* CreateNode() {
+ all_nodes_.emplace_back(new Node());
+ all_nodes_.back()->id = next_node_id_++;
+ return all_nodes_.back().get();
+ }
+
+ // Huffman tree root.
+ Node* root_ = nullptr;
+
+ // Huffman tree deleter.
+ std::vector<std::unique_ptr<Node>> all_nodes_;
+
+ // Encoding table value -> {bits, num_bits}.
+ // Huffman codes are expected to never exceed 64 bit length (this is in fact
+ // impossible if frequencies are stored as uint32_t).
+ std::unordered_map<Val, std::pair<uint64_t, size_t>> encoding_table_;
+
+ // Next node id issued by CreateNode();
+ uint32_t next_node_id_ = 1;
+};
+
+} // namespace spvutils
+
+#endif // LIBSPIRV_UTIL_HUFFMAN_CODEC_H_
diff --git a/source/util/move_to_front.h b/source/util/move_to_front.h
new file mode 100644
index 00000000..dc1430f1
--- /dev/null
+++ b/source/util/move_to_front.h
@@ -0,0 +1,649 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef LIBSPIRV_UTIL_MOVE_TO_FRONT_H_
+#define LIBSPIRV_UTIL_MOVE_TO_FRONT_H_
+
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <iomanip>
+#include <iostream>
+#include <ostream>
+#include <sstream>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace spvutils {
+
+// Log(n) move-to-front implementation. Implements the following functions:
+// Insert - pushes value to the front of the mtf sequence
+// (only unique values allowed).
+// Remove - remove value from the sequence.
+// ValueFromRank - access value by its 1-indexed rank in the sequence.
+// RankFromValue - get the rank of the given value in the sequence.
+// Accessing a value with ValueFromRank or RankFromValue moves the value to the
+// front of the sequence (rank of 1).
+//
+// The implementation is based on an AVL-based order statistic tree. The tree
+// is ordered by timestamps issued when values are inserted or accessed (recent
+// values go to the left side of the tree, old values are gradually rotated to
+// the right side).
+//
+// Terminology
+// rank: 1-indexed rank showing how recently the value was inserted or accessed.
+// node: handle used internally to access node data.
+// size: size of the subtree of a node (including the node).
+// height: distance from a node to the farthest leaf.
+template <typename Val>
+class MoveToFront {
+ public:
+ explicit MoveToFront(size_t reserve_capacity = 128) {
+ nodes_.reserve(reserve_capacity);
+
+ // Create NIL node.
+ nodes_.emplace_back(Node());
+ }
+
+ virtual ~MoveToFront() {}
+
+ // Inserts value in the move-to-front sequence. Does nothing if the value is
+ // already in the sequence. Returns true if insertion was successful.
+ // The inserted value is placed at the front of the sequence (rank 1).
+ bool Insert(const Val& value);
+
+ // Removes value from move-to-front sequence. Returns false iff the value
+ // was not found.
+ bool Remove(const Val& value);
+
+ // Computes 1-indexed rank of value in the move-to-front sequence and moves
+ // the value to the front. Example:
+ // Before the call: 4 8 2 1 7
+ // RankFromValue(8) returns 1
+ // After the call: 8 4 2 1 7
+ // Returns true iff the value was found in the sequence.
+ bool RankFromValue(const Val& value, size_t* rank);
+
+ // Returns value corresponding to a 1-indexed rank in the move-to-front
+ // sequence and moves the value to the front. Example:
+ // Before the call: 4 8 2 1 7
+ // ValueFromRank(1) returns 8
+ // After the call: 8 4 2 1 7
+ // Returns true iff the rank is within bounds [1, GetSize()].
+ bool ValueFromRank(size_t rank, Val* value);
+
+ // Returns the number of elements in the move-to-front sequence.
+ size_t GetSize() const {
+ return SizeOf(root_);
+ }
+
+ protected:
+ // Internal tree data structure uses handles instead of pointers. Leaves and
+ // root parent reference a singleton under handle 0. Although dereferencing
+ // a null pointer is not possible, inappropriate access to handle 0 would
+ // cause an assertion. Handles are not garbage collected if value was deprecated
+ // with DeprecateValue(). But handles are recycled when a node is repositioned.
+
+ // Internal tree data structure node.
+ struct Node {
+ // Timestamp from a logical clock which updates every time the element is
+ // accessed through ValueFromRank or RankFromValue.
+ uint32_t timestamp = 0;
+ // The size of the node's subtree, including the node.
+ // SizeOf(LeftOf(node)) + SizeOf(RightOf(node)) + 1.
+ uint32_t size = 0;
+ // Handles to connected nodes.
+ uint32_t left = 0;
+ uint32_t right = 0;
+ uint32_t parent = 0;
+ // Distance to the farthest leaf.
+ // Leaves have height 0, real nodes at least 1.
+ uint32_t height = 0;
+ // Stored value.
+ Val value = Val();
+ };
+
+ // Creates node and sets correct values. Non-NIL nodes should be created only
+ // through this function. If the node with this value has been created previously
+ // and since orphaned, reuses the old node instead of creating a new one.
+ uint32_t CreateNode(uint32_t timestamp, const Val& value) {
+ uint32_t handle = static_cast<uint32_t>(nodes_.size());
+ const auto result = value_to_node_.emplace(value, handle);
+ if (result.second) {
+ // Create new node.
+ nodes_.emplace_back(Node());
+ Node& node = nodes_.back();
+ node.timestamp = timestamp;
+ node.value = value;
+ node.size = 1;
+ // Non-NIL nodes start with height 1 because their NIL children are leaves.
+ node.height = 1;
+ } else {
+ // Reuse old node.
+ handle = result.first->second;
+ assert(!IsInTree(handle));
+ assert(ValueOf(handle) == value);
+ assert(SizeOf(handle) == 1);
+ assert(HeightOf(handle) == 1);
+ MutableTimestampOf(handle) = timestamp;
+ }
+
+ return handle;
+ }
+
+ // Node accessor methods. Naming is designed to be similar to natural
+ // language as these functions tend to be used in sequences, for example:
+ // ParentOf(LeftestDescendentOf(RightOf(node)))
+
+ // Returns value of the node referenced by |handle|.
+ Val ValueOf(uint32_t node) const {
+ return nodes_.at(node).value;
+ }
+
+ // Returns left child of |node|.
+ uint32_t LeftOf(uint32_t node) const {
+ return nodes_.at(node).left;
+ }
+
+ // Returns right child of |node|.
+ uint32_t RightOf(uint32_t node) const {
+ return nodes_.at(node).right;
+ }
+
+ // Returns parent of |node|.
+ uint32_t ParentOf(uint32_t node) const {
+ return nodes_.at(node).parent;
+ }
+
+ // Returns timestamp of |node|.
+ uint32_t TimestampOf(uint32_t node) const {
+ assert(node);
+ return nodes_.at(node).timestamp;
+ }
+
+ // Returns size of |node|.
+ uint32_t SizeOf(uint32_t node) const {
+ return nodes_.at(node).size;
+ }
+
+ // Returns height of |node|.
+ uint32_t HeightOf(uint32_t node) const {
+ return nodes_.at(node).height;
+ }
+
+ // Returns mutable reference to value of |node|.
+ Val& MutableValueOf(uint32_t node) {
+ assert(node);
+ return nodes_.at(node).value;
+ }
+
+ // Returns mutable reference to handle of left child of |node|.
+ uint32_t& MutableLeftOf(uint32_t node) {
+ assert(node);
+ return nodes_.at(node).left;
+ }
+
+ // Returns mutable reference to handle of right child of |node|.
+ uint32_t& MutableRightOf(uint32_t node) {
+ assert(node);
+ return nodes_.at(node).right;
+ }
+
+ // Returns mutable reference to handle of parent of |node|.
+ uint32_t& MutableParentOf(uint32_t node) {
+ assert(node);
+ return nodes_.at(node).parent;
+ }
+
+ // Returns mutable reference to timestamp of |node|.
+ uint32_t& MutableTimestampOf(uint32_t node) {
+ assert(node);
+ return nodes_.at(node).timestamp;
+ }
+
+ // Returns mutable reference to size of |node|.
+ uint32_t& MutableSizeOf(uint32_t node) {
+ assert(node);
+ return nodes_.at(node).size;
+ }
+
+ // Returns mutable reference to height of |node|.
+ uint32_t& MutableHeightOf(uint32_t node) {
+ assert(node);
+ return nodes_.at(node).height;
+ }
+
+ // Returns true iff |node| is left child of its parent.
+ bool IsLeftChild(uint32_t node) const {
+ assert(node);
+ return LeftOf(ParentOf(node)) == node;
+ }
+
+ // Returns true iff |node| is right child of its parent.
+ bool IsRightChild(uint32_t node) const {
+ assert(node);
+ return RightOf(ParentOf(node)) == node;
+ }
+
+ // Returns true iff |node| has no relatives.
+ bool IsOrphan(uint32_t node) const {
+ assert(node);
+ return !ParentOf(node) && !LeftOf(node) && !RightOf(node);
+ }
+
+ // Returns true iff |node| is in the tree.
+ bool IsInTree(uint32_t node) const {
+ assert(node);
+ return node == root_ || !IsOrphan(node);
+ }
+
+ // Returns the height difference between right and left subtrees.
+ int BalanceOf(uint32_t node) const {
+ return int(HeightOf(RightOf(node))) - int(HeightOf(LeftOf(node)));
+ }
+
+ // Updates size and height of the node, assuming that the children have
+ // correct values.
+ void UpdateNode(uint32_t node);
+
+ // Returns the most LeftOf(LeftOf(... descendent which is not leaf.
+ uint32_t LeftestDescendantOf(uint32_t node) const {
+ uint32_t parent = 0;
+ while (node) {
+ parent = node;
+ node = LeftOf(node);
+ }
+ return parent;
+ }
+
+ // Returns the most RightOf(RightOf(... descendent which is not leaf.
+ uint32_t RightestDescendantOf(uint32_t node) const {
+ uint32_t parent = 0;
+ while (node) {
+ parent = node;
+ node = RightOf(node);
+ }
+ return parent;
+ }
+
+ // Inserts node in the tree. The node must be an orphan.
+ void InsertNode(uint32_t node);
+
+ // Removes node from the tree. May change value_to_node_ if removal uses a
+ // scapegoat. Returns the removed (orphaned) handle for recycling. The
+ // returned handle may not be equal to |node| if scapegoat was used.
+ uint32_t RemoveNode(uint32_t node);
+
+ // Rotates |node| left, reassigns all connections and returns the node
+ // which takes place of the |node|.
+ uint32_t RotateLeft(const uint32_t node);
+
+ // Rotates |node| right, reassigns all connections and returns the node
+ // which takes place of the |node|.
+ uint32_t RotateRight(const uint32_t node);
+
+ // Root node handle. The tree is empty if root_ is 0.
+ uint32_t root_ = 0;
+
+ // Incremented counters for next timestamp and value.
+ uint32_t next_timestamp_ = 1;
+
+ // Holds all tree nodes. Indices of this vector are node handles.
+ std::vector<Node> nodes_;
+
+ // Maps ids to node handles.
+ std::unordered_map<Val, uint32_t> value_to_node_;
+};
+
+template <typename Val>
+bool MoveToFront<Val>::Insert(const Val& value) {
+ auto it = value_to_node_.find(value);
+ if (it != value_to_node_.end() && IsInTree(it->second))
+ return false;
+
+ const size_t old_size = GetSize();
+ (void)old_size;
+
+ InsertNode(CreateNode(next_timestamp_++, value));
+
+ assert(value_to_node_.count(value));
+ assert(old_size + 1 == GetSize());
+ return true;
+}
+
+template <typename Val>
+bool MoveToFront<Val>::Remove(const Val& value) {
+ auto it = value_to_node_.find(value);
+ if (it == value_to_node_.end())
+ return false;
+
+ if (!IsInTree(it->second))
+ return false;
+
+ const uint32_t orphan = RemoveNode(it->second);
+ (void)orphan;
+ // The node of |value| is still alive but it's orphaned now. Can still be
+ // reused later.
+ assert(!IsInTree(orphan));
+ assert(ValueOf(orphan) == value);
+ return true;
+}
+
+template <typename Val>
+bool MoveToFront<Val>::RankFromValue(const Val& value, size_t* rank) {
+ const size_t old_size = GetSize();
+ (void)old_size;
+ const auto it = value_to_node_.find(value);
+ if (it == value_to_node_.end()) {
+ return false;
+ }
+
+ uint32_t target = it->second;
+
+ if (!IsInTree(target)) {
+ return false;
+ }
+
+ uint32_t node = target;
+ *rank = 1 + SizeOf(LeftOf(node));
+ while (node) {
+ if (IsRightChild(node))
+ *rank += 1 + SizeOf(LeftOf(ParentOf(node)));
+ node = ParentOf(node);
+ }
+
+ // Update timestamp and reposition the node.
+ target = RemoveNode(target);
+ assert(ValueOf(target) == value);
+ assert(old_size == GetSize() + 1);
+ MutableTimestampOf(target) = next_timestamp_++;
+ InsertNode(target);
+ assert(old_size == GetSize());
+ return true;
+}
+
+template <typename Val>
+bool MoveToFront<Val>::ValueFromRank(size_t rank, Val* value) {
+ const size_t old_size = GetSize();
+ if (rank <= 0 || rank > old_size) {
+ return false;
+ }
+
+ uint32_t node = root_;
+ while (node) {
+ const size_t left_subtree_num_nodes = SizeOf(LeftOf(node));
+ if (rank == left_subtree_num_nodes + 1) {
+ // This is the node we are looking for.
+ node = RemoveNode(node);
+ assert(old_size == GetSize() + 1);
+ MutableTimestampOf(node) = next_timestamp_++;
+ InsertNode(node);
+ assert(old_size == GetSize());
+ *value = ValueOf(node);
+ return true;
+ }
+
+ if (rank < left_subtree_num_nodes + 1) {
+ // Descend into the left subtree. The rank is still valid.
+ node = LeftOf(node);
+ } else {
+ // Descend into the right subtree. We leave behind the left subtree and
+ // the current node, adjust the |rank| accordingly.
+ rank -= left_subtree_num_nodes + 1;
+ node = RightOf(node);
+ }
+ }
+
+ assert(0);
+ return false;
+}
+
+template <typename Val>
+void MoveToFront<Val>::InsertNode(uint32_t node) {
+ assert(!IsInTree(node));
+ assert(SizeOf(node) == 1);
+ assert(HeightOf(node) == 1);
+ assert(TimestampOf(node));
+
+ if (!root_) {
+ root_ = node;
+ return;
+ }
+
+ uint32_t iter = root_;
+ uint32_t parent = 0;
+
+ // Will determine if |node| will become the right or left child after
+ // insertion (but before balancing).
+ bool right_child;
+
+ // Find the node which will become |node|'s parent after insertion
+ // (but before balancing).
+ while (iter) {
+ parent = iter;
+ assert(TimestampOf(iter) != TimestampOf(node));
+ right_child = TimestampOf(iter) > TimestampOf(node);
+ iter = right_child ? RightOf(iter) : LeftOf(iter);
+ }
+
+ assert(parent);
+
+ // Connect node and parent.
+ MutableParentOf(node) = parent;
+ if (right_child)
+ MutableRightOf(parent) = node;
+ else
+ MutableLeftOf(parent) = node;
+
+ // Insertion is finished. Start the balancing process.
+ bool needs_rebalancing = true;
+ parent = ParentOf(node);
+
+ while (parent) {
+ UpdateNode(parent);
+
+ if (needs_rebalancing) {
+ const int parent_balance = BalanceOf(parent);
+
+ if (RightOf(parent) == node) {
+ // Added node to the right subtree.
+ if (parent_balance > 1) {
+ // Parent is right heavy, rotate left.
+ if (BalanceOf(node) < 0)
+ RotateRight(node);
+ parent = RotateLeft(parent);
+ } else if (parent_balance == 0 || parent_balance == -1) {
+ // Parent is balanced or left heavy, no need to balance further.
+ needs_rebalancing = false;
+ }
+ } else {
+ // Added node to the left subtree.
+ if (parent_balance < -1) {
+ // Parent is left heavy, rotate right.
+ if (BalanceOf(node) > 0)
+ RotateLeft(node);
+ parent = RotateRight(parent);
+ } else if (parent_balance == 0 || parent_balance == 1) {
+ // Parent is balanced or right heavy, no need to balance further.
+ needs_rebalancing = false;
+ }
+ }
+ }
+
+ assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1));
+
+ node = parent;
+ parent = ParentOf(parent);
+ }
+}
+
+template <typename Val>
+uint32_t MoveToFront<Val>::RemoveNode(uint32_t node) {
+ if (LeftOf(node) && RightOf(node)) {
+ // If |node| has two children, then use another node as scapegoat and swap
+ // their contents. We pick the scapegoat on the side of the tree which has more nodes.
+ const uint32_t scapegoat = SizeOf(LeftOf(node)) >= SizeOf(RightOf(node)) ?
+ RightestDescendantOf(LeftOf(node)) : LeftestDescendantOf(RightOf(node));
+ assert(scapegoat);
+ std::swap(MutableValueOf(node), MutableValueOf(scapegoat));
+ std::swap(MutableTimestampOf(node), MutableTimestampOf(scapegoat));
+ value_to_node_[ValueOf(node)] = node;
+ value_to_node_[ValueOf(scapegoat)] = scapegoat;
+ node = scapegoat;
+ }
+
+ // |node| may have only one child at this point.
+ assert(!RightOf(node) || !LeftOf(node));
+
+ uint32_t parent = ParentOf(node);
+ uint32_t child = RightOf(node) ? RightOf(node) : LeftOf(node);
+
+ // Orphan |node| and reconnect parent and child.
+ if (child)
+ MutableParentOf(child) = parent;
+
+ if (parent) {
+ if (LeftOf(parent) == node)
+ MutableLeftOf(parent) = child;
+ else
+ MutableRightOf(parent) = child;
+ }
+
+ MutableParentOf(node) = 0;
+ MutableLeftOf(node) = 0;
+ MutableRightOf(node) = 0;
+ UpdateNode(node);
+ const uint32_t orphan = node;
+
+ if (root_ == node)
+ root_ = child;
+
+ // Removal is finished. Start the balancing process.
+ bool needs_rebalancing = true;
+ node = child;
+
+ while (parent) {
+ UpdateNode(parent);
+
+ if (needs_rebalancing) {
+ const int parent_balance = BalanceOf(parent);
+
+ if (parent_balance == 1 || parent_balance == -1) {
+ // The height of the subtree was not changed.
+ needs_rebalancing = false;
+ } else {
+ if (RightOf(parent) == node) {
+ // Removed node from the right subtree.
+ if (parent_balance < -1) {
+ // Parent is left heavy, rotate right.
+ const uint32_t sibling = LeftOf(parent);
+ if (BalanceOf(sibling) > 0)
+ RotateLeft(sibling);
+ parent = RotateRight(parent);
+ }
+ } else {
+ // Removed node from the left subtree.
+ if (parent_balance > 1) {
+ // Parent is right heavy, rotate left.
+ const uint32_t sibling = RightOf(parent);
+ if (BalanceOf(sibling) < 0)
+ RotateRight(sibling);
+ parent = RotateLeft(parent);
+ }
+ }
+ }
+ }
+
+ assert(BalanceOf(parent) >= -1 && (BalanceOf(parent) <= 1));
+
+ node = parent;
+ parent = ParentOf(parent);
+ }
+
+ return orphan;
+}
+
+template <typename Val>
+uint32_t MoveToFront<Val>::RotateLeft(const uint32_t node) {
+ const uint32_t pivot = RightOf(node);
+ assert(pivot);
+
+ // LeftOf(pivot) gets attached to node in place of pivot.
+ MutableRightOf(node) = LeftOf(pivot);
+ if (RightOf(node))
+ MutableParentOf(RightOf(node)) = node;
+
+ // Pivot gets attached to ParentOf(node) in place of node.
+ MutableParentOf(pivot) = ParentOf(node);
+ if (!ParentOf(node))
+ root_ = pivot;
+ else if (IsLeftChild(node))
+ MutableLeftOf(ParentOf(node)) = pivot;
+ else
+ MutableRightOf(ParentOf(node)) = pivot;
+
+ // Node is child of pivot.
+ MutableLeftOf(pivot) = node;
+ MutableParentOf(node) = pivot;
+
+ // Update both node and pivot. Pivot is the new parent of node, so node should
+ // be updated first.
+ UpdateNode(node);
+ UpdateNode(pivot);
+
+ return pivot;
+}
+
+template <typename Val>
+uint32_t MoveToFront<Val>::RotateRight(const uint32_t node) {
+ const uint32_t pivot = LeftOf(node);
+ assert(pivot);
+
+ // RightOf(pivot) gets attached to node in place of pivot.
+ MutableLeftOf(node) = RightOf(pivot);
+ if (LeftOf(node))
+ MutableParentOf(LeftOf(node)) = node;
+
+ // Pivot gets attached to ParentOf(node) in place of node.
+ MutableParentOf(pivot) = ParentOf(node);
+ if (!ParentOf(node))
+ root_ = pivot;
+ else if (IsLeftChild(node))
+ MutableLeftOf(ParentOf(node)) = pivot;
+ else
+ MutableRightOf(ParentOf(node)) = pivot;
+
+ // Node is child of pivot.
+ MutableRightOf(pivot) = node;
+ MutableParentOf(node) = pivot;
+
+ // Update both node and pivot. Pivot is the new parent of node, so node should
+ // be updated first.
+ UpdateNode(node);
+ UpdateNode(pivot);
+
+ return pivot;
+}
+
+template <typename Val>
+void MoveToFront<Val>::UpdateNode(uint32_t node) {
+ MutableSizeOf(node) = 1 + SizeOf(LeftOf(node)) + SizeOf(RightOf(node));
+ MutableHeightOf(node) =
+ 1 + std::max(HeightOf(LeftOf(node)), HeightOf(RightOf(node)));
+}
+
+} // namespace spvutils
+
+#endif // LIBSPIRV_UTIL_MOVE_TO_FRONT_H_
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 350b09df..850f15ee 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -209,25 +209,17 @@ bool ValidationState_t::IsDefinedId(uint32_t id) const {
}
const Instruction* ValidationState_t::FindDef(uint32_t id) const {
- if (all_definitions_.count(id) == 0) {
+ auto it = all_definitions_.find(id);
+ if (it == all_definitions_.end())
return nullptr;
- } else {
- /// We are in a const function, so we cannot use defs.operator[]().
- /// Luckily we know the key exists, so defs_.at() won't throw an
- /// exception.
- return all_definitions_.at(id);
- }
+ return it->second;
}
Instruction* ValidationState_t::FindDef(uint32_t id) {
- if (all_definitions_.count(id) == 0) {
+ auto it = all_definitions_.find(id);
+ if (it == all_definitions_.end())
return nullptr;
- } else {
- /// We are in a const function, so we cannot use defs.operator[]().
- /// Luckily we know the key exists, so defs_.at() won't throw an
- /// exception.
- return all_definitions_.at(id);
- }
+ return it->second;
}
// Increments the instruction count. Used for diagnostic
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 87a80ce4..d94093b1 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -268,7 +268,7 @@ class ValidationState_t {
Instruction* FindDef(uint32_t id);
/// Returns a deque of instructions in the order they appear in the binary
- const std::deque<Instruction>& ordered_instructions() {
+ const std::deque<Instruction>& ordered_instructions() const {
return ordered_instructions_;
}
diff --git a/source/validate.cpp b/source/validate.cpp
index 03bc6d9a..ad0fbb98 100644
--- a/source/validate.cpp
+++ b/source/validate.cpp
@@ -317,6 +317,8 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
// NOTE: Copy each instruction for easier processing
std::vector<spv_instruction_t> instructions;
+ // Expect average instruction length to be a bit over 2 words.
+ instructions.reserve(binary->wordCount / 2);
uint64_t index = SPV_INDEX_INSTRUCTION;
while (index < binary->wordCount) {
uint16_t wordCount;
@@ -326,7 +328,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
spv_instruction_t inst;
spvInstructionCopy(&binary->code[index], static_cast<SpvOp>(opcode),
wordCount, endian, &inst);
- instructions.push_back(inst);
+ instructions.emplace_back(std::move(inst));
index += wordCount;
}
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index 926dadef..8ebcd93b 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -169,6 +169,17 @@ add_spvtools_unittest(
SRCS bit_stream.cpp
LIBS ${SPIRV_TOOLS})
+add_spvtools_unittest(
+ TARGET huffman_codec
+ SRCS huffman_codec.cpp
+ LIBS ${SPIRV_TOOLS})
+
+add_spvtools_unittest(
+ TARGET move_to_front
+ SRCS move_to_front_test.cpp
+ LIBS ${SPIRV_TOOLS})
+
+add_subdirectory(comp)
add_subdirectory(opt)
-add_subdirectory(val)
add_subdirectory(stats)
+add_subdirectory(val)
diff --git a/test/bit_stream.cpp b/test/bit_stream.cpp
index 8deeb4e8..3fcdc147 100644
--- a/test/bit_stream.cpp
+++ b/test/bit_stream.cpp
@@ -36,6 +36,7 @@ using spvutils::StreamToBits;
using spvutils::GetLowerBits;
using spvutils::EncodeZigZag;
using spvutils::DecodeZigZag;
+using spvutils::Log2U64;
// A simple and inefficient implementatition of BitWriterInterface,
// using std::stringstream. Intended for tests only.
@@ -88,6 +89,10 @@ class BitReaderFromString : public BitReaderInterface {
return sub.length();
}
+ size_t GetNumReadBits() const override {
+ return pos_;
+ }
+
bool ReachedEnd() const override {
return pos_ >= str_.length();
}
@@ -101,6 +106,45 @@ class BitReaderFromString : public BitReaderInterface {
size_t pos_;
};
+TEST(Log2U16, Test) {
+ EXPECT_EQ(0u, Log2U64(0));
+ EXPECT_EQ(0u, Log2U64(1));
+ EXPECT_EQ(1u, Log2U64(2));
+ EXPECT_EQ(1u, Log2U64(3));
+ EXPECT_EQ(2u, Log2U64(4));
+ EXPECT_EQ(2u, Log2U64(5));
+ EXPECT_EQ(2u, Log2U64(6));
+ EXPECT_EQ(2u, Log2U64(7));
+ EXPECT_EQ(3u, Log2U64(8));
+ EXPECT_EQ(3u, Log2U64(9));
+ EXPECT_EQ(3u, Log2U64(10));
+ EXPECT_EQ(3u, Log2U64(11));
+ EXPECT_EQ(3u, Log2U64(12));
+ EXPECT_EQ(3u, Log2U64(13));
+ EXPECT_EQ(3u, Log2U64(14));
+ EXPECT_EQ(3u, Log2U64(15));
+ EXPECT_EQ(4u, Log2U64(16));
+ EXPECT_EQ(4u, Log2U64(17));
+ EXPECT_EQ(5u, Log2U64(35));
+ EXPECT_EQ(6u, Log2U64(72));
+ EXPECT_EQ(7u, Log2U64(255));
+ EXPECT_EQ(8u, Log2U64(256));
+ EXPECT_EQ(15u, Log2U64(65535));
+ EXPECT_EQ(16u, Log2U64(65536));
+ EXPECT_EQ(19u, Log2U64(0xFFFFF));
+ EXPECT_EQ(23u, Log2U64(0xFFFFFF));
+ EXPECT_EQ(27u, Log2U64(0xFFFFFFF));
+ EXPECT_EQ(31u, Log2U64(0xFFFFFFFF));
+ EXPECT_EQ(35u, Log2U64(0xFFFFFFFFF));
+ EXPECT_EQ(39u, Log2U64(0xFFFFFFFFFF));
+ EXPECT_EQ(43u, Log2U64(0xFFFFFFFFFFF));
+ EXPECT_EQ(47u, Log2U64(0xFFFFFFFFFFFF));
+ EXPECT_EQ(51u, Log2U64(0xFFFFFFFFFFFFF));
+ EXPECT_EQ(55u, Log2U64(0xFFFFFFFFFFFFFF));
+ EXPECT_EQ(59u, Log2U64(0xFFFFFFFFFFFFFFF));
+ EXPECT_EQ(63u, Log2U64(0xFFFFFFFFFFFFFFFF));
+}
+
TEST(NumBitsToNumWords, Word8) {
EXPECT_EQ(0u, NumBitsToNumWords<8>(0));
EXPECT_EQ(1u, NumBitsToNumWords<8>(1));
@@ -424,6 +468,23 @@ TEST(BitWriterStringStream, WriteBits) {
EXPECT_EQ("11001", writer.GetStreamRaw());
}
+TEST(BitWriterStringStream, WriteUnencodedU8) {
+ BitWriterStringStream writer;
+ const uint8_t bits = 127;
+ writer.WriteUnencoded(bits);
+ EXPECT_EQ(8u, writer.GetNumBits());
+ EXPECT_EQ("11111110", writer.GetStreamRaw());
+}
+
+TEST(BitWriterStringStream, WriteUnencodedS64) {
+ BitWriterStringStream writer;
+ const int64_t bits = std::numeric_limits<int64_t>::min() + 7;
+ writer.WriteUnencoded(bits);
+ EXPECT_EQ(64u, writer.GetNumBits());
+ EXPECT_EQ("1110000000000000000000000000000000000000000000000000000000000001",
+ writer.GetStreamRaw());
+}
+
TEST(BitWriterStringStream, WriteMultiple) {
BitWriterStringStream writer;
@@ -715,6 +776,29 @@ TEST(BitReaderWord64, ReadBitsTwoWords) {
EXPECT_TRUE(reader.ReachedEnd());
}
+TEST(BitReaderFromString, ReadUnencodedU8) {
+ BitReaderFromString reader("11111110");
+ uint8_t val = 0;
+ ASSERT_TRUE(reader.ReadUnencoded(&val));
+ EXPECT_EQ(8u, reader.GetNumReadBits());
+ EXPECT_EQ(127, val);
+}
+
+TEST(BitReaderFromString, ReadUnencodedU16Fail) {
+ BitReaderFromString reader("11111110");
+ uint16_t val = 0;
+ ASSERT_FALSE(reader.ReadUnencoded(&val));
+}
+
+TEST(BitReaderFromString, ReadUnencodedS64) {
+ BitReaderFromString reader(
+ "1110000000000000000000000000000000000000000000000000000000000001");
+ int64_t val = 0;
+ ASSERT_TRUE(reader.ReadUnencoded(&val));
+ EXPECT_EQ(64u, reader.GetNumReadBits());
+ EXPECT_EQ(std::numeric_limits<int64_t>::min() + 7, val);
+}
+
TEST(BitReaderWord64, FromU8) {
std::vector<uint8_t> buffer = {
0xAA, 0xBB, 0xCC, 0xDD,
@@ -1135,4 +1219,68 @@ TEST(VariableWidthWriteRead, VariedNumbersChunkLength8) {
EXPECT_EQ(expected_values, actual_values);
}
+TEST(FixedWidthWrite, Val0Max3) {
+ BitWriterStringStream writer;
+ writer.WriteFixedWidth(0, 3);
+ EXPECT_EQ("00", writer.GetStreamRaw());
+}
+
+TEST(FixedWidthWrite, Val0Max5) {
+ BitWriterStringStream writer;
+ writer.WriteFixedWidth(0, 5);
+ EXPECT_EQ("000", writer.GetStreamRaw());
+}
+
+TEST(FixedWidthWrite, Val0Max255) {
+ BitWriterStringStream writer;
+ writer.WriteFixedWidth(0, 255);
+ EXPECT_EQ("00000000", writer.GetStreamRaw());
+}
+
+TEST(FixedWidthWrite, Val3Max8) {
+ BitWriterStringStream writer;
+ writer.WriteFixedWidth(3, 8);
+ EXPECT_EQ("1100", writer.GetStreamRaw());
+}
+
+TEST(FixedWidthWrite, Val15Max127) {
+ BitWriterStringStream writer;
+ writer.WriteFixedWidth(15, 127);
+ EXPECT_EQ("1111000", writer.GetStreamRaw());
+}
+
+TEST(FixedWidthRead, Val0Max3) {
+ BitReaderFromString reader("0011111");
+ uint64_t val = 0;
+ ASSERT_TRUE(reader.ReadFixedWidth(&val, 3));
+ EXPECT_EQ(0u, val);
+}
+
+TEST(FixedWidthRead, Val0Max5) {
+ BitReaderFromString reader("0001010101");
+ uint64_t val = 0;
+ ASSERT_TRUE(reader.ReadFixedWidth(&val, 5));
+ EXPECT_EQ(0u, val);
+}
+
+TEST(FixedWidthRead, Val3Max8) {
+ BitReaderFromString reader("11001010101");
+ uint64_t val = 0;
+ ASSERT_TRUE(reader.ReadFixedWidth(&val, 8));
+ EXPECT_EQ(3u, val);
+}
+
+TEST(FixedWidthRead, Val15Max127) {
+ BitReaderFromString reader("111100010101");
+ uint64_t val = 0;
+ ASSERT_TRUE(reader.ReadFixedWidth(&val, 127));
+ EXPECT_EQ(15u, val);
+}
+
+TEST(FixedWidthRead, Fail) {
+ BitReaderFromString reader("111100");
+ uint64_t val = 0;
+ ASSERT_FALSE(reader.ReadFixedWidth(&val, 127));
+}
+
} // anonymous namespace
diff --git a/test/comp/CMakeLists.txt b/test/comp/CMakeLists.txt
new file mode 100644
index 00000000..f769b8c9
--- /dev/null
+++ b/test/comp/CMakeLists.txt
@@ -0,0 +1,23 @@
+# Copyright (c) 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set(VAL_TEST_COMMON_SRCS
+ ${CMAKE_CURRENT_SOURCE_DIR}/../test_fixture.h
+ ${CMAKE_CURRENT_SOURCE_DIR}/../unit_spirv.h
+)
+
+add_spvtools_unittest(TARGET markv_codec
+ SRCS markv_codec_test.cpp ${VAL_TEST_COMMON_SRCS}
+ LIBS SPIRV-Tools-comp ${SPIRV_TOOLS}
+)
diff --git a/test/comp/markv_codec_test.cpp b/test/comp/markv_codec_test.cpp
new file mode 100644
index 00000000..c43cc77a
--- /dev/null
+++ b/test/comp/markv_codec_test.cpp
@@ -0,0 +1,433 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for unique type declaration rules validator.
+
+#include <functional>
+#include <memory>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "spirv-tools/markv.h"
+#include "test_fixture.h"
+#include "unit_spirv.h"
+
+namespace {
+
+using spvtest::ScopedContext;
+
+void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
+ const spv_position_t& position,
+ const char* message) {
+ switch (level) {
+ case SPV_MSG_FATAL:
+ case SPV_MSG_INTERNAL_ERROR:
+ case SPV_MSG_ERROR:
+ std::cerr << "error: " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_WARNING:
+ std::cout << "warning: " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_INFO:
+ std::cout << "info: " << position.index << ": " << message << std::endl;
+ break;
+ default:
+ break;
+ }
+}
+
+// Compiles |code| to SPIR-V |words|.
+void Compile(const std::string& code, std::vector<uint32_t>* words,
+ uint32_t options = SPV_TEXT_TO_BINARY_OPTION_NONE,
+ spv_target_env env = SPV_ENV_UNIVERSAL_1_2) {
+ ScopedContext ctx(env);
+ SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler);
+
+ spv_binary spirv_binary;
+ ASSERT_EQ(SPV_SUCCESS, spvTextToBinaryWithOptions(
+ ctx.context, code.c_str(), code.size(), options, &spirv_binary, nullptr));
+
+ *words = std::vector<uint32_t>(
+ spirv_binary->code, spirv_binary->code + spirv_binary->wordCount);
+
+ spvBinaryDestroy(spirv_binary);
+}
+
+// Disassembles SPIR-V |words| to |out_text|.
+void Disassemble(const std::vector<uint32_t>& words,
+ std::string* out_text,
+ spv_target_env env = SPV_ENV_UNIVERSAL_1_2) {
+ ScopedContext ctx(env);
+ SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler);
+
+ spv_text text = nullptr;
+ ASSERT_EQ(SPV_SUCCESS, spvBinaryToText(ctx.context, words.data(),
+ words.size(), 0, &text, nullptr));
+ assert(text);
+
+ *out_text = std::string(text->str, text->length);
+ spvTextDestroy(text);
+}
+
+// Encodes SPIR-V |words| to |markv_binary|. |comments| context snippets of
+// disassembly and bit sequences for debugging.
+void Encode(const std::vector<uint32_t>& words,
+ spv_markv_binary* markv_binary,
+ std::string* comments,
+ spv_target_env env = SPV_ENV_UNIVERSAL_1_2) {
+ ScopedContext ctx(env);
+ SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler);
+
+ std::unique_ptr<spv_markv_encoder_options_t,
+ std::function<void(spv_markv_encoder_options_t*)>> options(
+ spvMarkvEncoderOptionsCreate(), &spvMarkvEncoderOptionsDestroy);
+ spv_text spv_text_comments;
+ ASSERT_EQ(SPV_SUCCESS, spvSpirvToMarkv(ctx.context, words.data(),
+ words.size(), options.get(),
+ markv_binary, &spv_text_comments,
+ nullptr));
+
+ *comments = std::string(spv_text_comments->str, spv_text_comments->length);
+ spvTextDestroy(spv_text_comments);
+}
+
+// Decodes |markv_binary| to SPIR-V |words|.
+void Decode(const spv_markv_binary markv_binary,
+ std::vector<uint32_t>* words,
+ spv_target_env env = SPV_ENV_UNIVERSAL_1_2) {
+ ScopedContext ctx(env);
+ SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler);
+
+ spv_binary spirv_binary = nullptr;
+ std::unique_ptr<spv_markv_decoder_options_t,
+ std::function<void(spv_markv_decoder_options_t*)>> options(
+ spvMarkvDecoderOptionsCreate(), &spvMarkvDecoderOptionsDestroy);
+ ASSERT_EQ(SPV_SUCCESS, spvMarkvToSpirv(ctx.context, markv_binary->data,
+ markv_binary->length, options.get(),
+ &spirv_binary, nullptr, nullptr));
+
+ *words = std::vector<uint32_t>(
+ spirv_binary->code, spirv_binary->code + spirv_binary->wordCount);
+
+ spvBinaryDestroy(spirv_binary);
+}
+
+// Encodes/decodes |original|, assembles/dissasembles |original|, then compares
+// the results of the two operations.
+void TestEncodeDecode(const std::string& original_text) {
+ std::vector<uint32_t> expected_binary;
+ Compile(original_text, &expected_binary);
+ ASSERT_FALSE(expected_binary.empty());
+
+ std::string expected_text;
+ Disassemble(expected_binary, &expected_text);
+ ASSERT_FALSE(expected_text.empty());
+
+ std::vector<uint32_t> binary_to_encode;
+ Compile(original_text, &binary_to_encode,
+ SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
+ ASSERT_FALSE(binary_to_encode.empty());
+
+ spv_markv_binary markv_binary = nullptr;
+ std::string encoder_comments;
+ Encode(binary_to_encode, &markv_binary, &encoder_comments);
+ ASSERT_NE(nullptr, markv_binary);
+
+ // std::cerr << encoder_comments << std::endl;
+ // std::cerr << "SPIR-V size: " << expected_binary.size() * 4 << std::endl;
+ // std::cerr << "MARK-V size: " << markv_binary->length << std::endl;
+
+ std::vector<uint32_t> decoded_binary;
+ Decode(markv_binary, &decoded_binary);
+ ASSERT_FALSE(decoded_binary.empty());
+
+ EXPECT_EQ(expected_binary, decoded_binary) << encoder_comments;
+
+ std::string decoded_text;
+ Disassemble(decoded_binary, &decoded_text);
+ ASSERT_FALSE(decoded_text.empty());
+
+ EXPECT_EQ(expected_text, decoded_text) << encoder_comments;
+
+ spvMarkvBinaryDestroy(markv_binary);
+}
+
+TEST(Markv, U32Literal) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%u32 = OpTypeInt 32 0
+%100 = OpConstant %u32 0
+%200 = OpConstant %u32 1
+%300 = OpConstant %u32 4294967295
+)");
+}
+
+TEST(Markv, S32Literal) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%s32 = OpTypeInt 32 1
+%100 = OpConstant %s32 0
+%200 = OpConstant %s32 1
+%300 = OpConstant %s32 -1
+%400 = OpConstant %s32 2147483647
+%500 = OpConstant %s32 -2147483648
+)");
+}
+
+TEST(Markv, U64Literal) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability Int64
+OpMemoryModel Logical GLSL450
+%u64 = OpTypeInt 64 0
+%100 = OpConstant %u64 0
+%200 = OpConstant %u64 1
+%300 = OpConstant %u64 18446744073709551615
+)");
+}
+
+TEST(Markv, S64Literal) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability Int64
+OpMemoryModel Logical GLSL450
+%s64 = OpTypeInt 64 1
+%100 = OpConstant %s64 0
+%200 = OpConstant %s64 1
+%300 = OpConstant %s64 -1
+%400 = OpConstant %s64 9223372036854775807
+%500 = OpConstant %s64 -9223372036854775808
+)");
+}
+
+TEST(Markv, U16Literal) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability Int16
+OpMemoryModel Logical GLSL450
+%u16 = OpTypeInt 16 0
+%100 = OpConstant %u16 0
+%200 = OpConstant %u16 1
+%300 = OpConstant %u16 65535
+)");
+}
+
+TEST(Markv, S16Literal) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability Int16
+OpMemoryModel Logical GLSL450
+%s16 = OpTypeInt 16 1
+%100 = OpConstant %s16 0
+%200 = OpConstant %s16 1
+%300 = OpConstant %s16 -1
+%400 = OpConstant %s16 32767
+%500 = OpConstant %s16 -32768
+)");
+}
+
+TEST(Markv, F32Literal) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+%f32 = OpTypeFloat 32
+%100 = OpConstant %f32 0
+%200 = OpConstant %f32 1
+%300 = OpConstant %f32 0.1
+%400 = OpConstant %f32 -0.1
+)");
+}
+
+TEST(Markv, F64Literal) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability Float64
+OpMemoryModel Logical GLSL450
+%f64 = OpTypeFloat 64
+%100 = OpConstant %f64 0
+%200 = OpConstant %f64 1
+%300 = OpConstant %f64 0.1
+%400 = OpConstant %f64 -0.1
+)");
+}
+
+TEST(Markv, F16Literal) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpCapability Float16
+OpMemoryModel Logical GLSL450
+%f16 = OpTypeFloat 16
+%100 = OpConstant %f16 0
+%200 = OpConstant %f16 1
+%300 = OpConstant %f16 0.1
+%400 = OpConstant %f16 -0.1
+)");
+}
+
+TEST(Markv, StringLiteral) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpExtension "SPV_KHR_16bit_storage"
+OpExtension "xxx"
+OpExtension "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+OpExtension ""
+OpMemoryModel Logical GLSL450
+)");
+}
+
+TEST(Markv, WithFunction) {
+ TestEncodeDecode(R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpExtension "SPV_KHR_16bit_storage"
+OpMemoryModel Physical32 OpenCL
+%f32 = OpTypeFloat 32
+%u32 = OpTypeInt 32 0
+%void = OpTypeVoid
+%void_func = OpTypeFunction %void
+%100 = OpConstant %u32 1
+%200 = OpConstant %u32 2
+%main = OpFunction %void None %void_func
+%entry_main = OpLabel
+%300 = OpIAdd %u32 %100 %200
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST(Markv, ForwardDeclaredId) {
+ TestEncodeDecode(R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+OpEntryPoint Kernel %1 "simple_kernel"
+%2 = OpTypeInt 32 0
+%3 = OpTypeVector %2 2
+%4 = OpConstant %2 2
+%5 = OpTypeArray %2 %4
+%6 = OpTypeVoid
+%7 = OpTypeFunction %6
+%1 = OpFunction %6 None %7
+%8 = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST(Markv, WithSwitch) {
+ TestEncodeDecode(R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Int64
+OpMemoryModel Physical32 OpenCL
+%u64 = OpTypeInt 64 0
+%void = OpTypeVoid
+%void_func = OpTypeFunction %void
+%val = OpConstant %u64 1
+%main = OpFunction %void None %void_func
+%entry_main = OpLabel
+OpSwitch %val %default 1 %case1 1000000000000 %case2
+%case1 = OpLabel
+OpNop
+OpBranch %after_switch
+%case2 = OpLabel
+OpNop
+OpBranch %after_switch
+%default = OpLabel
+OpNop
+OpBranch %after_switch
+%after_switch = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST(Markv, WithLoop) {
+ TestEncodeDecode(R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+OpMemoryModel Physical32 OpenCL
+%void = OpTypeVoid
+%void_func = OpTypeFunction %void
+%main = OpFunction %void None %void_func
+%entry_main = OpLabel
+OpLoopMerge %merge %continue DontUnroll|DependencyLength 10
+OpBranch %begin_loop
+%begin_loop = OpLabel
+OpNop
+OpBranch %continue
+%continue = OpLabel
+OpNop
+OpBranch %begin_loop
+%merge = OpLabel
+OpReturn
+OpFunctionEnd
+)");
+}
+
+TEST(Markv, WithDecorate) {
+ TestEncodeDecode(R"(
+OpCapability Shader
+OpCapability Linkage
+OpMemoryModel Logical GLSL450
+OpDecorate %1 ArrayStride 4
+OpDecorate %1 Uniform
+%2 = OpTypeFloat 32
+%1 = OpTypeRuntimeArray %2
+)");
+}
+
+TEST(Markv, WithExtInst) {
+ TestEncodeDecode(R"(
+OpCapability Addresses
+OpCapability Kernel
+OpCapability GenericPointer
+OpCapability Linkage
+%opencl = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical32 OpenCL
+%f32 = OpTypeFloat 32
+%void = OpTypeVoid
+%void_func = OpTypeFunction %void
+%100 = OpConstant %f32 1.1
+%main = OpFunction %void None %void_func
+%entry_main = OpLabel
+%200 = OpExtInst %f32 %opencl cos %100
+OpReturn
+OpFunctionEnd
+)");
+}
+
+} // namespace
diff --git a/test/huffman_codec.cpp b/test/huffman_codec.cpp
new file mode 100644
index 00000000..80f7d8f8
--- /dev/null
+++ b/test/huffman_codec.cpp
@@ -0,0 +1,220 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Contains utils for reading, writing and debug printing bit streams.
+
+#include <map>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+
+#include "util/huffman_codec.h"
+#include "util/bit_stream.h"
+#include "gmock/gmock.h"
+
+namespace {
+
+using spvutils::HuffmanCodec;
+using spvutils::BitsToStream;
+
+const std::map<std::string, uint32_t>& GetTestSet() {
+ static const std::map<std::string, uint32_t> hist = {
+ {"a", 4},
+ {"e", 7},
+ {"f", 3},
+ {"h", 2},
+ {"i", 3},
+ {"m", 2},
+ {"n", 2},
+ {"s", 2},
+ {"t", 2},
+ {"l", 1},
+ {"o", 2},
+ {"p", 1},
+ {"r", 1},
+ {"u", 1},
+ {"x", 1},
+ };
+
+ return hist;
+}
+
+class TestBitReader {
+ public:
+ TestBitReader(const std::string& bits) : bits_(bits) {}
+
+ bool ReadBit(bool* bit) {
+ if (pos_ < bits_.length()) {
+ *bit = bits_[pos_++] == '0' ? false : true;
+ return true;
+ }
+ return false;
+ }
+
+ private:
+ std::string bits_;
+ size_t pos_ = 0;
+};
+
+TEST(Huffman, PrintTree) {
+ HuffmanCodec<std::string> huffman(GetTestSet());
+ std::stringstream ss;
+ huffman.PrintTree(ss);
+
+ const std::string expected = std::string(R"(
+15-----7------e
+ 8------4------a
+ 4------2------m
+ 2------n
+19-----8------4------2------o
+ 2------s
+ 4------2------t
+ 2------1------l
+ 1------p
+ 11-----5------2------1------r
+ 1------u
+ 3------f
+ 6------3------i
+ 3------1------x
+ 2------h
+)").substr(1);
+
+ EXPECT_EQ(expected, ss.str());
+}
+
+TEST(Huffman, PrintTable) {
+ HuffmanCodec<std::string> huffman(GetTestSet());
+ std::stringstream ss;
+ huffman.PrintTable(ss);
+
+ const std::string expected = std::string(R"(
+e 7 11
+a 4 101
+i 3 0001
+f 3 0010
+t 2 0101
+s 2 0110
+o 2 0111
+n 2 1000
+m 2 1001
+h 2 00000
+x 1 00001
+u 1 00110
+r 1 00111
+p 1 01000
+l 1 01001
+)").substr(1);
+
+ EXPECT_EQ(expected, ss.str());
+}
+
+TEST(Huffman, TestValidity) {
+ HuffmanCodec<std::string> huffman(GetTestSet());
+ const auto& encoding_table = huffman.GetEncodingTable();
+ std::vector<std::string> codes;
+ for (const auto& entry : encoding_table) {
+ codes.push_back(BitsToStream(entry.second.first, entry.second.second));
+ }
+
+ std::sort(codes.begin(), codes.end());
+
+ ASSERT_LT(codes.size(), 20u) << "Inefficient test ahead";
+
+ for (size_t i = 0; i < codes.size(); ++i) {
+ for (size_t j = i + 1; j < codes.size(); ++j) {
+ ASSERT_FALSE(codes[i] == codes[j].substr(0, codes[i].length()))
+ << codes[i] << " is prefix of " << codes[j];
+ }
+ }
+}
+
+TEST(Huffman, TestEncode) {
+ HuffmanCodec<std::string> huffman(GetTestSet());
+
+ uint64_t bits = 0;
+ size_t num_bits = 0;
+
+ EXPECT_TRUE(huffman.Encode("e", &bits, &num_bits));
+ EXPECT_EQ(2u, num_bits);
+ EXPECT_EQ("11", BitsToStream(bits, num_bits));
+
+ EXPECT_TRUE(huffman.Encode("a", &bits, &num_bits));
+ EXPECT_EQ(3u, num_bits);
+ EXPECT_EQ("101", BitsToStream(bits, num_bits));
+
+ EXPECT_TRUE(huffman.Encode("x", &bits, &num_bits));
+ EXPECT_EQ(5u, num_bits);
+ EXPECT_EQ("00001", BitsToStream(bits, num_bits));
+
+ EXPECT_FALSE(huffman.Encode("y", &bits, &num_bits));
+}
+
+TEST(Huffman, TestDecode) {
+ HuffmanCodec<std::string> huffman(GetTestSet());
+ TestBitReader bit_reader("01001""0001""1000""00110""00001""00");
+ auto read_bit = [&bit_reader](bool* bit) {
+ return bit_reader.ReadBit(bit);
+ };
+
+ std::string decoded;
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ("l", decoded);
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ("i", decoded);
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ("n", decoded);
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ("u", decoded);
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ("x", decoded);
+
+ ASSERT_FALSE(huffman.DecodeFromStream(read_bit, &decoded));
+}
+
+TEST(Huffman, TestDecodeNumbers) {
+ const std::map<uint32_t, uint32_t> hist = { {1, 10}, {2, 5}, {3, 15} };
+ HuffmanCodec<uint32_t> huffman(hist);
+
+ TestBitReader bit_reader("1""1""01""00""01""1");
+ auto read_bit = [&bit_reader](bool* bit) {
+ return bit_reader.ReadBit(bit);
+ };
+
+ uint32_t decoded;
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ(3u, decoded);
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ(3u, decoded);
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ(2u, decoded);
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ(1u, decoded);
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ(2u, decoded);
+
+ ASSERT_TRUE(huffman.DecodeFromStream(read_bit, &decoded));
+ EXPECT_EQ(3u, decoded);
+}
+
+} // anonymous namespace
diff --git a/test/move_to_front_test.cpp b/test/move_to_front_test.cpp
new file mode 100644
index 00000000..89fc3a8f
--- /dev/null
+++ b/test/move_to_front_test.cpp
@@ -0,0 +1,785 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <algorithm>
+#include <iostream>
+#include <set>
+
+#include "gmock/gmock.h"
+#include "util/move_to_front.h"
+
+namespace {
+
+using spvutils::MoveToFront;
+
+// Class used to test the inner workings of MoveToFront.
+class MoveToFrontTester : public MoveToFront<uint32_t> {
+ public:
+ // Inserts the value in the internal tree data structure. For testing only.
+ void TestInsert(uint32_t val) {
+ InsertNode(CreateNode(val, val));
+ }
+
+ // Removes the value from the internal tree data structure. For testing only.
+ void TestRemove(uint32_t val) {
+ const auto it = value_to_node_.find(val);
+ assert(it != value_to_node_.end());
+ RemoveNode(it->second);
+ }
+
+ // Prints the internal tree data structure to |out|. For testing only.
+ void PrintTree(std::ostream& out, bool print_timestamp = false) const {
+ if (root_)
+ PrintTreeInternal(out, root_, 1, print_timestamp);
+ }
+
+ // Returns node handle corresponding to the value. The value may not be in the tree.
+ uint32_t GetNodeHandle(uint32_t value) const {
+ const auto it = value_to_node_.find(value);
+ if (it == value_to_node_.end())
+ return 0;
+
+ return it->second;
+ }
+
+ // Returns total node count (both those in the tree and removed,
+ // but not the NIL singleton).
+ size_t GetTotalNodeCount() const {
+ assert(nodes_.size());
+ return nodes_.size() - 1;
+ }
+
+ private:
+ // Prints the internal tree data structure for debug purposes in the following
+ // format:
+ // 10H3S4----5H1S1-----D2
+ // 15H2S2----12H1S1----D3
+ // Right links are horizontal, left links step down one line.
+ // 5H1S1 is read as value 5, height 1, size 1. Optionally node label can also
+ // contain timestamp (5H1S1T15). D3 stands for depth 3.
+ void PrintTreeInternal(std::ostream& out, uint32_t node, size_t depth,
+ bool print_timestamp) const;
+};
+
+void MoveToFrontTester::PrintTreeInternal(
+ std::ostream& out, uint32_t node,
+ size_t depth, bool print_timestamp) const {
+ if (!node) {
+ out << "D" << depth - 1 << std::endl;
+ return;
+ }
+
+ const size_t kTextFieldWvaluethWithoutTimestamp = 10;
+ const size_t kTextFieldWvaluethWithTimestamp = 14;
+ const size_t text_field_wvalueth = print_timestamp ?
+ kTextFieldWvaluethWithTimestamp : kTextFieldWvaluethWithoutTimestamp;
+
+ std::stringstream label;
+ label << ValueOf(node) << "H" << HeightOf(node) << "S" << SizeOf(node);
+ if (print_timestamp)
+ label << "T" << TimestampOf(node);
+ const size_t label_length = label.str().length();
+ if (label_length < text_field_wvalueth)
+ label << std::string(text_field_wvalueth - label_length, '-');
+
+ out << label.str();
+
+ PrintTreeInternal(out, RightOf(node), depth + 1, print_timestamp);
+
+ if (LeftOf(node)) {
+ out << std::string(depth * text_field_wvalueth, ' ');
+ PrintTreeInternal(out, LeftOf(node), depth + 1, print_timestamp);
+ }
+}
+
+void CheckTree(const MoveToFrontTester& mtf, const std::string& expected,
+ bool print_timestamp = false) {
+ std::stringstream ss;
+ mtf.PrintTree(ss, print_timestamp);
+ EXPECT_EQ(expected, ss.str());
+}
+
+TEST(MoveToFront, EmptyTree) {
+ MoveToFrontTester mtf;
+ CheckTree(mtf, std::string());
+}
+
+TEST(MoveToFront, InsertLeftRotation) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(30);
+ mtf.TestInsert(20);
+
+ CheckTree(mtf, std::string(R"(
+30H2S2----20H1S1----D2
+)").substr(1));
+
+ mtf.TestInsert(10);
+ CheckTree(mtf, std::string(R"(
+20H2S3----10H1S1----D2
+ 30H1S1----D2
+)").substr(1));
+}
+
+TEST(MoveToFront, InsertRightRotation) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(10);
+ mtf.TestInsert(20);
+
+ CheckTree(mtf, std::string(R"(
+10H2S2----D1
+ 20H1S1----D2
+)").substr(1));
+
+ mtf.TestInsert(30);
+ CheckTree(mtf, std::string(R"(
+20H2S3----10H1S1----D2
+ 30H1S1----D2
+)").substr(1));
+}
+
+TEST(MoveToFront, InsertRightLeftRotation) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(30);
+ mtf.TestInsert(20);
+
+ CheckTree(mtf, std::string(R"(
+30H2S2----20H1S1----D2
+)").substr(1));
+
+ mtf.TestInsert(25);
+ CheckTree(mtf, std::string(R"(
+25H2S3----20H1S1----D2
+ 30H1S1----D2
+)").substr(1));
+}
+
+TEST(MoveToFront, InsertLeftRightRotation) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(10);
+ mtf.TestInsert(20);
+
+ CheckTree(mtf, std::string(R"(
+10H2S2----D1
+ 20H1S1----D2
+)").substr(1));
+
+ mtf.TestInsert(15);
+ CheckTree(mtf, std::string(R"(
+15H2S3----10H1S1----D2
+ 20H1S1----D2
+)").substr(1));
+}
+
+TEST(MoveToFront, RemoveSingleton) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(10);
+ CheckTree(mtf, std::string(R"(
+10H1S1----D1
+)").substr(1));
+
+ mtf.TestRemove(10);
+ CheckTree(mtf, "");
+}
+
+TEST(MoveToFront, RemoveRootWithScapegoat) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(10);
+ mtf.TestInsert(5);
+ mtf.TestInsert(15);
+ CheckTree(mtf, std::string(R"(
+10H2S3----5H1S1-----D2
+ 15H1S1----D2
+)").substr(1));
+
+ mtf.TestRemove(10);
+ CheckTree(mtf, std::string(R"(
+15H2S2----5H1S1-----D2
+)").substr(1));
+}
+
+TEST(MoveToFront, RemoveRightRotation) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(10);
+ mtf.TestInsert(5);
+ mtf.TestInsert(15);
+ mtf.TestInsert(20);
+ CheckTree(mtf, std::string(R"(
+10H3S4----5H1S1-----D2
+ 15H2S2----D2
+ 20H1S1----D3
+)").substr(1));
+
+ mtf.TestRemove(5);
+
+ CheckTree(mtf, std::string(R"(
+15H2S3----10H1S1----D2
+ 20H1S1----D2
+)").substr(1));
+}
+
+TEST(MoveToFront, RemoveLeftRotation) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(10);
+ mtf.TestInsert(15);
+ mtf.TestInsert(5);
+ mtf.TestInsert(1);
+ CheckTree(mtf, std::string(R"(
+10H3S4----5H2S2-----1H1S1-----D3
+ 15H1S1----D2
+)").substr(1));
+
+ mtf.TestRemove(15);
+
+ CheckTree(mtf, std::string(R"(
+5H2S3-----1H1S1-----D2
+ 10H1S1----D2
+)").substr(1));
+}
+
+TEST(MoveToFront, RemoveLeftRightRotation) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(10);
+ mtf.TestInsert(15);
+ mtf.TestInsert(5);
+ mtf.TestInsert(12);
+ CheckTree(mtf, std::string(R"(
+10H3S4----5H1S1-----D2
+ 15H2S2----12H1S1----D3
+)").substr(1));
+
+ mtf.TestRemove(5);
+
+ CheckTree(mtf, std::string(R"(
+12H2S3----10H1S1----D2
+ 15H1S1----D2
+)").substr(1));
+}
+
+TEST(MoveToFront, RemoveRightLeftRotation) {
+ MoveToFrontTester mtf;
+
+ mtf.TestInsert(10);
+ mtf.TestInsert(15);
+ mtf.TestInsert(5);
+ mtf.TestInsert(8);
+ CheckTree(mtf, std::string(R"(
+10H3S4----5H2S2-----D2
+ 8H1S1-----D3
+ 15H1S1----D2
+)").substr(1));
+
+ mtf.TestRemove(15);
+
+ CheckTree(mtf, std::string(R"(
+8H2S3-----5H1S1-----D2
+ 10H1S1----D2
+)").substr(1));
+}
+
+TEST(MoveToFront, MultipleOperations) {
+ MoveToFrontTester mtf;
+ std::vector<uint32_t> vals =
+ { 5, 11, 12, 16, 15, 6, 14, 2, 7, 10, 4, 8, 9, 3, 1, 13 };
+
+ for (uint32_t i : vals) {
+ mtf.TestInsert(i);
+ }
+
+ CheckTree(mtf, std::string(R"(
+11H5S16---5H4S10----3H3S4-----2H2S2-----1H1S1-----D5
+ 4H1S1-----D4
+ 7H3S5-----6H1S1-----D4
+ 9H2S3-----8H1S1-----D5
+ 10H1S1----D5
+ 15H3S5----13H2S3----12H1S1----D4
+ 14H1S1----D4
+ 16H1S1----D3
+)").substr(1));
+
+ mtf.TestRemove(11);
+
+ CheckTree(mtf, std::string(R"(
+10H5S15---5H4S9-----3H3S4-----2H2S2-----1H1S1-----D5
+ 4H1S1-----D4
+ 7H3S4-----6H1S1-----D4
+ 9H2S2-----8H1S1-----D5
+ 15H3S5----13H2S3----12H1S1----D4
+ 14H1S1----D4
+ 16H1S1----D3
+)").substr(1));
+
+ mtf.TestInsert(11);
+
+ CheckTree(mtf, std::string(R"(
+10H5S16---5H4S9-----3H3S4-----2H2S2-----1H1S1-----D5
+ 4H1S1-----D4
+ 7H3S4-----6H1S1-----D4
+ 9H2S2-----8H1S1-----D5
+ 13H3S6----12H2S2----11H1S1----D4
+ 15H2S3----14H1S1----D4
+ 16H1S1----D4
+)").substr(1));
+
+ mtf.TestRemove(5);
+
+ CheckTree(mtf, std::string(R"(
+10H5S15---6H4S8-----3H3S4-----2H2S2-----1H1S1-----D5
+ 4H1S1-----D4
+ 8H2S3-----7H1S1-----D4
+ 9H1S1-----D4
+ 13H3S6----12H2S2----11H1S1----D4
+ 15H2S3----14H1S1----D4
+ 16H1S1----D4
+)").substr(1));
+
+ mtf.TestInsert(5);
+
+ CheckTree(mtf, std::string(R"(
+10H5S16---6H4S9-----3H3S5-----2H2S2-----1H1S1-----D5
+ 4H2S2-----D4
+ 5H1S1-----D5
+ 8H2S3-----7H1S1-----D4
+ 9H1S1-----D4
+ 13H3S6----12H2S2----11H1S1----D4
+ 15H2S3----14H1S1----D4
+ 16H1S1----D4
+)").substr(1));
+
+ mtf.TestRemove(2);
+ mtf.TestRemove(1);
+ mtf.TestRemove(4);
+ mtf.TestRemove(3);
+ mtf.TestRemove(6);
+ mtf.TestRemove(5);
+ mtf.TestRemove(7);
+ mtf.TestRemove(9);
+
+ CheckTree(mtf, std::string(R"(
+13H4S8----10H3S4----8H1S1-----D3
+ 12H2S2----11H1S1----D4
+ 15H2S3----14H1S1----D3
+ 16H1S1----D3
+)").substr(1));
+}
+
+TEST(MoveToFront, BiggerScaleTreeTest) {
+ MoveToFrontTester mtf;
+ std::set<uint32_t> all_vals;
+
+ const uint32_t kMagic1 = 2654435761;
+ const uint32_t kMagic2 = 10000;
+
+ for (uint32_t i = 1; i < 1000; ++i) {
+ const uint32_t val = (i * kMagic1) % kMagic2;
+ if (!all_vals.count(val)) {
+ mtf.TestInsert(val);
+ all_vals.insert(val);
+ }
+ }
+
+ for (uint32_t i = 1; i < 1000; ++i) {
+ const uint32_t val = (i * kMagic1) % kMagic2;
+ if (val % 2 == 0) {
+ mtf.TestRemove(val);
+ all_vals.erase(val);
+ }
+ }
+
+ for (uint32_t i = 1000; i < 2000; ++i) {
+ const uint32_t val = (i * kMagic1) % kMagic2;
+ if (!all_vals.count(val)) {
+ mtf.TestInsert(val);
+ all_vals.insert(val);
+ }
+ }
+
+ for (uint32_t i = 1; i < 2000; ++i) {
+ const uint32_t val = (i * kMagic1) % kMagic2;
+ if (val > 50) {
+ mtf.TestRemove(val);
+ all_vals.erase(val);
+ }
+ }
+
+ EXPECT_EQ(all_vals, std::set<uint32_t>({2, 4, 11, 13, 24, 33, 35, 37, 46}));
+
+ CheckTree(mtf, std::string(R"(
+33H4S9----11H3S5----2H2S2-----D3
+ 4H1S1-----D4
+ 13H2S2----D3
+ 24H1S1----D4
+ 37H2S3----35H1S1----D3
+ 46H1S1----D3
+)").substr(1));
+}
+
+TEST(MoveToFront, RankFromValue) {
+ MoveToFrontTester mtf;
+
+ size_t rank = 0;
+ EXPECT_FALSE(mtf.RankFromValue(1, &rank));
+
+ EXPECT_TRUE(mtf.Insert(1));
+ EXPECT_TRUE(mtf.Insert(2));
+ EXPECT_TRUE(mtf.Insert(3));
+ EXPECT_FALSE(mtf.Insert(2));
+ CheckTree(mtf, std::string(R"(
+2H2S3T2-------1H1S1T1-------D2
+ 3H1S1T3-------D2
+)").substr(1), /* print_timestamp = */ true);
+
+ EXPECT_FALSE(mtf.RankFromValue(4, &rank));
+
+ EXPECT_TRUE(mtf.RankFromValue(1, &rank));
+ EXPECT_EQ(3u, rank);
+
+ CheckTree(mtf, std::string(R"(
+3H2S3T3-------2H1S1T2-------D2
+ 1H1S1T4-------D2
+)").substr(1), /* print_timestamp = */ true);
+
+ EXPECT_TRUE(mtf.RankFromValue(1, &rank));
+ EXPECT_EQ(1u, rank);
+
+ EXPECT_TRUE(mtf.RankFromValue(3, &rank));
+ EXPECT_EQ(2u, rank);
+
+ EXPECT_TRUE(mtf.RankFromValue(2, &rank));
+ EXPECT_EQ(3u, rank);
+
+ EXPECT_TRUE(mtf.Insert(40));
+
+ EXPECT_TRUE(mtf.RankFromValue(1, &rank));
+ EXPECT_EQ(4u, rank);
+
+ EXPECT_TRUE(mtf.Insert(50));
+
+ EXPECT_TRUE(mtf.RankFromValue(1, &rank));
+ EXPECT_EQ(2u, rank);
+
+ CheckTree(mtf, std::string(R"(
+2H3S5T7-------3H1S1T6-------D2
+ 50H2S3T10-----40H1S1T8------D3
+ 1H1S1T11------D3
+)").substr(1), /* print_timestamp = */ true);
+
+ EXPECT_TRUE(mtf.RankFromValue(50, &rank));
+ EXPECT_EQ(2u, rank);
+
+ EXPECT_EQ(5u, mtf.GetSize());
+ CheckTree(mtf, std::string(R"(
+2H3S5T7-------3H1S1T6-------D2
+ 1H2S3T11------40H1S1T8------D3
+ 50H1S1T12-----D3
+)").substr(1), /* print_timestamp = */ true);
+
+ EXPECT_FALSE(mtf.RankFromValue(0, &rank));
+ EXPECT_FALSE(mtf.RankFromValue(20, &rank));
+}
+
+TEST(MoveToFront, ValueFromRank) {
+ MoveToFrontTester mtf;
+
+ uint32_t value = 0;
+ EXPECT_FALSE(mtf.ValueFromRank(0, &value));
+ EXPECT_FALSE(mtf.ValueFromRank(1, &value));
+
+ EXPECT_TRUE(mtf.Insert(1));
+ EXPECT_TRUE(mtf.Insert(2));
+ EXPECT_TRUE(mtf.Insert(3));
+
+ EXPECT_TRUE(mtf.ValueFromRank(3, &value));
+ EXPECT_EQ(1u, value);
+
+ EXPECT_TRUE(mtf.ValueFromRank(1, &value));
+ EXPECT_EQ(1u, value);
+
+ EXPECT_TRUE(mtf.ValueFromRank(2, &value));
+ EXPECT_EQ(3u, value);
+
+ EXPECT_EQ(3u, mtf.GetSize());
+
+ CheckTree(mtf, std::string(R"(
+1H2S3T5-------2H1S1T2-------D2
+ 3H1S1T6-------D2
+)").substr(1), /* print_timestamp = */ true);
+
+ EXPECT_TRUE(mtf.ValueFromRank(3, &value));
+ EXPECT_EQ(2u, value);
+
+ CheckTree(mtf, std::string(R"(
+3H2S3T6-------1H1S1T5-------D2
+ 2H1S1T7-------D2
+)").substr(1), /* print_timestamp = */ true);
+
+ EXPECT_TRUE(mtf.Insert(10));
+ CheckTree(mtf, std::string(R"(
+3H3S4T6-------1H1S1T5-------D2
+ 2H2S2T7-------D2
+ 10H1S1T8------D3
+)").substr(1), /* print_timestamp = */ true);
+
+ EXPECT_TRUE(mtf.ValueFromRank(1, &value));
+ EXPECT_EQ(10u, value);
+}
+
+TEST(MoveToFront, Remove) {
+ MoveToFrontTester mtf;
+
+ EXPECT_FALSE(mtf.Remove(1));
+ EXPECT_EQ(0u, mtf.GetTotalNodeCount());
+
+ EXPECT_TRUE(mtf.Insert(1));
+ EXPECT_TRUE(mtf.Insert(2));
+ EXPECT_TRUE(mtf.Insert(3));
+
+ CheckTree(mtf, std::string(R"(
+2H2S3T2-------1H1S1T1-------D2
+ 3H1S1T3-------D2
+)").substr(1), /* print_timestamp = */ true);
+
+ EXPECT_EQ(1u, mtf.GetNodeHandle(1));
+ EXPECT_EQ(3u, mtf.GetTotalNodeCount());
+ EXPECT_TRUE(mtf.Remove(1));
+ EXPECT_EQ(3u, mtf.GetTotalNodeCount());
+
+ CheckTree(mtf, std::string(R"(
+2H2S2T2-------D1
+ 3H1S1T3-------D2
+)").substr(1), /* print_timestamp = */ true);
+
+ uint32_t value = 0;
+ EXPECT_TRUE(mtf.ValueFromRank(2, &value));
+ EXPECT_EQ(2u, value);
+
+ CheckTree(mtf, std::string(R"(
+3H2S2T3-------D1
+ 2H1S1T4-------D2
+)").substr(1), /* print_timestamp = */ true);
+
+ EXPECT_TRUE(mtf.Insert(1));
+ EXPECT_EQ(1u, mtf.GetNodeHandle(1));
+ EXPECT_EQ(3u, mtf.GetTotalNodeCount());
+}
+
+TEST(MoveToFront, LargerScale) {
+ MoveToFrontTester mtf;
+ uint32_t value = 0;
+ size_t rank = 0;
+
+ for (uint32_t i = 1; i < 1000; ++i) {
+ ASSERT_TRUE(mtf.Insert(i));
+ ASSERT_EQ(i, mtf.GetSize());
+
+ ASSERT_TRUE(mtf.RankFromValue(i, &rank));
+ ASSERT_EQ(1u, rank);
+
+ ASSERT_TRUE(mtf.ValueFromRank(1, &value));
+ ASSERT_EQ(i, value);
+ }
+
+ ASSERT_TRUE(mtf.ValueFromRank(999, &value));
+ ASSERT_EQ(1u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(999, &value));
+ ASSERT_EQ(2u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(999, &value));
+ ASSERT_EQ(3u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(999, &value));
+ ASSERT_EQ(4u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(999, &value));
+ ASSERT_EQ(5u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(999, &value));
+ ASSERT_EQ(6u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(101, &value));
+ ASSERT_EQ(905u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(101, &value));
+ ASSERT_EQ(906u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(101, &value));
+ ASSERT_EQ(907u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(201, &value));
+ ASSERT_EQ(805u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(201, &value));
+ ASSERT_EQ(806u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(201, &value));
+ ASSERT_EQ(807u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(301, &value));
+ ASSERT_EQ(705u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(301, &value));
+ ASSERT_EQ(706u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(301, &value));
+ ASSERT_EQ(707u, value);
+
+ ASSERT_TRUE(mtf.RankFromValue(605, &rank));
+ ASSERT_EQ(401u, rank);
+
+ ASSERT_TRUE(mtf.RankFromValue(606, &rank));
+ ASSERT_EQ(401u, rank);
+
+ ASSERT_TRUE(mtf.RankFromValue(607, &rank));
+ ASSERT_EQ(401u, rank);
+
+ ASSERT_TRUE(mtf.ValueFromRank(1, &value));
+ ASSERT_EQ(607u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(2, &value));
+ ASSERT_EQ(606u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(3, &value));
+ ASSERT_EQ(605u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(4, &value));
+ ASSERT_EQ(707u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(5, &value));
+ ASSERT_EQ(706u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(6, &value));
+ ASSERT_EQ(705u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(7, &value));
+ ASSERT_EQ(807u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(8, &value));
+ ASSERT_EQ(806u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(9, &value));
+ ASSERT_EQ(805u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(10, &value));
+ ASSERT_EQ(907u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(11, &value));
+ ASSERT_EQ(906u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(12, &value));
+ ASSERT_EQ(905u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(13, &value));
+ ASSERT_EQ(6u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(14, &value));
+ ASSERT_EQ(5u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(15, &value));
+ ASSERT_EQ(4u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(16, &value));
+ ASSERT_EQ(3u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(17, &value));
+ ASSERT_EQ(2u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(18, &value));
+ ASSERT_EQ(1u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(19, &value));
+ ASSERT_EQ(999u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(20, &value));
+ ASSERT_EQ(998u, value);
+
+ ASSERT_TRUE(mtf.ValueFromRank(21, &value));
+ ASSERT_EQ(997u, value);
+
+ ASSERT_TRUE(mtf.RankFromValue(997, &rank));
+ ASSERT_EQ(1u, rank);
+
+ ASSERT_TRUE(mtf.RankFromValue(998, &rank));
+ ASSERT_EQ(2u, rank);
+
+ ASSERT_TRUE(mtf.RankFromValue(996, &rank));
+ ASSERT_EQ(22u, rank);
+
+ ASSERT_TRUE(mtf.Remove(995));
+
+ ASSERT_TRUE(mtf.RankFromValue(994, &rank));
+ ASSERT_EQ(23u, rank);
+
+ for (uint32_t i = 10; i < 1000; ++i) {
+ if (i != 995) {
+ ASSERT_TRUE(mtf.Remove(i));
+ } else {
+ ASSERT_FALSE(mtf.Remove(i));
+ }
+ }
+
+ CheckTree(mtf, std::string(R"(
+6H4S9T3028----8H2S3T24------7H1S1T21------D3
+ 9H1S1T27------D3
+ 2H3S5T3032----4H2S3T3030----5H1S1T3029----D4
+ 3H1S1T3031----D4
+ 1H1S1T3033----D3
+)").substr(1), /* print_timestamp = */ true);
+
+ ASSERT_TRUE(mtf.Insert(1000));
+ ASSERT_TRUE(mtf.ValueFromRank(1, &value));
+ ASSERT_EQ(1000u, value);
+}
+
+TEST(MoveToFront, String) {
+ MoveToFront<std::string> mtf;
+
+ EXPECT_TRUE(mtf.Insert("AAA"));
+ EXPECT_TRUE(mtf.Insert("BBB"));
+ EXPECT_TRUE(mtf.Insert("CCC"));
+ EXPECT_FALSE(mtf.Insert("AAA"));
+
+ std::string value;
+ EXPECT_TRUE(mtf.ValueFromRank(2, &value));
+ EXPECT_EQ("BBB", value);
+
+ EXPECT_TRUE(mtf.ValueFromRank(2, &value));
+ EXPECT_EQ("CCC", value);
+
+ size_t rank = 0;
+ EXPECT_TRUE(mtf.RankFromValue("AAA", &rank));
+ EXPECT_EQ(3u, rank);
+
+ EXPECT_FALSE(mtf.ValueFromRank(0, &value));
+ EXPECT_FALSE(mtf.RankFromValue("ABC", &rank));
+ EXPECT_FALSE(mtf.Remove("ABC"));
+
+ EXPECT_TRUE(mtf.Remove("AAA"));
+ EXPECT_FALSE(mtf.Remove("AAA"));
+ EXPECT_FALSE(mtf.RankFromValue("AAA", &rank));
+
+ EXPECT_TRUE(mtf.Insert("AAA"));
+ EXPECT_TRUE(mtf.RankFromValue("AAA", &rank));
+ EXPECT_EQ(1u, rank);
+}
+
+} // anonymous namespace
diff --git a/test/operand_pattern_test.cpp b/test/operand_pattern_test.cpp
index be4fffdc..358671c5 100644
--- a/test/operand_pattern_test.cpp
+++ b/test/operand_pattern_test.cpp
@@ -28,31 +28,31 @@ TEST(OperandPattern, InitiallyEmpty) {
EXPECT_TRUE(empty.empty());
}
-TEST(OperandPattern, PushFrontsAreOnTheLeft) {
+TEST(OperandPattern, PushBacksAreOnTheRight) {
spv_operand_pattern_t pattern;
- pattern.push_front(SPV_OPERAND_TYPE_ID);
+ pattern.push_back(SPV_OPERAND_TYPE_ID);
EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_ID}));
EXPECT_EQ(1u, pattern.size());
EXPECT_TRUE(!pattern.empty());
- EXPECT_EQ(SPV_OPERAND_TYPE_ID, pattern.front());
+ EXPECT_EQ(SPV_OPERAND_TYPE_ID, pattern.back());
- pattern.push_front(SPV_OPERAND_TYPE_NONE);
- EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_NONE,
- SPV_OPERAND_TYPE_ID}));
+ pattern.push_back(SPV_OPERAND_TYPE_NONE);
+ EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_ID,
+ SPV_OPERAND_TYPE_NONE}));
EXPECT_EQ(2u, pattern.size());
EXPECT_TRUE(!pattern.empty());
- EXPECT_EQ(SPV_OPERAND_TYPE_NONE, pattern.front());
+ EXPECT_EQ(SPV_OPERAND_TYPE_NONE, pattern.back());
}
-TEST(OperandPattern, PopFrontsAreOnTheLeft) {
- spv_operand_pattern_t pattern{SPV_OPERAND_TYPE_LITERAL_INTEGER,
- SPV_OPERAND_TYPE_ID};
+TEST(OperandPattern, PopBacksAreOnTheRight) {
+ spv_operand_pattern_t pattern{SPV_OPERAND_TYPE_ID,
+ SPV_OPERAND_TYPE_LITERAL_INTEGER};
- pattern.pop_front();
+ pattern.pop_back();
EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_ID}));
- pattern.pop_front();
+ pattern.pop_back();
EXPECT_THAT(pattern, Eq(spv_operand_pattern_t{}));
}
@@ -72,44 +72,44 @@ TEST_P(MaskExpansionTest, Sample) {
spvOperandTableGet(&operandTable, SPV_ENV_UNIVERSAL_1_0));
spv_operand_pattern_t pattern(GetParam().initial);
- spvPrependOperandTypesForMask(operandTable, GetParam().type, GetParam().mask,
+ spvPushOperandTypesForMask(operandTable, GetParam().type, GetParam().mask,
&pattern);
EXPECT_THAT(pattern, Eq(GetParam().expected));
}
// These macros let us write non-trivial examples without too much text.
-#define SUFFIX0 SPV_OPERAND_TYPE_NONE, SPV_OPERAND_TYPE_ID
-#define SUFFIX1 \
- SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE, \
- SPV_OPERAND_TYPE_STORAGE_CLASS
+#define PREFIX0 SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_NONE
+#define PREFIX1 \
+ SPV_OPERAND_TYPE_STORAGE_CLASS, SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE, \
+ SPV_OPERAND_TYPE_ID
INSTANTIATE_TEST_CASE_P(
OperandPattern, MaskExpansionTest,
::testing::ValuesIn(std::vector<MaskExpansionCase>{
// No bits means no change.
- {SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, 0, {SUFFIX0}, {SUFFIX0}},
+ {SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, 0, {PREFIX0}, {PREFIX0}},
// Unknown bits means no change.
{SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS,
0xfffffffc,
- {SUFFIX1},
- {SUFFIX1}},
+ {PREFIX1},
+ {PREFIX1}},
// Volatile has no operands.
{SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS,
SpvMemoryAccessVolatileMask,
- {SUFFIX0},
- {SUFFIX0}},
+ {PREFIX0},
+ {PREFIX0}},
// Aligned has one literal number operand.
{SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS,
SpvMemoryAccessAlignedMask,
- {SUFFIX1},
- {SPV_OPERAND_TYPE_LITERAL_INTEGER, SUFFIX1}},
+ {PREFIX1},
+ {PREFIX1, SPV_OPERAND_TYPE_LITERAL_INTEGER}},
// Volatile with Aligned still has just one literal number operand.
{SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS,
SpvMemoryAccessVolatileMask | SpvMemoryAccessAlignedMask,
- {SUFFIX1},
- {SPV_OPERAND_TYPE_LITERAL_INTEGER, SUFFIX1}},
+ {PREFIX1},
+ {PREFIX1, SPV_OPERAND_TYPE_LITERAL_INTEGER}},
}), );
-#undef SUFFIX0
-#undef SUFFIX1
+#undef PREFIX0
+#undef PREFIX1
// Returns a vector of all operand types that can be used in a pattern.
std::vector<spv_operand_type_t> allOperandTypes() {
@@ -149,7 +149,7 @@ TEST_P(VariableOperandExpansionTest, NonMatchableOperandsExpand) {
EXPECT_FALSE(pattern.empty());
// For the existing rules, the first expansion of a zero-or-more operand
// type yields a matchable operand type. This isn't strictly necessary.
- EXPECT_FALSE(spvOperandIsVariable(pattern.front()));
+ EXPECT_FALSE(spvOperandIsVariable(pattern.back()));
}
}
@@ -183,8 +183,8 @@ TEST(AlternatePatternFollowingImmediate, SingleElement) {
TEST(AlternatePatternFollowingImmediate, SingleResultId) {
EXPECT_THAT(
spvAlternatePatternFollowingImmediate({SPV_OPERAND_TYPE_RESULT_ID}),
- Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_RESULT_ID,
- SPV_OPERAND_TYPE_OPTIONAL_CIV}));
+ Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV,
+ SPV_OPERAND_TYPE_RESULT_ID}));
}
TEST(AlternatePatternFollowingImmediate, MultipleNonResultIds) {
@@ -199,12 +199,15 @@ TEST(AlternatePatternFollowingImmediate, MultipleNonResultIds) {
TEST(AlternatePatternFollowingImmediate, ResultIdFront) {
EXPECT_THAT(spvAlternatePatternFollowingImmediate(
{SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_ID}),
- Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_RESULT_ID,
+ Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV,
+ SPV_OPERAND_TYPE_RESULT_ID,
SPV_OPERAND_TYPE_OPTIONAL_CIV}));
EXPECT_THAT(spvAlternatePatternFollowingImmediate(
{SPV_OPERAND_TYPE_RESULT_ID,
SPV_OPERAND_TYPE_FP_ROUNDING_MODE, SPV_OPERAND_TYPE_ID}),
- Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_RESULT_ID,
+ Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV,
+ SPV_OPERAND_TYPE_RESULT_ID,
+ SPV_OPERAND_TYPE_OPTIONAL_CIV,
SPV_OPERAND_TYPE_OPTIONAL_CIV}));
EXPECT_THAT(spvAlternatePatternFollowingImmediate(
{SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_DIMENSIONALITY,
@@ -212,7 +215,13 @@ TEST(AlternatePatternFollowingImmediate, ResultIdFront) {
SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE,
SPV_OPERAND_TYPE_FP_ROUNDING_MODE, SPV_OPERAND_TYPE_ID,
SPV_OPERAND_TYPE_VARIABLE_ID}),
- Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_RESULT_ID,
+ Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV,
+ SPV_OPERAND_TYPE_RESULT_ID,
+ SPV_OPERAND_TYPE_OPTIONAL_CIV,
+ SPV_OPERAND_TYPE_OPTIONAL_CIV,
+ SPV_OPERAND_TYPE_OPTIONAL_CIV,
+ SPV_OPERAND_TYPE_OPTIONAL_CIV,
+ SPV_OPERAND_TYPE_OPTIONAL_CIV,
SPV_OPERAND_TYPE_OPTIONAL_CIV}));
}
@@ -230,8 +239,8 @@ TEST(AlternatePatternFollowingImmediate, ResultIdMiddle) {
SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_FP_ROUNDING_MODE,
SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_VARIABLE_ID}),
Eq(spv_operand_pattern_t{
- SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV,
SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_RESULT_ID,
+ SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV,
SPV_OPERAND_TYPE_OPTIONAL_CIV}));
}
@@ -239,14 +248,12 @@ TEST(AlternatePatternFollowingImmediate, ResultIdBack) {
EXPECT_THAT(spvAlternatePatternFollowingImmediate(
{SPV_OPERAND_TYPE_ID, SPV_OPERAND_TYPE_RESULT_ID}),
Eq(spv_operand_pattern_t{SPV_OPERAND_TYPE_OPTIONAL_CIV,
- SPV_OPERAND_TYPE_RESULT_ID,
- SPV_OPERAND_TYPE_OPTIONAL_CIV}));
+ SPV_OPERAND_TYPE_RESULT_ID}));
EXPECT_THAT(spvAlternatePatternFollowingImmediate(
{SPV_OPERAND_TYPE_FP_ROUNDING_MODE, SPV_OPERAND_TYPE_ID,
SPV_OPERAND_TYPE_RESULT_ID}),
Eq(spv_operand_pattern_t{
- SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV,
- SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_OPTIONAL_CIV}));
+ SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_RESULT_ID}));
EXPECT_THAT(
spvAlternatePatternFollowingImmediate(
{SPV_OPERAND_TYPE_DIMENSIONALITY, SPV_OPERAND_TYPE_LINKAGE_TYPE,
@@ -254,10 +261,7 @@ TEST(AlternatePatternFollowingImmediate, ResultIdBack) {
SPV_OPERAND_TYPE_FP_ROUNDING_MODE, SPV_OPERAND_TYPE_ID,
SPV_OPERAND_TYPE_VARIABLE_ID, SPV_OPERAND_TYPE_RESULT_ID}),
Eq(spv_operand_pattern_t{
- SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV,
- SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV,
- SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_OPTIONAL_CIV,
- SPV_OPERAND_TYPE_RESULT_ID, SPV_OPERAND_TYPE_OPTIONAL_CIV}));
+ SPV_OPERAND_TYPE_OPTIONAL_CIV, SPV_OPERAND_TYPE_RESULT_ID}));
}
} // anonymous namespace
diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt
index fcaefe26..75583270 100644
--- a/test/opt/CMakeLists.txt
+++ b/test/opt/CMakeLists.txt
@@ -53,11 +53,26 @@ add_spvtools_unittest(TARGET pass_freeze_spec_const
LIBS SPIRV-Tools-opt
)
+add_spvtools_unittest(TARGET pass_block_merge
+ SRCS block_merge_test.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
add_spvtools_unittest(TARGET pass_inline
SRCS inline_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
)
+add_spvtools_unittest(TARGET pass_insert_extract_elim
+ SRCS insert_extract_elim_test.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
+add_spvtools_unittest(TARGET pass_local_ssa_elim
+ SRCS local_ssa_elim_test.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
add_spvtools_unittest(TARGET pass_local_single_block_elim
SRCS local_single_block_elim.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
@@ -68,6 +83,21 @@ add_spvtools_unittest(TARGET pass_local_access_chain_convert
LIBS SPIRV-Tools-opt
)
+add_spvtools_unittest(TARGET pass_local_single_store_elim
+ SRCS local_single_store_elim_test.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
+add_spvtools_unittest(TARGET pass_dead_branch_elim
+ SRCS dead_branch_elim_test.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
+add_spvtools_unittest(TARGET pass_aggressive_dce
+ SRCS aggressive_dead_code_elim_test.cpp pass_utils.cpp
+ LIBS SPIRV-Tools-opt
+)
+
add_spvtools_unittest(TARGET pass_eliminate_dead_const
SRCS eliminate_dead_const_test.cpp pass_utils.cpp
LIBS SPIRV-Tools-opt
diff --git a/test/opt/aggressive_dead_code_elim_test.cpp b/test/opt/aggressive_dead_code_elim_test.cpp
new file mode 100644
index 00000000..49ed29fc
--- /dev/null
+++ b/test/opt/aggressive_dead_code_elim_test.cpp
@@ -0,0 +1,689 @@
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using AggressiveDCETest = PassTest<::testing::Test>;
+
+TEST_F(AggressiveDCETest, EliminateExtendedInst) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in vec4 Dead;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // vec4 dv = sqrt(Dead);
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs1 =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %dv "dv"
+OpName %Dead "Dead"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %Dead "Dead"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%Dead = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %9
+%15 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%dv = OpVariable %_ptr_Function_v4float Function
+%16 = OpLoad %v4float %BaseColor
+OpStore %v %16
+%17 = OpLoad %v4float %Dead
+%18 = OpExtInst %v4float %1 Sqrt %17
+OpStore %dv %18
+%19 = OpLoad %v4float %v
+OpStore %gl_FragColor %19
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %9
+%15 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%16 = OpLoad %v4float %BaseColor
+OpStore %v %16
+%19 = OpLoad %v4float %v
+OpStore %gl_FragColor %19
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::AggressiveDCEPass>(
+ predefs1 + names_before + predefs2 + func_before,
+ predefs1 + names_after + predefs2 + func_after,
+ true, true);
+}
+
+TEST_F(AggressiveDCETest, NoEliminateFrexp) {
+ // Note: SPIR-V hand-edited to utilize Frexp
+ //
+ // #version 450
+ //
+ // in vec4 BaseColor;
+ // in vec4 Dead;
+ // out vec4 Color;
+ // out ivec4 iv2;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // vec4 dv = frexp(Dead, iv2);
+ // Color = v;
+ // }
+
+ const std::string predefs1 =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Dead %iv2 %Color
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 450
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %dv "dv"
+OpName %Dead "Dead"
+OpName %iv2 "iv2"
+OpName %ResType "ResType"
+OpName %Color "Color"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %Dead "Dead"
+OpName %iv2 "iv2"
+OpName %ResType "ResType"
+OpName %Color "Color"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%11 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%Dead = OpVariable %_ptr_Input_v4float Input
+%int = OpTypeInt 32 1
+%v4int = OpTypeVector %int 4
+%_ptr_Output_v4int = OpTypePointer Output %v4int
+%iv2 = OpVariable %_ptr_Output_v4int Output
+%ResType = OpTypeStruct %v4float %v4int
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%Color = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %11
+%20 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%dv = OpVariable %_ptr_Function_v4float Function
+%21 = OpLoad %v4float %BaseColor
+OpStore %v %21
+%22 = OpLoad %v4float %Dead
+%23 = OpExtInst %v4float %1 Frexp %22 %iv2
+OpStore %dv %23
+%24 = OpLoad %v4float %v
+OpStore %Color %24
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %11
+%20 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%21 = OpLoad %v4float %BaseColor
+OpStore %v %21
+%22 = OpLoad %v4float %Dead
+%23 = OpExtInst %v4float %1 Frexp %22 %iv2
+%24 = OpLoad %v4float %v
+OpStore %Color %24
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::AggressiveDCEPass>(
+ predefs1 + names_before + predefs2 + func_before,
+ predefs1 + names_after + predefs2 + func_after,
+ true, true);
+}
+
+TEST_F(AggressiveDCETest, EliminateDecorate) {
+ // Note: The SPIR-V was hand-edited to add the OpDecorate
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in vec4 Dead;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // vec4 dv = Dead * 0.5;
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs1 =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %dv "dv"
+OpName %Dead "Dead"
+OpName %gl_FragColor "gl_FragColor"
+OpDecorate %8 RelaxedPrecision
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %Dead "Dead"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%Dead = OpVariable %_ptr_Input_v4float Input
+%float_0_5 = OpConstant %float 0.5
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %10
+%17 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%dv = OpVariable %_ptr_Function_v4float Function
+%18 = OpLoad %v4float %BaseColor
+OpStore %v %18
+%19 = OpLoad %v4float %Dead
+%8 = OpVectorTimesScalar %v4float %19 %float_0_5
+OpStore %dv %8
+%20 = OpLoad %v4float %v
+OpStore %gl_FragColor %20
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %10
+%17 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%18 = OpLoad %v4float %BaseColor
+OpStore %v %18
+%20 = OpLoad %v4float %v
+OpStore %gl_FragColor %20
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::AggressiveDCEPass>(
+ predefs1 + names_before + predefs2 + func_before,
+ predefs1 + names_after + predefs2 + func_after,
+ true, true);
+}
+
+TEST_F(AggressiveDCETest, Simple) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in vec4 Dead;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // vec4 dv = Dead;
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs1 =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %dv "dv"
+OpName %Dead "Dead"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %Dead "Dead"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%Dead = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %9
+%15 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%dv = OpVariable %_ptr_Function_v4float Function
+%16 = OpLoad %v4float %BaseColor
+OpStore %v %16
+%17 = OpLoad %v4float %Dead
+OpStore %dv %17
+%18 = OpLoad %v4float %v
+OpStore %gl_FragColor %18
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %9
+%15 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%16 = OpLoad %v4float %BaseColor
+OpStore %v %16
+%18 = OpLoad %v4float %v
+OpStore %gl_FragColor %18
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::AggressiveDCEPass>(
+ predefs1 + names_before + predefs2 + func_before,
+ predefs1 + names_after + predefs2 + func_after,
+ true, true);
+}
+
+TEST_F(AggressiveDCETest, DeadCycle) {
+ // #version 140
+ // in vec4 BaseColor;
+ //
+ // layout(std140) uniform U_t
+ // {
+ // int g_I ;
+ // } ;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // float df = 0.0;
+ // int i = 0;
+ // while (i < g_I) {
+ // df = df * 0.5;
+ // i = i + 1;
+ // }
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs1 =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %df "df"
+OpName %i "i"
+OpName %U_t "U_t"
+OpMemberName %U_t 0 "g_I"
+OpName %_ ""
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %i "i"
+OpName %U_t "U_t"
+OpMemberName %U_t 0 "g_I"
+OpName %_ ""
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string predefs2 =
+ R"(OpMemberDecorate %U_t 0 Offset 0
+OpDecorate %U_t Block
+OpDecorate %_ DescriptorSet 0
+%void = OpTypeVoid
+%11 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_0 = OpConstant %int 0
+%U_t = OpTypeStruct %int
+%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t
+%_ = OpVariable %_ptr_Uniform_U_t Uniform
+%_ptr_Uniform_int = OpTypePointer Uniform %int
+%bool = OpTypeBool
+%float_0_5 = OpConstant %float 0.5
+%int_1 = OpConstant %int 1
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %11
+%27 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%df = OpVariable %_ptr_Function_float Function
+%i = OpVariable %_ptr_Function_int Function
+%28 = OpLoad %v4float %BaseColor
+OpStore %v %28
+OpStore %df %float_0
+OpStore %i %int_0
+OpBranch %29
+%29 = OpLabel
+OpLoopMerge %30 %31 None
+OpBranch %32
+%32 = OpLabel
+%33 = OpLoad %int %i
+%34 = OpAccessChain %_ptr_Uniform_int %_ %int_0
+%35 = OpLoad %int %34
+%36 = OpSLessThan %bool %33 %35
+OpBranchConditional %36 %37 %30
+%37 = OpLabel
+%38 = OpLoad %float %df
+%39 = OpFMul %float %38 %float_0_5
+OpStore %df %39
+%40 = OpLoad %int %i
+%41 = OpIAdd %int %40 %int_1
+OpStore %i %41
+OpBranch %31
+%31 = OpLabel
+OpBranch %29
+%30 = OpLabel
+%42 = OpLoad %v4float %v
+OpStore %gl_FragColor %42
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %11
+%27 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%i = OpVariable %_ptr_Function_int Function
+%28 = OpLoad %v4float %BaseColor
+OpStore %v %28
+OpStore %i %int_0
+OpBranch %29
+%29 = OpLabel
+OpLoopMerge %30 %31 None
+OpBranch %32
+%32 = OpLabel
+%33 = OpLoad %int %i
+%34 = OpAccessChain %_ptr_Uniform_int %_ %int_0
+%35 = OpLoad %int %34
+%36 = OpSLessThan %bool %33 %35
+OpBranchConditional %36 %37 %30
+%37 = OpLabel
+%40 = OpLoad %int %i
+%41 = OpIAdd %int %40 %int_1
+OpStore %i %41
+OpBranch %31
+%31 = OpLabel
+OpBranch %29
+%30 = OpLabel
+%42 = OpLoad %v4float %v
+OpStore %gl_FragColor %42
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::AggressiveDCEPass>(
+ predefs1 + names_before + predefs2 + func_before,
+ predefs1 + names_after + predefs2 + func_after,
+ true, true);
+}
+
+TEST_F(AggressiveDCETest, OptWhitelistExtension) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in vec4 Dead;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // vec4 dv = Dead;
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs1 =
+ R"(OpCapability Shader
+OpExtension "SPV_AMD_gpu_shader_int16"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %dv "dv"
+OpName %Dead "Dead"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %Dead "Dead"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%Dead = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string func_before =
+ R"(%main = OpFunction %void None %9
+%15 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%dv = OpVariable %_ptr_Function_v4float Function
+%16 = OpLoad %v4float %BaseColor
+OpStore %v %16
+%17 = OpLoad %v4float %Dead
+OpStore %dv %17
+%18 = OpLoad %v4float %v
+OpStore %gl_FragColor %18
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string func_after =
+ R"(%main = OpFunction %void None %9
+%15 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%16 = OpLoad %v4float %BaseColor
+OpStore %v %16
+%18 = OpLoad %v4float %v
+OpStore %gl_FragColor %18
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::AggressiveDCEPass>(
+ predefs1 + names_before + predefs2 + func_before,
+ predefs1 + names_after + predefs2 + func_after,
+ true, true);
+}
+
+TEST_F(AggressiveDCETest, NoOptBlacklistExtension) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in vec4 Dead;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // vec4 dv = Dead;
+ // gl_FragColor = v;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+OpExtension "SPV_KHR_variable_pointers"
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Dead %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %dv "dv"
+OpName %Dead "Dead"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%Dead = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%main = OpFunction %void None %9
+%15 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%dv = OpVariable %_ptr_Function_v4float Function
+%16 = OpLoad %v4float %BaseColor
+OpStore %v %16
+%17 = OpLoad %v4float %Dead
+OpStore %dv %17
+%18 = OpLoad %v4float %v
+OpStore %gl_FragColor %18
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::AggressiveDCEPass>(
+ assembly, assembly, true, true);
+}
+
+// TODO(greg-lunarg): Add tests to verify handling of these cases:
+//
+// Check that logical addressing required
+// Check that function calls inhibit optimization
+// Others?
+
+} // anonymous namespace
diff --git a/test/opt/block_merge_test.cpp b/test/opt/block_merge_test.cpp
new file mode 100644
index 00000000..2133feee
--- /dev/null
+++ b/test/opt/block_merge_test.cpp
@@ -0,0 +1,337 @@
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using BlockMergeTest = PassTest<::testing::Test>;
+
+TEST_F(BlockMergeTest, Simple) {
+ // Note: SPIR-V hand edited to insert block boundary
+ // between two statements in main.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %7
+%13 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%14 = OpLoad %v4float %BaseColor
+OpStore %v %14
+OpBranch %15
+%15 = OpLabel
+%16 = OpLoad %v4float %v
+OpStore %gl_FragColor %16
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %7
+%13 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%14 = OpLoad %v4float %BaseColor
+OpStore %v %14
+%16 = OpLoad %v4float %v
+OpStore %gl_FragColor %16
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::BlockMergePass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(BlockMergeTest, EmptyBlock) {
+ // Note: SPIR-V hand edited to insert empty block
+ // after two statements in main.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %7
+%13 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%14 = OpLoad %v4float %BaseColor
+OpStore %v %14
+OpBranch %15
+%15 = OpLabel
+%16 = OpLoad %v4float %v
+OpStore %gl_FragColor %16
+OpBranch %17
+%17 = OpLabel
+OpBranch %18
+%18 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %7
+%13 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%14 = OpLoad %v4float %BaseColor
+OpStore %v %14
+%16 = OpLoad %v4float %v
+OpStore %gl_FragColor %16
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::BlockMergePass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(BlockMergeTest, NoOptOfMergeOrContinueBlock) {
+ // Note: SPIR-V hand edited remove dead branch and add block
+ // before continue block
+ //
+ // #version 140
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // while (true) {
+ // break;
+ // }
+ // gl_FragColor = BaseColor;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %gl_FragColor "gl_FragColor"
+OpName %BaseColor "BaseColor"
+%void = OpTypeVoid
+%6 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%main = OpFunction %void None %6
+%13 = OpLabel
+OpBranch %14
+%14 = OpLabel
+OpLoopMerge %15 %16 None
+OpBranch %17
+%17 = OpLabel
+OpBranch %15
+%18 = OpLabel
+OpBranch %16
+%16 = OpLabel
+OpBranch %14
+%15 = OpLabel
+%19 = OpLoad %v4float %BaseColor
+OpStore %gl_FragColor %19
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::BlockMergePass>(
+ assembly, assembly, true, true);
+}
+
+TEST_F(BlockMergeTest, NestedInControlFlow) {
+ // Note: SPIR-V hand edited to insert block boundary
+ // between OpFMul and OpStore in then-part.
+ //
+ // #version 140
+ // in vec4 BaseColor;
+ //
+ // layout(std140) uniform U_t
+ // {
+ // bool g_B ;
+ // } ;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // if (g_B)
+ // vec4 v = v * 0.25;
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %U_t "U_t"
+OpMemberName %U_t 0 "g_B"
+OpName %_ ""
+OpName %v_0 "v"
+OpName %gl_FragColor "gl_FragColor"
+OpMemberDecorate %U_t 0 Offset 0
+OpDecorate %U_t Block
+OpDecorate %_ DescriptorSet 0
+%void = OpTypeVoid
+%10 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%uint = OpTypeInt 32 0
+%U_t = OpTypeStruct %uint
+%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t
+%_ = OpVariable %_ptr_Uniform_U_t Uniform
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%_ptr_Uniform_uint = OpTypePointer Uniform %uint
+%bool = OpTypeBool
+%uint_0 = OpConstant %uint 0
+%float_0_25 = OpConstant %float 0.25
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %10
+%24 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%v_0 = OpVariable %_ptr_Function_v4float Function
+%25 = OpLoad %v4float %BaseColor
+OpStore %v %25
+%26 = OpAccessChain %_ptr_Uniform_uint %_ %int_0
+%27 = OpLoad %uint %26
+%28 = OpINotEqual %bool %27 %uint_0
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+%31 = OpLoad %v4float %v
+%32 = OpVectorTimesScalar %v4float %31 %float_0_25
+OpBranch %33
+%33 = OpLabel
+OpStore %v_0 %32
+OpBranch %29
+%29 = OpLabel
+%34 = OpLoad %v4float %v
+OpStore %gl_FragColor %34
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %10
+%24 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%v_0 = OpVariable %_ptr_Function_v4float Function
+%25 = OpLoad %v4float %BaseColor
+OpStore %v %25
+%26 = OpAccessChain %_ptr_Uniform_uint %_ %int_0
+%27 = OpLoad %uint %26
+%28 = OpINotEqual %bool %27 %uint_0
+OpSelectionMerge %29 None
+OpBranchConditional %28 %30 %29
+%30 = OpLabel
+%31 = OpLoad %v4float %v
+%32 = OpVectorTimesScalar %v4float %31 %float_0_25
+OpStore %v_0 %32
+OpBranch %29
+%29 = OpLabel
+%34 = OpLoad %v4float %v
+OpStore %gl_FragColor %34
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::BlockMergePass>(
+ predefs + before, predefs + after, true, true);
+}
+
+// TODO(greg-lunarg): Add tests to verify handling of these cases:
+//
+// More complex control flow
+// Others?
+
+} // anonymous namespace
diff --git a/test/opt/dead_branch_elim_test.cpp b/test/opt/dead_branch_elim_test.cpp
new file mode 100644
index 00000000..46958fb2
--- /dev/null
+++ b/test/opt/dead_branch_elim_test.cpp
@@ -0,0 +1,905 @@
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using DeadBranchElimTest = PassTest<::testing::Test>;
+
+TEST_F(DeadBranchElimTest, IfThenElseTrue) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v;
+ // if (true)
+ // v = vec4(0.0,0.0,0.0,0.0);
+ // else
+ // v = vec4(1.0,1.0,1.0,1.0);
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %gl_FragColor "gl_FragColor"
+OpName %BaseColor "BaseColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%float_0 = OpConstant %float 0
+%14 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%float_1 = OpConstant %float 1
+%16 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %7
+%19 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+OpSelectionMerge %20 None
+OpBranchConditional %true %21 %22
+%21 = OpLabel
+OpStore %v %14
+OpBranch %20
+%22 = OpLabel
+OpStore %v %16
+OpBranch %20
+%20 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %7
+%19 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+OpBranch %21
+%21 = OpLabel
+OpStore %v %14
+OpBranch %20
+%20 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(DeadBranchElimTest, IfThenElseFalse) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v;
+ // if (false)
+ // v = vec4(0.0,0.0,0.0,0.0);
+ // else
+ // v = vec4(1.0,1.0,1.0,1.0);
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %gl_FragColor "gl_FragColor"
+OpName %BaseColor "BaseColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%bool = OpTypeBool
+%false = OpConstantFalse %bool
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%float_0 = OpConstant %float 0
+%14 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%float_1 = OpConstant %float 1
+%16 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %7
+%19 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+OpSelectionMerge %20 None
+OpBranchConditional %false %21 %22
+%21 = OpLabel
+OpStore %v %14
+OpBranch %20
+%22 = OpLabel
+OpStore %v %16
+OpBranch %20
+%20 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %7
+%19 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+OpBranch %22
+%22 = OpLabel
+OpStore %v %16
+OpBranch %20
+%20 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(DeadBranchElimTest, IfThenTrue) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // if (true)
+ // v = v * vec4(0.5,0.5,0.5,0.5);
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%float_0_5 = OpConstant %float 0.5
+%15 = OpConstantComposite %v4float %float_0_5 %float_0_5 %float_0_5 %float_0_5
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %7
+%17 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%18 = OpLoad %v4float %BaseColor
+OpStore %v %18
+OpSelectionMerge %19 None
+OpBranchConditional %true %20 %19
+%20 = OpLabel
+%21 = OpLoad %v4float %v
+%22 = OpFMul %v4float %21 %15
+OpStore %v %22
+OpBranch %19
+%19 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %7
+%17 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%18 = OpLoad %v4float %BaseColor
+OpStore %v %18
+OpBranch %20
+%20 = OpLabel
+%21 = OpLoad %v4float %v
+%22 = OpFMul %v4float %21 %15
+OpStore %v %22
+OpBranch %19
+%19 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(DeadBranchElimTest, IfThenFalse) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // if (false)
+ // v = v * vec4(0.5,0.5,0.5,0.5);
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%bool = OpTypeBool
+%false = OpConstantFalse %bool
+%float_0_5 = OpConstant %float 0.5
+%15 = OpConstantComposite %v4float %float_0_5 %float_0_5 %float_0_5 %float_0_5
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %7
+%17 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%18 = OpLoad %v4float %BaseColor
+OpStore %v %18
+OpSelectionMerge %19 None
+OpBranchConditional %false %20 %19
+%20 = OpLabel
+%21 = OpLoad %v4float %v
+%22 = OpFMul %v4float %21 %15
+OpStore %v %22
+OpBranch %19
+%19 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %7
+%17 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%18 = OpLoad %v4float %BaseColor
+OpStore %v %18
+OpBranch %19
+%19 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(DeadBranchElimTest, IfThenElsePhiTrue) {
+ // Test handling of phi in merge block after dead branch elimination.
+ // Note: The SPIR-V has had store/load elimination and phi insertion
+ //
+ // #version 140
+ //
+ // void main()
+ // {
+ // vec4 v;
+ // if (true)
+ // v = vec4(0.0,0.0,0.0,0.0);
+ // else
+ // v = vec4(1.0,1.0,1.0,1.0);
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%5 = OpTypeFunction %void
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%float_0 = OpConstant %float 0
+%12 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%float_1 = OpConstant %float 1
+%14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %5
+%17 = OpLabel
+OpSelectionMerge %18 None
+OpBranchConditional %true %19 %20
+%19 = OpLabel
+OpBranch %18
+%20 = OpLabel
+OpBranch %18
+%18 = OpLabel
+%21 = OpPhi %v4float %12 %19 %14 %20
+OpStore %gl_FragColor %21
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %5
+%17 = OpLabel
+OpBranch %19
+%19 = OpLabel
+OpBranch %18
+%18 = OpLabel
+OpStore %gl_FragColor %12
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(DeadBranchElimTest, IfThenElsePhiFalse) {
+ // Test handling of phi in merge block after dead branch elimination.
+ // Note: The SPIR-V has had store/load elimination and phi insertion
+ //
+ // #version 140
+ //
+ // void main()
+ // {
+ // vec4 v;
+ // if (true)
+ // v = vec4(0.0,0.0,0.0,0.0);
+ // else
+ // v = vec4(1.0,1.0,1.0,1.0);
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%5 = OpTypeFunction %void
+%bool = OpTypeBool
+%false = OpConstantFalse %bool
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%float_0 = OpConstant %float 0
+%12 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%float_1 = OpConstant %float 1
+%14 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %5
+%17 = OpLabel
+OpSelectionMerge %18 None
+OpBranchConditional %false %19 %20
+%19 = OpLabel
+OpBranch %18
+%20 = OpLabel
+OpBranch %18
+%18 = OpLabel
+%21 = OpPhi %v4float %12 %19 %14 %20
+OpStore %gl_FragColor %21
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %5
+%17 = OpLabel
+OpBranch %20
+%20 = OpLabel
+OpBranch %18
+%18 = OpLabel
+OpStore %gl_FragColor %14
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(DeadBranchElimTest, CompoundIfThenElseFalse) {
+ // #version 140
+ //
+ // layout(std140) uniform U_t
+ // {
+ // bool g_B ;
+ // } ;
+ //
+ // void main()
+ // {
+ // vec4 v;
+ // if (false) {
+ // if (g_B)
+ // v = vec4(0.0,0.0,0.0,0.0);
+ // else
+ // v = vec4(1.0,1.0,1.0,1.0);
+ // } else {
+ // if (g_B)
+ // v = vec4(1.0,1.0,1.0,1.0);
+ // else
+ // v = vec4(0.0,0.0,0.0,0.0);
+ // }
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %U_t "U_t"
+OpMemberName %U_t 0 "g_B"
+OpName %_ ""
+OpName %v "v"
+OpName %gl_FragColor "gl_FragColor"
+OpMemberDecorate %U_t 0 Offset 0
+OpDecorate %U_t Block
+OpDecorate %_ DescriptorSet 0
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%bool = OpTypeBool
+%false = OpConstantFalse %bool
+%uint = OpTypeInt 32 0
+%U_t = OpTypeStruct %uint
+%_ptr_Uniform_U_t = OpTypePointer Uniform %U_t
+%_ = OpVariable %_ptr_Uniform_U_t Uniform
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%_ptr_Uniform_uint = OpTypePointer Uniform %uint
+%uint_0 = OpConstant %uint 0
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%float_0 = OpConstant %float 0
+%21 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%float_1 = OpConstant %float 1
+%23 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%25 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+OpSelectionMerge %26 None
+OpBranchConditional %false %27 %28
+%27 = OpLabel
+%29 = OpAccessChain %_ptr_Uniform_uint %_ %int_0
+%30 = OpLoad %uint %29
+%31 = OpINotEqual %bool %30 %uint_0
+OpSelectionMerge %32 None
+OpBranchConditional %31 %33 %34
+%33 = OpLabel
+OpStore %v %21
+OpBranch %32
+%34 = OpLabel
+OpStore %v %23
+OpBranch %32
+%32 = OpLabel
+OpBranch %26
+%28 = OpLabel
+%35 = OpAccessChain %_ptr_Uniform_uint %_ %int_0
+%36 = OpLoad %uint %35
+%37 = OpINotEqual %bool %36 %uint_0
+OpSelectionMerge %38 None
+OpBranchConditional %37 %39 %40
+%39 = OpLabel
+OpStore %v %23
+OpBranch %38
+%40 = OpLabel
+OpStore %v %21
+OpBranch %38
+%38 = OpLabel
+OpBranch %26
+%26 = OpLabel
+%41 = OpLoad %v4float %v
+OpStore %gl_FragColor %41
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%25 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+OpBranch %28
+%28 = OpLabel
+%35 = OpAccessChain %_ptr_Uniform_uint %_ %int_0
+%36 = OpLoad %uint %35
+%37 = OpINotEqual %bool %36 %uint_0
+OpSelectionMerge %38 None
+OpBranchConditional %37 %39 %40
+%39 = OpLabel
+OpStore %v %23
+OpBranch %38
+%40 = OpLabel
+OpStore %v %21
+OpBranch %38
+%38 = OpLabel
+OpBranch %26
+%26 = OpLabel
+%41 = OpLoad %v4float %v
+OpStore %gl_FragColor %41
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(DeadBranchElimTest, NoOrphanMerge) {
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%bool = OpTypeBool
+%true = OpConstantTrue %bool
+%float_0_5 = OpConstant %float 0.5
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %7
+%16 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%17 = OpLoad %v4float %BaseColor
+OpStore %v %17
+OpSelectionMerge %18 None
+OpBranchConditional %true %19 %20
+%19 = OpLabel
+OpKill
+%20 = OpLabel
+%21 = OpLoad %v4float %v
+%22 = OpVectorTimesScalar %v4float %21 %float_0_5
+OpStore %v %22
+OpBranch %18
+%18 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %7
+%16 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%17 = OpLoad %v4float %BaseColor
+OpStore %v %17
+OpSelectionMerge %18 None
+OpBranchConditional %true %19 %18
+%19 = OpLabel
+OpKill
+%18 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(DeadBranchElimTest, KeepContinueTargetWhenKillAfterMerge) {
+ // #version 450
+ // void main() {
+ // bool c;
+ // bool d;
+ // while(c) {
+ // if(d) {
+ // continue;
+ // }
+ // if(false) {
+ // continue;
+ // }
+ // discard;
+ // }
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main"
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 450
+OpName %main "main"
+OpName %c "c"
+OpName %d "d"
+%void = OpTypeVoid
+%6 = OpTypeFunction %void
+%bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+%false = OpConstantFalse %bool
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %6
+%10 = OpLabel
+%c = OpVariable %_ptr_Function_bool Function
+%d = OpVariable %_ptr_Function_bool Function
+OpBranch %11
+%11 = OpLabel
+OpLoopMerge %12 %13 None
+OpBranch %14
+%14 = OpLabel
+%15 = OpLoad %bool %c
+OpBranchConditional %15 %16 %12
+%16 = OpLabel
+%17 = OpLoad %bool %d
+OpSelectionMerge %18 None
+OpBranchConditional %17 %19 %18
+%19 = OpLabel
+OpBranch %13
+%18 = OpLabel
+OpSelectionMerge %20 None
+OpBranchConditional %false %21 %20
+%21 = OpLabel
+OpBranch %13
+%20 = OpLabel
+OpKill
+%13 = OpLabel
+OpBranch %11
+%12 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %6
+%10 = OpLabel
+%c = OpVariable %_ptr_Function_bool Function
+%d = OpVariable %_ptr_Function_bool Function
+OpBranch %11
+%11 = OpLabel
+OpLoopMerge %12 %13 None
+OpBranch %14
+%14 = OpLabel
+%15 = OpLoad %bool %c
+OpBranchConditional %15 %16 %12
+%16 = OpLabel
+%17 = OpLoad %bool %d
+OpSelectionMerge %18 None
+OpBranchConditional %17 %19 %18
+%19 = OpLabel
+OpBranch %13
+%18 = OpLabel
+OpBranch %20
+%20 = OpLabel
+OpKill
+%13 = OpLabel
+OpBranch %11
+%12 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs + before, predefs + after, true, true);
+}
+
+TEST_F(DeadBranchElimTest, DecorateDeleted) {
+ // Note: SPIR-V hand-edited to add decoration
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // if (false)
+ // v = v * vec4(0.5,0.5,0.5,0.5);
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs_before =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+OpDecorate %22 RelaxedPrecision
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%bool = OpTypeBool
+%false = OpConstantFalse %bool
+%float_0_5 = OpConstant %float 0.5
+%15 = OpConstantComposite %v4float %float_0_5 %float_0_5 %float_0_5 %float_0_5
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string predefs_after =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%bool = OpTypeBool
+%false = OpConstantFalse %bool
+%float_0_5 = OpConstant %float 0.5
+%16 = OpConstantComposite %v4float %float_0_5 %float_0_5 %float_0_5 %float_0_5
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %7
+%17 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%18 = OpLoad %v4float %BaseColor
+OpStore %v %18
+OpSelectionMerge %19 None
+OpBranchConditional %false %20 %19
+%20 = OpLabel
+%21 = OpLoad %v4float %v
+%22 = OpFMul %v4float %21 %15
+OpStore %v %22
+OpBranch %19
+%19 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%18 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%19 = OpLoad %v4float %BaseColor
+OpStore %v %19
+OpBranch %20
+%20 = OpLabel
+%23 = OpLoad %v4float %v
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::DeadBranchElimPass>(
+ predefs_before + before, predefs_after + after, true, true);
+}
+
+// TODO(greg-lunarg): Add tests to verify handling of these cases:
+//
+// More complex control flow
+// Others?
+
+} // anonymous namespace
diff --git a/test/opt/insert_extract_elim_test.cpp b/test/opt/insert_extract_elim_test.cpp
new file mode 100644
index 00000000..18b31bb9
--- /dev/null
+++ b/test/opt/insert_extract_elim_test.cpp
@@ -0,0 +1,334 @@
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using InsertExtractElimTest = PassTest<::testing::Test>;
+
+TEST_F(InsertExtractElimTest, Simple) {
+ // Note: The SPIR-V assembly has had store/load elimination
+ // performed to allow the inserts and extracts to directly
+ // reference each other.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // struct S_t {
+ // vec4 v0;
+ // vec4 v1;
+ // };
+ //
+ // void main()
+ // {
+ // S_t s0;
+ // s0.v1 = BaseColor;
+ // gl_FragColor = s0.v1;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %S_t "S_t"
+OpMemberName %S_t 0 "v0"
+OpMemberName %S_t 1 "v1"
+OpName %s0 "s0"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%S_t = OpTypeStruct %v4float %v4float
+%_ptr_Function_S_t = OpTypePointer Function %S_t
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%18 = OpLoad %v4float %BaseColor
+%19 = OpLoad %S_t %s0
+%20 = OpCompositeInsert %S_t %18 %19 1
+OpStore %s0 %20
+%21 = OpCompositeExtract %v4float %20 1
+OpStore %gl_FragColor %21
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%18 = OpLoad %v4float %BaseColor
+%19 = OpLoad %S_t %s0
+%20 = OpCompositeInsert %S_t %18 %19 1
+OpStore %s0 %20
+OpStore %gl_FragColor %18
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::InsertExtractElimPass>(predefs + before,
+ predefs + after, true, true);
+}
+
+TEST_F(InsertExtractElimTest, OptimizeAcrossNonConflictingInsert) {
+ // Note: The SPIR-V assembly has had store/load elimination
+ // performed to allow the inserts and extracts to directly
+ // reference each other.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // struct S_t {
+ // vec4 v0;
+ // vec4 v1;
+ // };
+ //
+ // void main()
+ // {
+ // S_t s0;
+ // s0.v1 = BaseColor;
+ // s0.v0[2] = 0.0;
+ // gl_FragColor = s0.v1;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %S_t "S_t"
+OpMemberName %S_t 0 "v0"
+OpMemberName %S_t 1 "v1"
+OpName %s0 "s0"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%S_t = OpTypeStruct %v4float %v4float
+%_ptr_Function_S_t = OpTypePointer Function %S_t
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%float_0 = OpConstant %float 0
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%18 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%19 = OpLoad %v4float %BaseColor
+%20 = OpLoad %S_t %s0
+%21 = OpCompositeInsert %S_t %19 %20 1
+%22 = OpCompositeInsert %S_t %float_0 %21 0 2
+OpStore %s0 %22
+%23 = OpCompositeExtract %v4float %22 1
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%18 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%19 = OpLoad %v4float %BaseColor
+%20 = OpLoad %S_t %s0
+%21 = OpCompositeInsert %S_t %19 %20 1
+%22 = OpCompositeInsert %S_t %float_0 %21 0 2
+OpStore %s0 %22
+OpStore %gl_FragColor %19
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::InsertExtractElimPass>(predefs + before,
+ predefs + after, true, true);
+}
+
+TEST_F(InsertExtractElimTest, ConflictingInsertPreventsOptimization) {
+ // Note: The SPIR-V assembly has had store/load elimination
+ // performed to allow the inserts and extracts to directly
+ // reference each other.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // struct S_t {
+ // vec4 v0;
+ // vec4 v1;
+ // };
+ //
+ // void main()
+ // {
+ // S_t s0;
+ // s0.v1 = BaseColor;
+ // s0.v1[2] = 0.0;
+ // gl_FragColor = s0.v1;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %S_t "S_t"
+OpMemberName %S_t 0 "v0"
+OpMemberName %S_t 1 "v1"
+OpName %s0 "s0"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%S_t = OpTypeStruct %v4float %v4float
+%_ptr_Function_S_t = OpTypePointer Function %S_t
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%float_0 = OpConstant %float 0
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%main = OpFunction %void None %8
+%18 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%19 = OpLoad %v4float %BaseColor
+%20 = OpLoad %S_t %s0
+%21 = OpCompositeInsert %S_t %19 %20 1
+%22 = OpCompositeInsert %S_t %float_0 %21 1 2
+OpStore %s0 %22
+%23 = OpCompositeExtract %v4float %22 1
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::InsertExtractElimPass>(assembly,
+ assembly, true, true);
+}
+
+TEST_F(InsertExtractElimTest, ConflictingInsertPreventsOptimization2) {
+ // Note: The SPIR-V assembly has had store/load elimination
+ // performed to allow the inserts and extracts to directly
+ // reference each other.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // struct S_t {
+ // vec4 v0;
+ // vec4 v1;
+ // };
+ //
+ // void main()
+ // {
+ // S_t s0;
+ // s0.v1[1] = 1.0;
+ // s0.v1 = Baseline;
+ // gl_FragColor = vec4(s0.v1[1], 0.0, 0.0, 0.0);
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %S_t "S_t"
+OpMemberName %S_t 0 "v0"
+OpMemberName %S_t 1 "v1"
+OpName %s0 "s0"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%S_t = OpTypeStruct %v4float %v4float
+%_ptr_Function_S_t = OpTypePointer Function %S_t
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%float_1 = OpConstant %float 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%float_0 = OpConstant %float 0
+%main = OpFunction %void None %8
+%22 = OpLabel
+%s0 = OpVariable %_ptr_Function_S_t Function
+%23 = OpLoad %S_t %s0
+%24 = OpCompositeInsert %S_t %float_1 %23 1 1
+%25 = OpLoad %v4float %BaseColor
+%26 = OpCompositeInsert %S_t %25 %24 1
+%27 = OpCompositeExtract %float %26 1 1
+%28 = OpCompositeConstruct %v4float %27 %float_0 %float_0 %float_0
+OpStore %gl_FragColor %28
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::InsertExtractElimPass>(assembly,
+ assembly, true, true);
+}
+
+// TODO(greg-lunarg): Add tests to verify handling of these cases:
+//
+
+} // anonymous namespace
diff --git a/test/opt/instruction_test.cpp b/test/opt/instruction_test.cpp
index 608229f4..2db4ed2c 100644
--- a/test/opt/instruction_test.cpp
+++ b/test/opt/instruction_test.cpp
@@ -69,6 +69,55 @@ spv_parsed_instruction_t kSampleParsedInstruction = {kSampleInstructionWords,
44, // result id
kSampleParsedOperands,
3};
+
+// The words for an OpAccessChain instruction.
+uint32_t kSampleAccessChainInstructionWords[] = {
+ (7 << 16) | uint32_t(SpvOpAccessChain), 100, 101, 102, 103, 104, 105};
+
+// The operands that would be parsed from kSampleAccessChainInstructionWords.
+spv_parsed_operand_t kSampleAccessChainOperands[] = {
+ {1, 1, SPV_OPERAND_TYPE_RESULT_ID, SPV_NUMBER_NONE, 0},
+ {2, 1, SPV_OPERAND_TYPE_TYPE_ID, SPV_NUMBER_NONE, 0},
+ {3, 1, SPV_OPERAND_TYPE_ID, SPV_NUMBER_NONE, 0},
+ {4, 1, SPV_OPERAND_TYPE_ID, SPV_NUMBER_NONE, 0},
+ {5, 1, SPV_OPERAND_TYPE_ID, SPV_NUMBER_NONE, 0},
+ {6, 1, SPV_OPERAND_TYPE_ID, SPV_NUMBER_NONE, 0},
+};
+
+// A valid parse of kSampleAccessChainInstructionWords
+spv_parsed_instruction_t kSampleAccessChainInstruction = {
+ kSampleAccessChainInstructionWords,
+ uint16_t(7),
+ uint16_t(SpvOpAccessChain),
+ SPV_EXT_INST_TYPE_NONE,
+ 100, // type id
+ 101, // result id
+ kSampleAccessChainOperands,
+ 6};
+
+// The words for an OpControlBarrier instruction.
+uint32_t kSampleControlBarrierInstructionWords[] = {
+ (4 << 16) | uint32_t(SpvOpControlBarrier), 100, 101, 102};
+
+// The operands that would be parsed from kSampleControlBarrierInstructionWords.
+spv_parsed_operand_t kSampleControlBarrierOperands[] = {
+ {1, 1, SPV_OPERAND_TYPE_SCOPE_ID, SPV_NUMBER_NONE, 0}, // Execution
+ {2, 1, SPV_OPERAND_TYPE_SCOPE_ID, SPV_NUMBER_NONE, 0}, // Memory
+ {3, 1, SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID, SPV_NUMBER_NONE,
+ 0}, // Semantics
+};
+
+// A valid parse of kSampleControlBarrierInstructionWords
+spv_parsed_instruction_t kSampleControlBarrierInstruction = {
+ kSampleControlBarrierInstructionWords,
+ uint16_t(4),
+ uint16_t(SpvOpControlBarrier),
+ SPV_EXT_INST_TYPE_NONE,
+ 0, // type id
+ 0, // result id
+ kSampleControlBarrierOperands,
+ 3};
+
TEST(InstructionTest, CreateWithOpcodeAndOperands) {
Instruction inst(kSampleParsedInstruction);
EXPECT_EQ(SpvOpTypeInt, inst.opcode());
@@ -148,4 +197,28 @@ TEST(InstructionTest, OperandIterators) {
EXPECT_EQ(SPV_OPERAND_TYPE_TYPE_ID, (*(inst.cbegin() + 2)).type);
}
+TEST(InstructionTest, ForInIdStandardIdTypes) {
+ Instruction inst(kSampleAccessChainInstruction);
+
+ std::vector<uint32_t> ids;
+ inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); });
+ EXPECT_THAT(ids, Eq(std::vector<uint32_t>{102, 103, 104, 105}));
+
+ ids.clear();
+ inst.ForEachInId([&ids](uint32_t* idptr) { ids.push_back(*idptr); });
+ EXPECT_THAT(ids, Eq(std::vector<uint32_t>{102, 103, 104, 105}));
+}
+
+TEST(InstructionTest, ForInIdNonstandardIdTypes) {
+ Instruction inst(kSampleControlBarrierInstruction);
+
+ std::vector<uint32_t> ids;
+ inst.ForEachInId([&ids](const uint32_t* idptr) { ids.push_back(*idptr); });
+ EXPECT_THAT(ids, Eq(std::vector<uint32_t>{100, 101, 102}));
+
+ ids.clear();
+ inst.ForEachInId([&ids](uint32_t* idptr) { ids.push_back(*idptr); });
+ EXPECT_THAT(ids, Eq(std::vector<uint32_t>{100, 101, 102}));
+}
+
} // anonymous namespace
diff --git a/test/opt/local_single_block_elim.cpp b/test/opt/local_single_block_elim.cpp
index 8c193bda..68c7d5c3 100644
--- a/test/opt/local_single_block_elim.cpp
+++ b/test/opt/local_single_block_elim.cpp
@@ -16,11 +16,12 @@
#include "pass_fixture.h"
#include "pass_utils.h"
-template <typename T> std::vector<T> concat(const std::vector<T> &a, const std::vector<T> &b) {
- std::vector<T> ret = std::vector<T>();
- std::copy(a.begin(), a.end(), back_inserter(ret));
- std::copy(b.begin(), b.end(), back_inserter(ret));
- return ret;
+template <typename T>
+std::vector<T> concat(const std::vector<T>& a, const std::vector<T>& b) {
+ std::vector<T> ret;
+ std::copy(a.begin(), a.end(), back_inserter(ret));
+ std::copy(b.begin(), b.end(), back_inserter(ret));
+ return ret;
}
namespace {
@@ -31,16 +32,16 @@ using LocalSingleBlockLoadStoreElimTest = PassTest<::testing::Test>;
TEST_F(LocalSingleBlockLoadStoreElimTest, SimpleStoreLoadElim) {
// #version 140
- //
+ //
// in vec4 BaseColor;
- //
+ //
// void main()
// {
// vec4 v = BaseColor;
// gl_FragColor = v;
// }
- const std::string predefs =
+ const std::string predefs_before =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
@@ -62,6 +63,27 @@ OpName %gl_FragColor "gl_FragColor"
%gl_FragColor = OpVariable %_ptr_Output_v4float Output
)";
+ const std::string predefs_after =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
const std::string before =
R"(%main = OpFunction %void None %7
%13 = OpLabel
@@ -77,7 +99,6 @@ OpFunctionEnd
const std::string after =
R"(%main = OpFunction %void None %7
%13 = OpLabel
-%v = OpVariable %_ptr_Function_v4float Function
%14 = OpLoad %v4float %BaseColor
OpStore %gl_FragColor %14
OpReturn
@@ -85,15 +106,15 @@ OpFunctionEnd
)";
SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
- predefs + before, predefs + after, true, true);
+ predefs_before + before, predefs_after + after, true, true);
}
TEST_F(LocalSingleBlockLoadStoreElimTest, SimpleLoadLoadElim) {
// #version 140
- //
+ //
// in vec4 BaseColor;
// in float fi;
- //
+ //
// void main()
// {
// vec4 v = BaseColor;
@@ -190,16 +211,16 @@ OpFunctionEnd
}
TEST_F(LocalSingleBlockLoadStoreElimTest,
- NoStoreElimIfInterveningAccessChainLoad) {
+ NoStoreElimIfInterveningAccessChainLoad) {
//
// Note that even though the Load to %v is eliminated, the Store to %v
// is not eliminated due to the following access chain reference.
//
// #version 140
- //
+ //
// in vec4 BaseColor;
// flat in int Idx;
- //
+ //
// void main()
// {
// vec4 v = BaseColor;
@@ -207,7 +228,7 @@ TEST_F(LocalSingleBlockLoadStoreElimTest,
// gl_FragColor = v/f;
// }
- const std::string predefs =
+ const std::string predefs_before =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
@@ -236,6 +257,34 @@ OpDecorate %Idx Flat
%gl_FragColor = OpVariable %_ptr_Output_v4float Output
)";
+ const std::string predefs_after =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %Idx %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %Idx "Idx"
+OpName %gl_FragColor "gl_FragColor"
+OpDecorate %Idx Flat
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%int = OpTypeInt 32 1
+%_ptr_Input_int = OpTypePointer Input %int
+%Idx = OpVariable %_ptr_Input_int Input
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
const std::string before =
R"(%main = OpFunction %void None %9
%18 = OpLabel
@@ -260,7 +309,6 @@ OpFunctionEnd
R"(%main = OpFunction %void None %9
%18 = OpLabel
%v = OpVariable %_ptr_Function_v4float Function
-%f = OpVariable %_ptr_Function_float Function
%19 = OpLoad %v4float %BaseColor
OpStore %v %19
%20 = OpLoad %int %Idx
@@ -274,15 +322,15 @@ OpFunctionEnd
)";
SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
- predefs + before, predefs + after, true, true);
+ predefs_before + before, predefs_after + after, true, true);
}
TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfInterveningAccessChainStore) {
// #version 140
- //
+ //
// in vec4 BaseColor;
// flat in int Idx;
- //
+ //
// void main()
// {
// vec4 v = BaseColor;
@@ -332,14 +380,14 @@ OpFunctionEnd
)";
SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
- assembly, assembly, false, true);
+ assembly, assembly, false, true);
}
TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfInterveningFunctionCall) {
// #version 140
- //
+ //
// in vec4 BaseColor;
- //
+ //
// void foo() {
// }
//
@@ -388,16 +436,16 @@ OpFunctionEnd
)";
SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
- assembly, assembly, false, true);
+ assembly, assembly, false, true);
}
-TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfCopyObjectInFunction) {
+TEST_F(LocalSingleBlockLoadStoreElimTest, ElimIfCopyObjectInFunction) {
// Note: SPIR-V hand edited to insert CopyObject
//
// #version 140
- //
+ //
// in vec4 BaseColor;
- //
+ //
// void main()
// {
// vec4 v1 = BaseColor;
@@ -406,7 +454,7 @@ TEST_F(LocalSingleBlockLoadStoreElimTest, NoElimIfCopyObjectInFunction) {
// gl_FragData[1] = v2;
// }
- const std::string assembly =
+ const std::string predefs_before =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
@@ -435,7 +483,39 @@ OpName %v2 "v2"
%_ptr_Output_v4float = OpTypePointer Output %v4float
%float_0_5 = OpConstant %float 0.5
%int_1 = OpConstant %int 1
-%main = OpFunction %void None %8
+)";
+
+ const std::string predefs_after =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragData
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragData "gl_FragData"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%uint = OpTypeInt 32 0
+%uint_32 = OpConstant %uint 32
+%_arr_v4float_uint_32 = OpTypeArray %v4float %uint_32
+%_ptr_Output__arr_v4float_uint_32 = OpTypePointer Output %_arr_v4float_uint_32
+%gl_FragData = OpVariable %_ptr_Output__arr_v4float_uint_32 Output
+%int = OpTypeInt 32 1
+%int_0 = OpConstant %int 0
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%float_0_5 = OpConstant %float 0.5
+%int_1 = OpConstant %int 1
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
%22 = OpLabel
%v1 = OpVariable %_ptr_Function_v4float Function
%v2 = OpVariable %_ptr_Function_v4float Function
@@ -455,8 +535,22 @@ OpReturn
OpFunctionEnd
)";
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%22 = OpLabel
+%23 = OpLoad %v4float %BaseColor
+%25 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_0
+OpStore %25 %23
+%26 = OpLoad %v4float %BaseColor
+%27 = OpVectorTimesScalar %v4float %26 %float_0_5
+%30 = OpAccessChain %_ptr_Output_v4float %gl_FragData %int_1
+OpStore %30 %27
+OpReturn
+OpFunctionEnd
+)";
+
SinglePassRunAndCheck<opt::LocalSingleBlockLoadStoreElimPass>(
- assembly, assembly, false, true);
+ predefs_before + before, predefs_after + after, true, true);
}
// TODO(greg-lunarg): Add tests to verify handling of these cases:
diff --git a/test/opt/local_single_store_elim_test.cpp b/test/opt/local_single_store_elim_test.cpp
new file mode 100644
index 00000000..cc74dab2
--- /dev/null
+++ b/test/opt/local_single_store_elim_test.cpp
@@ -0,0 +1,655 @@
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using LocalSingleStoreElimTest = PassTest<::testing::Test>;
+
+TEST_F(LocalSingleStoreElimTest, PositiveAndNegative) {
+ // Single store to v is optimized. Multiple store to
+ // f is not optimized.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in float fi;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // float f = fi;
+ // if (f < 0)
+ // f = 0.0;
+ // gl_FragColor = v + f;
+ // }
+
+ const std::string predefs_before =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %fi "fi"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Input_float = OpTypePointer Input %float
+%fi = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string predefs_after =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %fi "fi"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Input_float = OpTypePointer Input %float
+%fi = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %9
+%19 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%f = OpVariable %_ptr_Function_float Function
+%20 = OpLoad %v4float %BaseColor
+OpStore %v %20
+%21 = OpLoad %float %fi
+OpStore %f %21
+%22 = OpLoad %float %f
+%23 = OpFOrdLessThan %bool %22 %float_0
+OpSelectionMerge %24 None
+OpBranchConditional %23 %25 %24
+%25 = OpLabel
+OpStore %f %float_0
+OpBranch %24
+%24 = OpLabel
+%26 = OpLoad %v4float %v
+%27 = OpLoad %float %f
+%28 = OpCompositeConstruct %v4float %27 %27 %27 %27
+%29 = OpFAdd %v4float %26 %28
+OpStore %gl_FragColor %29
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %9
+%19 = OpLabel
+%f = OpVariable %_ptr_Function_float Function
+%20 = OpLoad %v4float %BaseColor
+%21 = OpLoad %float %fi
+OpStore %f %21
+%22 = OpLoad %float %f
+%23 = OpFOrdLessThan %bool %22 %float_0
+OpSelectionMerge %24 None
+OpBranchConditional %23 %25 %24
+%25 = OpLabel
+OpStore %f %float_0
+OpBranch %24
+%24 = OpLabel
+%27 = OpLoad %float %f
+%28 = OpCompositeConstruct %v4float %27 %27 %27 %27
+%29 = OpFAdd %v4float %20 %28
+OpStore %gl_FragColor %29
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleStoreElimPass>(
+ predefs_before + before,
+ predefs_after + after, true, true);
+}
+
+TEST_F(LocalSingleStoreElimTest, MultipleLoads) {
+ // Single store to multiple loads of v is optimized.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in float fi;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // float f = fi;
+ // if (f < 0)
+ // f = 0.0;
+ // gl_FragColor = v + f;
+ // }
+
+ const std::string predefs_before =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %fi "fi"
+OpName %r "r"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Input_float = OpTypePointer Input %float
+%fi = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%float_1 = OpConstant %float 1
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string predefs_after =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %BaseColor "BaseColor"
+OpName %fi "fi"
+OpName %r "r"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Input_float = OpTypePointer Input %float
+%fi = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%float_1 = OpConstant %float 1
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %9
+%19 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%r = OpVariable %_ptr_Function_v4float Function
+%20 = OpLoad %v4float %BaseColor
+OpStore %v %20
+%21 = OpLoad %float %fi
+%22 = OpFOrdLessThan %bool %21 %float_0
+OpSelectionMerge %23 None
+OpBranchConditional %22 %24 %25
+%24 = OpLabel
+%26 = OpLoad %v4float %v
+OpStore %r %26
+OpBranch %23
+%25 = OpLabel
+%27 = OpLoad %v4float %v
+%28 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1
+%29 = OpFSub %v4float %28 %27
+OpStore %r %29
+OpBranch %23
+%23 = OpLabel
+%30 = OpLoad %v4float %r
+OpStore %gl_FragColor %30
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %9
+%19 = OpLabel
+%r = OpVariable %_ptr_Function_v4float Function
+%20 = OpLoad %v4float %BaseColor
+%21 = OpLoad %float %fi
+%22 = OpFOrdLessThan %bool %21 %float_0
+OpSelectionMerge %23 None
+OpBranchConditional %22 %24 %25
+%24 = OpLabel
+OpStore %r %20
+OpBranch %23
+%25 = OpLabel
+%28 = OpCompositeConstruct %v4float %float_1 %float_1 %float_1 %float_1
+%29 = OpFSub %v4float %28 %20
+OpStore %r %29
+OpBranch %23
+%23 = OpLabel
+%30 = OpLoad %v4float %r
+OpStore %gl_FragColor %30
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleStoreElimPass>(
+ predefs_before + before,
+ predefs_after + after, true, true);
+}
+
+TEST_F(LocalSingleStoreElimTest, NoStoreElimWithInterveningAccessChainLoad) {
+ // Last load of v is eliminated, but access chain load and store of v isn't
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // float f = v[3];
+ // gl_FragColor = v * f;
+ // }
+
+ const std::string predefs_before =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%uint = OpTypeInt 32 0
+%uint_3 = OpConstant %uint 3
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string predefs_after =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%uint = OpTypeInt 32 0
+%uint_3 = OpConstant %uint 3
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%f = OpVariable %_ptr_Function_float Function
+%18 = OpLoad %v4float %BaseColor
+OpStore %v %18
+%19 = OpAccessChain %_ptr_Function_float %v %uint_3
+%20 = OpLoad %float %19
+OpStore %f %20
+%21 = OpLoad %v4float %v
+%22 = OpLoad %float %f
+%23 = OpVectorTimesScalar %v4float %21 %22
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%17 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%18 = OpLoad %v4float %BaseColor
+OpStore %v %18
+%19 = OpAccessChain %_ptr_Function_float %v %uint_3
+%20 = OpLoad %float %19
+%23 = OpVectorTimesScalar %v4float %18 %20
+OpStore %gl_FragColor %23
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleStoreElimPass>(
+ predefs_before + before,
+ predefs_after + after, true, true);
+}
+
+TEST_F(LocalSingleStoreElimTest, NoReplaceOfDominatingPartialStore) {
+ // Note: SPIR-V hand edited to initialize v to vec4(0.0)
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ //
+ // void main()
+ // {
+ // vec4 v;
+ // float v[1] = 1.0;
+ // gl_FragColor = v;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %gl_FragColor %BaseColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %gl_FragColor "gl_FragColor"
+OpName %BaseColor "BaseColor"
+%void = OpTypeVoid
+%7 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%float_0 = OpConstant %float 0
+%12 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+%float_1 = OpConstant %float 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%main = OpFunction %void None %7
+%19 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function %12
+%20 = OpAccessChain %_ptr_Function_float %v %uint_1
+OpStore %20 %float_1
+%21 = OpLoad %v4float %v
+OpStore %gl_FragColor %21
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleStoreElimPass>(assembly, assembly, true,
+ true);
+}
+
+TEST_F(LocalSingleStoreElimTest, ElimIfCopyObjectInFunction) {
+ // Note: hand edited to insert OpCopyObject
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in float fi;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // float f = fi;
+ // if (f < 0)
+ // f = 0.0;
+ // gl_FragColor = v + f;
+ // }
+
+ const std::string predefs_before =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %fi "fi"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Input_float = OpTypePointer Input %float
+%fi = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string predefs_after =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %fi %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %fi "fi"
+OpName %gl_FragColor "gl_FragColor"
+%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Input_float = OpTypePointer Input %float
+%fi = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %9
+%19 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%f = OpVariable %_ptr_Function_float Function
+%20 = OpLoad %v4float %BaseColor
+OpStore %v %20
+%21 = OpLoad %float %fi
+OpStore %f %21
+%22 = OpLoad %float %f
+%23 = OpFOrdLessThan %bool %22 %float_0
+OpSelectionMerge %24 None
+OpBranchConditional %23 %25 %24
+%25 = OpLabel
+OpStore %f %float_0
+OpBranch %24
+%24 = OpLabel
+%26 = OpCopyObject %_ptr_Function_v4float %v
+%27 = OpLoad %v4float %26
+%28 = OpLoad %float %f
+%29 = OpCompositeConstruct %v4float %28 %28 %28 %28
+%30 = OpFAdd %v4float %27 %29
+OpStore %gl_FragColor %30
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %9
+%19 = OpLabel
+%f = OpVariable %_ptr_Function_float Function
+%20 = OpLoad %v4float %BaseColor
+%21 = OpLoad %float %fi
+OpStore %f %21
+%22 = OpLoad %float %f
+%23 = OpFOrdLessThan %bool %22 %float_0
+OpSelectionMerge %24 None
+OpBranchConditional %23 %25 %24
+%25 = OpLabel
+OpStore %f %float_0
+OpBranch %24
+%24 = OpLabel
+%28 = OpLoad %float %f
+%29 = OpCompositeConstruct %v4float %28 %28 %28 %28
+%30 = OpFAdd %v4float %20 %29
+OpStore %gl_FragColor %30
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleStoreElimPass>(
+ predefs_before + before, predefs_after + after, true, true);
+}
+
+TEST_F(LocalSingleStoreElimTest, NoOptIfStoreNotDominating) {
+ // Single store to f not optimized because it does not dominate
+ // the load.
+ //
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in float fi;
+ //
+ // void main()
+ // {
+ // float f;
+ // if (fi < 0)
+ // f = 0.5;
+ // if (fi < 0)
+ // gl_FragColor = BaseColor * f;
+ // else
+ // gl_FragColor = BaseColor;
+ // }
+
+ const std::string assembly =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %fi %gl_FragColor %BaseColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+OpName %main "main"
+OpName %fi "fi"
+OpName %f "f"
+OpName %gl_FragColor "gl_FragColor"
+OpName %BaseColor "BaseColor"
+%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_ptr_Input_float = OpTypePointer Input %float
+%fi = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0_5 = OpConstant %float 0.5
+%v4float = OpTypeVector %float 4
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%main = OpFunction %void None %8
+%18 = OpLabel
+%f = OpVariable %_ptr_Function_float Function
+%19 = OpLoad %float %fi
+%20 = OpFOrdLessThan %bool %19 %float_0
+OpSelectionMerge %21 None
+OpBranchConditional %20 %22 %21
+%22 = OpLabel
+OpStore %f %float_0_5
+OpBranch %21
+%21 = OpLabel
+%23 = OpLoad %float %fi
+%24 = OpFOrdLessThan %bool %23 %float_0
+OpSelectionMerge %25 None
+OpBranchConditional %24 %26 %27
+%26 = OpLabel
+%28 = OpLoad %v4float %BaseColor
+%29 = OpLoad %float %f
+%30 = OpVectorTimesScalar %v4float %28 %29
+OpStore %gl_FragColor %30
+OpBranch %25
+%27 = OpLabel
+%31 = OpLoad %v4float %BaseColor
+OpStore %gl_FragColor %31
+OpBranch %25
+%25 = OpLabel
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalSingleStoreElimPass>(assembly, assembly, true,
+ true);
+}
+
+// TODO(greg-lunarg): Add tests to verify handling of these cases:
+//
+// Other types
+// Others?
+
+} // anonymous namespace
diff --git a/test/opt/local_ssa_elim_test.cpp b/test/opt/local_ssa_elim_test.cpp
new file mode 100644
index 00000000..bcee7ca3
--- /dev/null
+++ b/test/opt/local_ssa_elim_test.cpp
@@ -0,0 +1,1239 @@
+// Copyright (c) 2017 Valve Corporation
+// Copyright (c) 2017 LunarG Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "pass_fixture.h"
+#include "pass_utils.h"
+
+namespace {
+
+using namespace spvtools;
+
+using LocalSSAElimTest = PassTest<::testing::Test>;
+
+TEST_F(LocalSSAElimTest, ForLoop) {
+ // #version 140
+ //
+ // in vec4 BC;
+ // out float fo;
+ //
+ // void main()
+ // {
+ // float f = 0.0;
+ // for (int i=0; i<4; i++) {
+ // f = f + BC[i];
+ // }
+ // fo = f;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BC %fo
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %f "f"
+OpName %i "i"
+OpName %BC "BC"
+OpName %fo "fo"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %BC "BC"
+OpName %fo "fo"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_0 = OpConstant %int 0
+%int_4 = OpConstant %int 4
+%bool = OpTypeBool
+%v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BC = OpVariable %_ptr_Input_v4float Input
+%_ptr_Input_float = OpTypePointer Input %float
+%int_1 = OpConstant %int 1
+%_ptr_Output_float = OpTypePointer Output %float
+%fo = OpVariable %_ptr_Output_float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%22 = OpLabel
+%f = OpVariable %_ptr_Function_float Function
+%i = OpVariable %_ptr_Function_int Function
+OpStore %f %float_0
+OpStore %i %int_0
+OpBranch %23
+%23 = OpLabel
+OpLoopMerge %24 %25 None
+OpBranch %26
+%26 = OpLabel
+%27 = OpLoad %int %i
+%28 = OpSLessThan %bool %27 %int_4
+OpBranchConditional %28 %29 %24
+%29 = OpLabel
+%30 = OpLoad %float %f
+%31 = OpLoad %int %i
+%32 = OpAccessChain %_ptr_Input_float %BC %31
+%33 = OpLoad %float %32
+%34 = OpFAdd %float %30 %33
+OpStore %f %34
+OpBranch %25
+%25 = OpLabel
+%35 = OpLoad %int %i
+%36 = OpIAdd %int %35 %int_1
+OpStore %i %36
+OpBranch %23
+%24 = OpLabel
+%37 = OpLoad %float %f
+OpStore %fo %37
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%22 = OpLabel
+OpBranch %23
+%23 = OpLabel
+%38 = OpPhi %float %float_0 %22 %34 %25
+%39 = OpPhi %int %int_0 %22 %36 %25
+OpLoopMerge %24 %25 None
+OpBranch %26
+%26 = OpLabel
+%28 = OpSLessThan %bool %39 %int_4
+OpBranchConditional %28 %29 %24
+%29 = OpLabel
+%32 = OpAccessChain %_ptr_Input_float %BC %39
+%33 = OpLoad %float %32
+%34 = OpFAdd %float %38 %33
+OpBranch %25
+%25 = OpLabel
+%36 = OpIAdd %int %39 %int_1
+OpBranch %23
+%24 = OpLabel
+OpStore %fo %38
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalMultiStoreElimPass>(
+ predefs + names_before + predefs2 + before,
+ predefs + names_after + predefs2 + after,
+ true, true);
+}
+
+TEST_F(LocalSSAElimTest, ForLoopWithContinue) {
+ // #version 140
+ //
+ // in vec4 BC;
+ // out float fo;
+ //
+ // void main()
+ // {
+ // float f = 0.0;
+ // for (int i=0; i<4; i++) {
+ // float t = BC[i];
+ // if (t < 0.0)
+ // continue;
+ // f = f + t;
+ // }
+ // fo = f;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BC %fo
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %f "f"
+OpName %i "i"
+OpName %t "t"
+OpName %BC "BC"
+OpName %fo "fo"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %BC "BC"
+OpName %fo "fo"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_0 = OpConstant %int 0
+%int_4 = OpConstant %int 4
+%bool = OpTypeBool
+%v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BC = OpVariable %_ptr_Input_v4float Input
+%_ptr_Input_float = OpTypePointer Input %float
+%int_1 = OpConstant %int 1
+%_ptr_Output_float = OpTypePointer Output %float
+%fo = OpVariable %_ptr_Output_float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %9
+%23 = OpLabel
+%f = OpVariable %_ptr_Function_float Function
+%i = OpVariable %_ptr_Function_int Function
+%t = OpVariable %_ptr_Function_float Function
+OpStore %f %float_0
+OpStore %i %int_0
+OpBranch %24
+%24 = OpLabel
+OpLoopMerge %25 %26 None
+OpBranch %27
+%27 = OpLabel
+%28 = OpLoad %int %i
+%29 = OpSLessThan %bool %28 %int_4
+OpBranchConditional %29 %30 %25
+%30 = OpLabel
+%31 = OpLoad %int %i
+%32 = OpAccessChain %_ptr_Input_float %BC %31
+%33 = OpLoad %float %32
+OpStore %t %33
+%34 = OpLoad %float %t
+%35 = OpFOrdLessThan %bool %34 %float_0
+OpSelectionMerge %36 None
+OpBranchConditional %35 %37 %36
+%37 = OpLabel
+OpBranch %26
+%36 = OpLabel
+%38 = OpLoad %float %f
+%39 = OpLoad %float %t
+%40 = OpFAdd %float %38 %39
+OpStore %f %40
+OpBranch %26
+%26 = OpLabel
+%41 = OpLoad %int %i
+%42 = OpIAdd %int %41 %int_1
+OpStore %i %42
+OpBranch %24
+%25 = OpLabel
+%43 = OpLoad %float %f
+OpStore %fo %43
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%46 = OpUndef %float
+%main = OpFunction %void None %9
+%23 = OpLabel
+OpBranch %24
+%24 = OpLabel
+%44 = OpPhi %float %float_0 %23 %48 %26
+%45 = OpPhi %int %int_0 %23 %42 %26
+%47 = OpPhi %float %46 %23 %33 %26
+OpLoopMerge %25 %26 None
+OpBranch %27
+%27 = OpLabel
+%29 = OpSLessThan %bool %45 %int_4
+OpBranchConditional %29 %30 %25
+%30 = OpLabel
+%32 = OpAccessChain %_ptr_Input_float %BC %45
+%33 = OpLoad %float %32
+%35 = OpFOrdLessThan %bool %33 %float_0
+OpSelectionMerge %36 None
+OpBranchConditional %35 %37 %36
+%37 = OpLabel
+OpBranch %26
+%36 = OpLabel
+%40 = OpFAdd %float %44 %33
+OpBranch %26
+%26 = OpLabel
+%48 = OpPhi %float %44 %37 %40 %36
+%42 = OpIAdd %int %45 %int_1
+OpBranch %24
+%25 = OpLabel
+OpStore %fo %44
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalMultiStoreElimPass>(
+ predefs + names_before + predefs2 + before,
+ predefs + names_after + predefs2 + after,
+ true, true);
+}
+
+TEST_F(LocalSSAElimTest, ForLoopWithBreak) {
+ // #version 140
+ //
+ // in vec4 BC;
+ // out float fo;
+ //
+ // void main()
+ // {
+ // float f = 0.0;
+ // for (int i=0; i<4; i++) {
+ // float t = f + BC[i];
+ // if (t > 1.0)
+ // break;
+ // f = t;
+ // }
+ // fo = f;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BC %fo
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %f "f"
+OpName %i "i"
+OpName %t "t"
+OpName %BC "BC"
+OpName %fo "fo"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %BC "BC"
+OpName %fo "fo"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_0 = OpConstant %int 0
+%int_4 = OpConstant %int 4
+%bool = OpTypeBool
+%v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BC = OpVariable %_ptr_Input_v4float Input
+%_ptr_Input_float = OpTypePointer Input %float
+%float_1 = OpConstant %float 1
+%int_1 = OpConstant %int 1
+%_ptr_Output_float = OpTypePointer Output %float
+%fo = OpVariable %_ptr_Output_float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %9
+%24 = OpLabel
+%f = OpVariable %_ptr_Function_float Function
+%i = OpVariable %_ptr_Function_int Function
+%t = OpVariable %_ptr_Function_float Function
+OpStore %f %float_0
+OpStore %i %int_0
+OpBranch %25
+%25 = OpLabel
+OpLoopMerge %26 %27 None
+OpBranch %28
+%28 = OpLabel
+%29 = OpLoad %int %i
+%30 = OpSLessThan %bool %29 %int_4
+OpBranchConditional %30 %31 %26
+%31 = OpLabel
+%32 = OpLoad %float %f
+%33 = OpLoad %int %i
+%34 = OpAccessChain %_ptr_Input_float %BC %33
+%35 = OpLoad %float %34
+%36 = OpFAdd %float %32 %35
+OpStore %t %36
+%37 = OpLoad %float %t
+%38 = OpFOrdGreaterThan %bool %37 %float_1
+OpSelectionMerge %39 None
+OpBranchConditional %38 %40 %39
+%40 = OpLabel
+OpBranch %26
+%39 = OpLabel
+%41 = OpLoad %float %t
+OpStore %f %41
+OpBranch %27
+%27 = OpLabel
+%42 = OpLoad %int %i
+%43 = OpIAdd %int %42 %int_1
+OpStore %i %43
+OpBranch %25
+%26 = OpLabel
+%44 = OpLoad %float %f
+OpStore %fo %44
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%47 = OpUndef %float
+%main = OpFunction %void None %9
+%24 = OpLabel
+OpBranch %25
+%25 = OpLabel
+%45 = OpPhi %float %float_0 %24 %36 %27
+%46 = OpPhi %int %int_0 %24 %43 %27
+%48 = OpPhi %float %47 %24 %36 %27
+OpLoopMerge %26 %27 None
+OpBranch %28
+%28 = OpLabel
+%30 = OpSLessThan %bool %46 %int_4
+OpBranchConditional %30 %31 %26
+%31 = OpLabel
+%34 = OpAccessChain %_ptr_Input_float %BC %46
+%35 = OpLoad %float %34
+%36 = OpFAdd %float %45 %35
+%38 = OpFOrdGreaterThan %bool %36 %float_1
+OpSelectionMerge %39 None
+OpBranchConditional %38 %40 %39
+%40 = OpLabel
+OpBranch %26
+%39 = OpLabel
+OpBranch %27
+%27 = OpLabel
+%43 = OpIAdd %int %46 %int_1
+OpBranch %25
+%26 = OpLabel
+%49 = OpPhi %float %48 %28 %36 %40
+OpStore %fo %45
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalMultiStoreElimPass>(
+ predefs + names_before + predefs2 + before,
+ predefs + names_after + predefs2 + after,
+ true, true);
+}
+
+TEST_F(LocalSSAElimTest, SwapProblem) {
+ // #version 140
+ //
+ // in float fe;
+ // out float fo;
+ //
+ // void main()
+ // {
+ // float f1 = 0.0;
+ // float f2 = 1.0;
+ // int ie = int(fe);
+ // for (int i=0; i<ie; i++) {
+ // float t = f1;
+ // f1 = f2;
+ // f2 = t;
+ // }
+ // fo = f1;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %fe %fo
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %f1 "f1"
+OpName %f2 "f2"
+OpName %ie "ie"
+OpName %fe "fe"
+OpName %i "i"
+OpName %t "t"
+OpName %fo "fo"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %fe "fe"
+OpName %fo "fo"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%11 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%float_1 = OpConstant %float 1
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%_ptr_Input_float = OpTypePointer Input %float
+%fe = OpVariable %_ptr_Input_float Input
+%int_0 = OpConstant %int 0
+%bool = OpTypeBool
+%int_1 = OpConstant %int 1
+%_ptr_Output_float = OpTypePointer Output %float
+%fo = OpVariable %_ptr_Output_float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %11
+%23 = OpLabel
+%f1 = OpVariable %_ptr_Function_float Function
+%f2 = OpVariable %_ptr_Function_float Function
+%ie = OpVariable %_ptr_Function_int Function
+%i = OpVariable %_ptr_Function_int Function
+%t = OpVariable %_ptr_Function_float Function
+OpStore %f1 %float_0
+OpStore %f2 %float_1
+%24 = OpLoad %float %fe
+%25 = OpConvertFToS %int %24
+OpStore %ie %25
+OpStore %i %int_0
+OpBranch %26
+%26 = OpLabel
+OpLoopMerge %27 %28 None
+OpBranch %29
+%29 = OpLabel
+%30 = OpLoad %int %i
+%31 = OpLoad %int %ie
+%32 = OpSLessThan %bool %30 %31
+OpBranchConditional %32 %33 %27
+%33 = OpLabel
+%34 = OpLoad %float %f1
+OpStore %t %34
+%35 = OpLoad %float %f2
+OpStore %f1 %35
+%36 = OpLoad %float %t
+OpStore %f2 %36
+OpBranch %28
+%28 = OpLabel
+%37 = OpLoad %int %i
+%38 = OpIAdd %int %37 %int_1
+OpStore %i %38
+OpBranch %26
+%27 = OpLabel
+%39 = OpLoad %float %f1
+OpStore %fo %39
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%43 = OpUndef %float
+%main = OpFunction %void None %11
+%23 = OpLabel
+%24 = OpLoad %float %fe
+%25 = OpConvertFToS %int %24
+OpBranch %26
+%26 = OpLabel
+%40 = OpPhi %float %float_0 %23 %41 %28
+%41 = OpPhi %float %float_1 %23 %40 %28
+%42 = OpPhi %int %int_0 %23 %38 %28
+%44 = OpPhi %float %43 %23 %40 %28
+OpLoopMerge %27 %28 None
+OpBranch %29
+%29 = OpLabel
+%32 = OpSLessThan %bool %42 %25
+OpBranchConditional %32 %33 %27
+%33 = OpLabel
+OpBranch %28
+%28 = OpLabel
+%38 = OpIAdd %int %42 %int_1
+OpBranch %26
+%27 = OpLabel
+OpStore %fo %40
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalMultiStoreElimPass>(
+ predefs + names_before + predefs2 + before,
+ predefs + names_after + predefs2 + after,
+ true, true);
+}
+
+TEST_F(LocalSSAElimTest, LostCopyProblem) {
+ // #version 140
+ //
+ // in vec4 BC;
+ // out float fo;
+ //
+ // void main()
+ // {
+ // float f = 0.0;
+ // float t;
+ // for (int i=0; i<4; i++) {
+ // t = f;
+ // f = f + BC[i];
+ // if (f > 1.0)
+ // break;
+ // }
+ // fo = t;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BC %fo
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %f "f"
+OpName %i "i"
+OpName %t "t"
+OpName %BC "BC"
+OpName %fo "fo"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %BC "BC"
+OpName %fo "fo"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%float_0 = OpConstant %float 0
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%int_0 = OpConstant %int 0
+%int_4 = OpConstant %int 4
+%bool = OpTypeBool
+%v4float = OpTypeVector %float 4
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BC = OpVariable %_ptr_Input_v4float Input
+%_ptr_Input_float = OpTypePointer Input %float
+%float_1 = OpConstant %float 1
+%int_1 = OpConstant %int 1
+%_ptr_Output_float = OpTypePointer Output %float
+%fo = OpVariable %_ptr_Output_float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %9
+%24 = OpLabel
+%f = OpVariable %_ptr_Function_float Function
+%i = OpVariable %_ptr_Function_int Function
+%t = OpVariable %_ptr_Function_float Function
+OpStore %f %float_0
+OpStore %i %int_0
+OpBranch %25
+%25 = OpLabel
+OpLoopMerge %26 %27 None
+OpBranch %28
+%28 = OpLabel
+%29 = OpLoad %int %i
+%30 = OpSLessThan %bool %29 %int_4
+OpBranchConditional %30 %31 %26
+%31 = OpLabel
+%32 = OpLoad %float %f
+OpStore %t %32
+%33 = OpLoad %float %f
+%34 = OpLoad %int %i
+%35 = OpAccessChain %_ptr_Input_float %BC %34
+%36 = OpLoad %float %35
+%37 = OpFAdd %float %33 %36
+OpStore %f %37
+%38 = OpLoad %float %f
+%39 = OpFOrdGreaterThan %bool %38 %float_1
+OpSelectionMerge %40 None
+OpBranchConditional %39 %41 %40
+%41 = OpLabel
+OpBranch %26
+%40 = OpLabel
+OpBranch %27
+%27 = OpLabel
+%42 = OpLoad %int %i
+%43 = OpIAdd %int %42 %int_1
+OpStore %i %43
+OpBranch %25
+%26 = OpLabel
+%44 = OpLoad %float %t
+OpStore %fo %44
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%47 = OpUndef %float
+%main = OpFunction %void None %9
+%24 = OpLabel
+OpBranch %25
+%25 = OpLabel
+%45 = OpPhi %float %float_0 %24 %37 %27
+%46 = OpPhi %int %int_0 %24 %43 %27
+%48 = OpPhi %float %47 %24 %45 %27
+OpLoopMerge %26 %27 None
+OpBranch %28
+%28 = OpLabel
+%30 = OpSLessThan %bool %46 %int_4
+OpBranchConditional %30 %31 %26
+%31 = OpLabel
+%35 = OpAccessChain %_ptr_Input_float %BC %46
+%36 = OpLoad %float %35
+%37 = OpFAdd %float %45 %36
+%39 = OpFOrdGreaterThan %bool %37 %float_1
+OpSelectionMerge %40 None
+OpBranchConditional %39 %41 %40
+%41 = OpLabel
+OpBranch %26
+%40 = OpLabel
+OpBranch %27
+%27 = OpLabel
+%43 = OpIAdd %int %46 %int_1
+OpBranch %25
+%26 = OpLabel
+%49 = OpPhi %float %45 %28 %37 %41
+%50 = OpPhi %float %48 %28 %45 %41
+OpStore %fo %50
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalMultiStoreElimPass>(
+ predefs + names_before + predefs2 + before,
+ predefs + names_after + predefs2 + after,
+ true, true);
+}
+
+TEST_F(LocalSSAElimTest, IfThenElse) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in float f;
+ //
+ // void main()
+ // {
+ // vec4 v;
+ // if (f >= 0)
+ // v = BaseColor * 0.5;
+ // else
+ // v = BaseColor + vec4(1.0,1.0,1.0,1.0);
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %f %BaseColor %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %f "f"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %f "f"
+OpName %BaseColor "BaseColor"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%_ptr_Input_float = OpTypePointer Input %float
+%f = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%float_0_5 = OpConstant %float 0.5
+%float_1 = OpConstant %float 1
+%18 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%20 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%21 = OpLoad %float %f
+%22 = OpFOrdGreaterThanEqual %bool %21 %float_0
+OpSelectionMerge %23 None
+OpBranchConditional %22 %24 %25
+%24 = OpLabel
+%26 = OpLoad %v4float %BaseColor
+%27 = OpVectorTimesScalar %v4float %26 %float_0_5
+OpStore %v %27
+OpBranch %23
+%25 = OpLabel
+%28 = OpLoad %v4float %BaseColor
+%29 = OpFAdd %v4float %28 %18
+OpStore %v %29
+OpBranch %23
+%23 = OpLabel
+%30 = OpLoad %v4float %v
+OpStore %gl_FragColor %30
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%20 = OpLabel
+%21 = OpLoad %float %f
+%22 = OpFOrdGreaterThanEqual %bool %21 %float_0
+OpSelectionMerge %23 None
+OpBranchConditional %22 %24 %25
+%24 = OpLabel
+%26 = OpLoad %v4float %BaseColor
+%27 = OpVectorTimesScalar %v4float %26 %float_0_5
+OpBranch %23
+%25 = OpLabel
+%28 = OpLoad %v4float %BaseColor
+%29 = OpFAdd %v4float %28 %18
+OpBranch %23
+%23 = OpLabel
+%31 = OpPhi %v4float %27 %24 %29 %25
+OpStore %gl_FragColor %31
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalMultiStoreElimPass>(
+ predefs + names_before + predefs2 + before,
+ predefs + names_after + predefs2 + after,
+ true, true);
+}
+
+TEST_F(LocalSSAElimTest, IfThen) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in float f;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // if (f <= 0)
+ // v = v * 0.5;
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %f %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%8 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%_ptr_Input_float = OpTypePointer Input %float
+%f = OpVariable %_ptr_Input_float Input
+%float_0 = OpConstant %float 0
+%bool = OpTypeBool
+%float_0_5 = OpConstant %float 0.5
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %8
+%18 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%19 = OpLoad %v4float %BaseColor
+OpStore %v %19
+%20 = OpLoad %float %f
+%21 = OpFOrdLessThanEqual %bool %20 %float_0
+OpSelectionMerge %22 None
+OpBranchConditional %21 %23 %22
+%23 = OpLabel
+%24 = OpLoad %v4float %v
+%25 = OpVectorTimesScalar %v4float %24 %float_0_5
+OpStore %v %25
+OpBranch %22
+%22 = OpLabel
+%26 = OpLoad %v4float %v
+OpStore %gl_FragColor %26
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %8
+%18 = OpLabel
+%19 = OpLoad %v4float %BaseColor
+%20 = OpLoad %float %f
+%21 = OpFOrdLessThanEqual %bool %20 %float_0
+OpSelectionMerge %22 None
+OpBranchConditional %21 %23 %22
+%23 = OpLabel
+%25 = OpVectorTimesScalar %v4float %19 %float_0_5
+OpBranch %22
+%22 = OpLabel
+%27 = OpPhi %v4float %19 %18 %25 %23
+OpStore %gl_FragColor %27
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalMultiStoreElimPass>(
+ predefs + names_before + predefs2 + before,
+ predefs + names_after + predefs2 + after,
+ true, true);
+}
+
+TEST_F(LocalSSAElimTest, Switch) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in float f;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // int i = int(f);
+ // switch (i) {
+ // case 0:
+ // v = v * 0.1;
+ // break;
+ // case 1:
+ // v = v * 0.3;
+ // break;
+ // case 2:
+ // v = v * 0.7;
+ // break;
+ // default:
+ // break;
+ // }
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %f %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %i "i"
+OpName %f "f"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%_ptr_Input_float = OpTypePointer Input %float
+%f = OpVariable %_ptr_Input_float Input
+%float_0_1 = OpConstant %float 0.1
+%float_0_3 = OpConstant %float 0.3
+%float_0_7 = OpConstant %float 0.7
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %9
+%21 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%i = OpVariable %_ptr_Function_int Function
+%22 = OpLoad %v4float %BaseColor
+OpStore %v %22
+%23 = OpLoad %float %f
+%24 = OpConvertFToS %int %23
+OpStore %i %24
+%25 = OpLoad %int %i
+OpSelectionMerge %26 None
+OpSwitch %25 %27 0 %28 1 %29 2 %30
+%27 = OpLabel
+OpBranch %26
+%28 = OpLabel
+%31 = OpLoad %v4float %v
+%32 = OpVectorTimesScalar %v4float %31 %float_0_1
+OpStore %v %32
+OpBranch %26
+%29 = OpLabel
+%33 = OpLoad %v4float %v
+%34 = OpVectorTimesScalar %v4float %33 %float_0_3
+OpStore %v %34
+OpBranch %26
+%30 = OpLabel
+%35 = OpLoad %v4float %v
+%36 = OpVectorTimesScalar %v4float %35 %float_0_7
+OpStore %v %36
+OpBranch %26
+%26 = OpLabel
+%37 = OpLoad %v4float %v
+OpStore %gl_FragColor %37
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %9
+%21 = OpLabel
+%22 = OpLoad %v4float %BaseColor
+%23 = OpLoad %float %f
+%24 = OpConvertFToS %int %23
+OpSelectionMerge %26 None
+OpSwitch %24 %27 0 %28 1 %29 2 %30
+%27 = OpLabel
+OpBranch %26
+%28 = OpLabel
+%32 = OpVectorTimesScalar %v4float %22 %float_0_1
+OpBranch %26
+%29 = OpLabel
+%34 = OpVectorTimesScalar %v4float %22 %float_0_3
+OpBranch %26
+%30 = OpLabel
+%36 = OpVectorTimesScalar %v4float %22 %float_0_7
+OpBranch %26
+%26 = OpLabel
+%38 = OpPhi %v4float %22 %27 %32 %28 %34 %29 %36 %30
+OpStore %gl_FragColor %38
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalMultiStoreElimPass>(
+ predefs + names_before + predefs2 + before,
+ predefs + names_after + predefs2 + after,
+ true, true);
+}
+
+TEST_F(LocalSSAElimTest, SwitchWithFallThrough) {
+ // #version 140
+ //
+ // in vec4 BaseColor;
+ // in float f;
+ //
+ // void main()
+ // {
+ // vec4 v = BaseColor;
+ // int i = int(f);
+ // switch (i) {
+ // case 0:
+ // v = v * 0.1;
+ // break;
+ // case 1:
+ // v = v + 0.1;
+ // case 2:
+ // v = v * 0.7;
+ // break;
+ // default:
+ // break;
+ // }
+ // gl_FragColor = v;
+ // }
+
+ const std::string predefs =
+ R"(OpCapability Shader
+%1 = OpExtInstImport "GLSL.std.450"
+OpMemoryModel Logical GLSL450
+OpEntryPoint Fragment %main "main" %BaseColor %f %gl_FragColor
+OpExecutionMode %main OriginUpperLeft
+OpSource GLSL 140
+)";
+
+ const std::string names_before =
+ R"(OpName %main "main"
+OpName %v "v"
+OpName %BaseColor "BaseColor"
+OpName %i "i"
+OpName %f "f"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string names_after =
+ R"(OpName %main "main"
+OpName %BaseColor "BaseColor"
+OpName %f "f"
+OpName %gl_FragColor "gl_FragColor"
+)";
+
+ const std::string predefs2 =
+ R"(%void = OpTypeVoid
+%9 = OpTypeFunction %void
+%float = OpTypeFloat 32
+%v4float = OpTypeVector %float 4
+%_ptr_Function_v4float = OpTypePointer Function %v4float
+%_ptr_Input_v4float = OpTypePointer Input %v4float
+%BaseColor = OpVariable %_ptr_Input_v4float Input
+%int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+%_ptr_Input_float = OpTypePointer Input %float
+%f = OpVariable %_ptr_Input_float Input
+%float_0_1 = OpConstant %float 0.1
+%float_0_7 = OpConstant %float 0.7
+%_ptr_Output_v4float = OpTypePointer Output %v4float
+%gl_FragColor = OpVariable %_ptr_Output_v4float Output
+)";
+
+ const std::string before =
+ R"(%main = OpFunction %void None %9
+%20 = OpLabel
+%v = OpVariable %_ptr_Function_v4float Function
+%i = OpVariable %_ptr_Function_int Function
+%21 = OpLoad %v4float %BaseColor
+OpStore %v %21
+%22 = OpLoad %float %f
+%23 = OpConvertFToS %int %22
+OpStore %i %23
+%24 = OpLoad %int %i
+OpSelectionMerge %25 None
+OpSwitch %24 %26 0 %27 1 %28 2 %29
+%26 = OpLabel
+OpBranch %25
+%27 = OpLabel
+%30 = OpLoad %v4float %v
+%31 = OpVectorTimesScalar %v4float %30 %float_0_1
+OpStore %v %31
+OpBranch %25
+%28 = OpLabel
+%32 = OpLoad %v4float %v
+%33 = OpCompositeConstruct %v4float %float_0_1 %float_0_1 %float_0_1 %float_0_1
+%34 = OpFAdd %v4float %32 %33
+OpStore %v %34
+OpBranch %29
+%29 = OpLabel
+%35 = OpLoad %v4float %v
+%36 = OpVectorTimesScalar %v4float %35 %float_0_7
+OpStore %v %36
+OpBranch %25
+%25 = OpLabel
+%37 = OpLoad %v4float %v
+OpStore %gl_FragColor %37
+OpReturn
+OpFunctionEnd
+)";
+
+ const std::string after =
+ R"(%main = OpFunction %void None %9
+%20 = OpLabel
+%21 = OpLoad %v4float %BaseColor
+%22 = OpLoad %float %f
+%23 = OpConvertFToS %int %22
+OpSelectionMerge %25 None
+OpSwitch %23 %26 0 %27 1 %28 2 %29
+%26 = OpLabel
+OpBranch %25
+%27 = OpLabel
+%31 = OpVectorTimesScalar %v4float %21 %float_0_1
+OpBranch %25
+%28 = OpLabel
+%33 = OpCompositeConstruct %v4float %float_0_1 %float_0_1 %float_0_1 %float_0_1
+%34 = OpFAdd %v4float %21 %33
+OpBranch %29
+%29 = OpLabel
+%38 = OpPhi %v4float %21 %20 %34 %28
+%36 = OpVectorTimesScalar %v4float %38 %float_0_7
+OpBranch %25
+%25 = OpLabel
+%39 = OpPhi %v4float %21 %26 %31 %27 %36 %29
+OpStore %gl_FragColor %39
+OpReturn
+OpFunctionEnd
+)";
+
+ SinglePassRunAndCheck<opt::LocalMultiStoreElimPass>(
+ predefs + names_before + predefs2 + before,
+ predefs + names_after + predefs2 + after,
+ true, true);
+}
+
+// TODO(greg-lunarg): Add tests to verify handling of these cases:
+//
+// No optimization in the presence of
+// access chains
+// function calls
+// OpCopyMemory?
+// unsupported extensions
+// Others?
+
+} // anonymous namespace
diff --git a/test/text_to_binary.extension_test.cpp b/test/text_to_binary.extension_test.cpp
index 45d1a23c..4b7c25c8 100644
--- a/test/text_to_binary.extension_test.cpp
+++ b/test/text_to_binary.extension_test.cpp
@@ -278,6 +278,74 @@ INSTANTIATE_TEST_CASE_P(
})), );
+// SPV_AMD_shader_explicit_vertex_parameter
+
+#define PREAMBLE "%1 = OpExtInstImport \"SPV_AMD_shader_explicit_vertex_parameter\"\n"
+INSTANTIATE_TEST_CASE_P(
+ SPV_AMD_shader_explicit_vertex_parameter, ExtensionRoundTripTest,
+ // We'll get coverage over operand tables by trying the universal
+ // environments, and at least one specific environment.
+ Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_VULKAN_1_0),
+ ValuesIn(std::vector<AssemblyCase>{
+ {PREAMBLE "%3 = OpExtInst %2 %1 InterpolateAtVertexAMD %4 %5\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_explicit_vertex_parameter")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 1, 4, 5})})},
+ })), );
+#undef PREAMBLE
+
+
+// SPV_AMD_shader_trinary_minmax
+
+#define PREAMBLE "%1 = OpExtInstImport \"SPV_AMD_shader_trinary_minmax\"\n"
+INSTANTIATE_TEST_CASE_P(
+ SPV_AMD_shader_trinary_minmax, ExtensionRoundTripTest,
+ // We'll get coverage over operand tables by trying the universal
+ // environments, and at least one specific environment.
+ Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_VULKAN_1_0),
+ ValuesIn(std::vector<AssemblyCase>{
+ {PREAMBLE "%3 = OpExtInst %2 %1 FMin3AMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_trinary_minmax")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 1, 4, 5, 6})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 UMin3AMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_trinary_minmax")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 2, 4, 5, 6})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 SMin3AMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_trinary_minmax")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 3, 4, 5, 6})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 FMax3AMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_trinary_minmax")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 4, 4, 5, 6})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 UMax3AMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_trinary_minmax")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 5, 4, 5, 6})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 SMax3AMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_trinary_minmax")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 6, 4, 5, 6})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 FMid3AMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_trinary_minmax")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 7, 4, 5, 6})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 UMid3AMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_trinary_minmax")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 8, 4, 5, 6})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 SMid3AMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_trinary_minmax")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 9, 4, 5, 6})})},
+ })), );
+#undef PREAMBLE
+
+
// SPV_AMD_gcn_shader
#define PREAMBLE "%1 = OpExtInstImport \"SPV_AMD_gcn_shader\"\n"
@@ -288,14 +356,14 @@ INSTANTIATE_TEST_CASE_P(
Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
SPV_ENV_VULKAN_1_0),
ValuesIn(std::vector<AssemblyCase>{
- {PREAMBLE "%3 = OpExtInst %2 %1 CubeFaceCoordAMD %4\n",
- Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
- MakeVector("SPV_AMD_gcn_shader")),
- MakeInstruction(SpvOpExtInst, {2, 3, 1, 2, 4})})},
{PREAMBLE "%3 = OpExtInst %2 %1 CubeFaceIndexAMD %4\n",
Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
MakeVector("SPV_AMD_gcn_shader")),
MakeInstruction(SpvOpExtInst, {2, 3, 1, 1, 4})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 CubeFaceCoordAMD %4\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_gcn_shader")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 2, 4})})},
{PREAMBLE "%3 = OpExtInst %2 %1 TimeAMD\n",
Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
MakeVector("SPV_AMD_gcn_shader")),
@@ -304,6 +372,36 @@ INSTANTIATE_TEST_CASE_P(
#undef PREAMBLE
+// SPV_AMD_shader_ballot
+
+#define PREAMBLE "%1 = OpExtInstImport \"SPV_AMD_shader_ballot\"\n"
+INSTANTIATE_TEST_CASE_P(
+ SPV_AMD_shader_ballot, ExtensionRoundTripTest,
+ // We'll get coverage over operand tables by trying the universal
+ // environments, and at least one specific environment.
+ Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1,
+ SPV_ENV_VULKAN_1_0),
+ ValuesIn(std::vector<AssemblyCase>{
+ {PREAMBLE "%3 = OpExtInst %2 %1 SwizzleInvocationsAMD %4 %5\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_ballot")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 1, 4, 5})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 SwizzleInvocationsMaskedAMD %4 %5\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_ballot")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 2, 4, 5})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 WriteInvocationAMD %4 %5 %6\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_ballot")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 3, 4, 5, 6})})},
+ {PREAMBLE "%3 = OpExtInst %2 %1 MbcntAMD %4\n",
+ Concatenate({MakeInstruction(SpvOpExtInstImport, {1},
+ MakeVector("SPV_AMD_shader_ballot")),
+ MakeInstruction(SpvOpExtInst, {2, 3, 1, 4, 4})})},
+ })), );
+#undef PREAMBLE
+
+
// SPV_KHR_variable_pointers
INSTANTIATE_TEST_CASE_P(
diff --git a/test/val/val_extensions_test.cpp b/test/val/val_extensions_test.cpp
index 1afc0c4a..dd668762 100644
--- a/test/val/val_extensions_test.cpp
+++ b/test/val/val_extensions_test.cpp
@@ -43,7 +43,13 @@ string GetErrorString(const std::string& extension) {
}
INSTANTIATE_TEST_CASE_P(ExpectSuccess, ValidateKnownExtensions, Values(
+ "SPV_AMD_shader_explicit_vertex_parameter",
+ "SPV_AMD_shader_trinary_minmax",
"SPV_AMD_gcn_shader",
+ "SPV_AMD_shader_ballot",
+ "SPV_AMD_gpu_shader_half_float",
+ "SPV_AMD_texture_gather_bias_lod",
+ "SPV_AMD_gpu_shader_int16",
"SPV_KHR_shader_ballot",
"SPV_KHR_shader_draw_parameters",
"SPV_KHR_subgroup_vote",
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
index 43b9507d..4dc0f138 100644
--- a/tools/CMakeLists.txt
+++ b/tools/CMakeLists.txt
@@ -42,6 +42,8 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
add_spvtools_tool(TARGET spirv-dis SRCS dis/dis.cpp LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-val SRCS val/val.cpp LIBS ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-opt SRCS opt/opt.cpp LIBS SPIRV-Tools-opt ${SPIRV_TOOLS})
+ add_spvtools_tool(TARGET spirv-markv SRCS comp/markv.cpp
+ LIBS SPIRV-Tools-comp ${SPIRV_TOOLS})
add_spvtools_tool(TARGET spirv-stats
SRCS stats/stats.cpp
stats/stats_analyzer.cpp
@@ -55,10 +57,15 @@ if (NOT ${SPIRV_SKIP_EXECUTABLES})
${SPIRV_HEADER_INCLUDE_DIR})
target_include_directories(spirv-stats PRIVATE ${spirv-tools_SOURCE_DIR}
${SPIRV_HEADER_INCLUDE_DIR})
+ target_include_directories(spirv-markv PRIVATE ${spirv-tools_SOURCE_DIR}
+ ${SPIRV_HEADER_INCLUDE_DIR})
- set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt spirv-stats spirv-cfg)
- install(TARGETS ${SPIRV_INSTALL_TARGETS}
- RUNTIME DESTINATION bin
- LIBRARY DESTINATION lib
- ARCHIVE DESTINATION lib)
+ set(SPIRV_INSTALL_TARGETS spirv-as spirv-dis spirv-val spirv-opt spirv-stats spirv-cfg
+ spirv-markv)
+ if(ENABLE_SPIRV_TOOLS_INSTALL)
+ install(TARGETS ${SPIRV_INSTALL_TARGETS}
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
+ endif(ENABLE_SPIRV_TOOLS_INSTALL)
endif()
diff --git a/tools/comp/markv.cpp b/tools/comp/markv.cpp
new file mode 100644
index 00000000..f9df9ca6
--- /dev/null
+++ b/tools/comp/markv.cpp
@@ -0,0 +1,247 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cassert>
+#include <cstdio>
+#include <cstring>
+#include <functional>
+#include <iostream>
+#include <memory>
+#include <vector>
+
+#include "source/spirv_target_env.h"
+#include "source/table.h"
+#include "spirv-tools/markv.h"
+#include "tools/io.h"
+
+namespace {
+
+enum Task {
+ kNoTask = 0,
+ kEncode,
+ kDecode,
+};
+
+struct ScopedContext {
+ ScopedContext(spv_target_env env) : context(spvContextCreate(env)) {}
+ ~ScopedContext() { spvContextDestroy(context); }
+ spv_context context;
+};
+
+void print_usage(char* argv0) {
+ printf(
+ R"(%s - Encodes or decodes a SPIR-V binary to or from a MARK-V binary.
+
+USAGE: %s [e|d] [options] [<filename>]
+
+The input binary is read from <filename>. If no file is specified,
+or if the filename is "-", then the binary is read from standard input.
+
+If no output is specified then the output is printed to stdout in a human
+readable format.
+
+WIP: MARK-V codec is in early stages of development. At the moment it only
+can encode and decode some SPIR-V files and only if exacly the same build of
+software is used (is doesn't write or handle version numbers yet).
+
+Tasks:
+ e Encode SPIR-V to MARK-V.
+ d Decode MARK-V to SPIR-V.
+
+Options:
+ -h, --help Print this help.
+ --comments Write codec comments to stdout.
+ --version Display MARK-V codec version.
+
+ -o <filename> Set the output filename.
+ Output goes to standard output if this option is
+ not specified, or if the filename is "-".
+)",
+ argv0, argv0);
+}
+
+void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
+ const spv_position_t& position,
+ const char* message) {
+ switch (level) {
+ case SPV_MSG_FATAL:
+ case SPV_MSG_INTERNAL_ERROR:
+ case SPV_MSG_ERROR:
+ std::cerr << "error: " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_WARNING:
+ std::cout << "warning: " << position.index << ": " << message
+ << std::endl;
+ break;
+ case SPV_MSG_INFO:
+ std::cout << "info: " << position.index << ": " << message << std::endl;
+ break;
+ default:
+ break;
+ }
+}
+
+} // namespace
+
+int main(int argc, char** argv) {
+ const char* input_filename = nullptr;
+ const char* output_filename = nullptr;
+
+ Task task = kNoTask;
+
+ if (argc < 3) {
+ print_usage(argv[0]);
+ return 0;
+ }
+
+ const char* task_char = argv[1];
+ if (0 == strcmp("e", task_char)) {
+ task = kEncode;
+ } else if (0 == strcmp("d", task_char)) {
+ task = kDecode;
+ }
+
+ if (task == kNoTask) {
+ print_usage(argv[0]);
+ return 1;
+ }
+
+ bool want_comments = false;
+
+ for (int argi = 2; argi < argc; ++argi) {
+ if ('-' == argv[argi][0]) {
+ switch (argv[argi][1]) {
+ case 'h':
+ print_usage(argv[0]);
+ return 0;
+ case 'o': {
+ if (!output_filename && argi + 1 < argc) {
+ output_filename = argv[++argi];
+ } else {
+ print_usage(argv[0]);
+ return 1;
+ }
+ } break;
+ case '-': {
+ if (0 == strcmp(argv[argi], "--help")) {
+ print_usage(argv[0]);
+ return 0;
+ } else if (0 == strcmp(argv[argi], "--comments")) {
+ want_comments = true;
+ } else if (0 == strcmp(argv[argi], "--version")) {
+ fprintf(stderr, "error: Not implemented\n");
+ return 1;
+ } else {
+ print_usage(argv[0]);
+ return 1;
+ }
+ } break;
+ case '\0': {
+ // Setting a filename of "-" to indicate stdin.
+ if (!input_filename) {
+ input_filename = argv[argi];
+ } else {
+ fprintf(stderr, "error: More than one input file specified\n");
+ return 1;
+ }
+ } break;
+ default:
+ print_usage(argv[0]);
+ return 1;
+ }
+ } else {
+ if (!input_filename) {
+ input_filename = argv[argi];
+ } else {
+ fprintf(stderr, "error: More than one input file specified\n");
+ return 1;
+ }
+ }
+ }
+
+ if (task == kDecode && want_comments) {
+ fprintf(stderr, "warning: Decoder comments not yet implemented\n");
+ want_comments = false;
+ }
+
+ const bool write_to_stdout = output_filename == nullptr ||
+ 0 == strcmp(output_filename, "-");
+
+ spv_text comments = nullptr;
+ spv_text* comments_ptr = want_comments ? &comments : nullptr;
+
+ ScopedContext ctx(SPV_ENV_UNIVERSAL_1_2);
+ SetContextMessageConsumer(ctx.context, DiagnosticsMessageHandler);
+
+ if (task == kEncode) {
+ std::vector<uint32_t> contents;
+ if (!ReadFile<uint32_t>(input_filename, "rb", &contents)) return 1;
+
+ std::unique_ptr<spv_markv_encoder_options_t,
+ std::function<void(spv_markv_encoder_options_t*)>> options(
+ spvMarkvEncoderOptionsCreate(), &spvMarkvEncoderOptionsDestroy);
+ spv_markv_binary markv_binary = nullptr;
+
+ if (SPV_SUCCESS !=
+ spvSpirvToMarkv(ctx.context, contents.data(), contents.size(),
+ options.get(), &markv_binary, comments_ptr, nullptr)) {
+ std::cerr << "error: Failed to encode " << input_filename << " to MARK-V "
+ << std::endl;
+ return 1;
+ }
+
+ if (want_comments) {
+ if (!WriteFile<char>(nullptr, "w", comments->str,
+ comments->length)) return 1;
+ }
+
+ if (!want_comments || !write_to_stdout) {
+ if (!WriteFile<uint8_t>(output_filename, "wb", markv_binary->data,
+ markv_binary->length)) return 1;
+ }
+ } else if (task == kDecode) {
+ std::vector<uint8_t> contents;
+ if (!ReadFile<uint8_t>(input_filename, "rb", &contents)) return 1;
+
+ std::unique_ptr<spv_markv_decoder_options_t,
+ std::function<void(spv_markv_decoder_options_t*)>> options(
+ spvMarkvDecoderOptionsCreate(), &spvMarkvDecoderOptionsDestroy);
+ spv_binary spirv_binary = nullptr;
+
+ if (SPV_SUCCESS !=
+ spvMarkvToSpirv(ctx.context, contents.data(), contents.size(),
+ options.get(), &spirv_binary, comments_ptr, nullptr)) {
+ std::cerr << "error: Failed to encode " << input_filename << " to MARK-V "
+ << std::endl;
+ return 1;
+ }
+
+ if (want_comments) {
+ if (!WriteFile<char>(nullptr, "w", comments->str,
+ comments->length)) return 1;
+ }
+
+ if (!want_comments || !write_to_stdout) {
+ if (!WriteFile<uint32_t>(output_filename, "wb", spirv_binary->code,
+ spirv_binary->wordCount)) return 1;
+ }
+ } else {
+ assert(false && "Unknown task");
+ }
+
+ spvTextDestroy(comments);
+
+ return 0;
+}
diff --git a/tools/emacs/CMakeLists.txt b/tools/emacs/CMakeLists.txt
index 3785771f..ecd7c277 100644
--- a/tools/emacs/CMakeLists.txt
+++ b/tools/emacs/CMakeLists.txt
@@ -40,7 +40,9 @@ option(SPIRV_TOOLS_INSTALL_EMACS_HELPERS
${SPIRV_TOOLS_INSTALL_EMACS_HELPERS})
if (${SPIRV_TOOLS_INSTALL_EMACS_HELPERS})
if(EXISTS /etc/emacs/site-start.d)
- install(FILES 50spirv-tools.el DESTINATION /etc/emacs/site-start.d)
+ if(ENABLE_SPIRV_TOOLS_INSTALL)
+ install(FILES 50spirv-tools.el DESTINATION /etc/emacs/site-start.d)
+ endif(ENABLE_SPIRV_TOOLS_INSTALL)
endif()
endif()
diff --git a/tools/lesspipe/CMakeLists.txt b/tools/lesspipe/CMakeLists.txt
index 10b5df45..484e51e5 100644
--- a/tools/lesspipe/CMakeLists.txt
+++ b/tools/lesspipe/CMakeLists.txt
@@ -23,4 +23,6 @@
# permissions.
# We have a .sh extension because Windows users often configure
# executable settings via filename extension.
-install(PROGRAMS spirv-lesspipe.sh DESTINATION bin)
+if(ENABLE_SPIRV_TOOLS_INSTALL)
+ install(PROGRAMS spirv-lesspipe.sh DESTINATION ${CMAKE_INSTALL_BINDIR})
+endif(ENABLE_SPIRV_TOOLS_INSTALL)
diff --git a/tools/opt/opt.cpp b/tools/opt/opt.cpp
index 21dd6764..ced538cf 100644
--- a/tools/opt/opt.cpp
+++ b/tools/opt/opt.cpp
@@ -135,8 +135,20 @@ int main(int argc, char** argv) {
optimizer.RegisterPass(CreateInlinePass());
} else if (0 == strcmp(cur_arg, "--convert-local-access-chains")) {
optimizer.RegisterPass(CreateLocalAccessChainConvertPass());
+ } else if (0 == strcmp(cur_arg, "--eliminate-dead-code-aggressive")) {
+ optimizer.RegisterPass(CreateAggressiveDCEPass());
+ } else if (0 == strcmp(cur_arg, "--eliminate-insert-extract")) {
+ optimizer.RegisterPass(CreateInsertExtractElimPass());
} else if (0 == strcmp(cur_arg, "--eliminate-local-single-block")) {
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass());
+ } else if (0 == strcmp(cur_arg, "--eliminate-local-single-store")) {
+ optimizer.RegisterPass(CreateLocalSingleStoreElimPass());
+ } else if (0 == strcmp(cur_arg, "--merge-blocks")) {
+ optimizer.RegisterPass(CreateBlockMergePass());
+ } else if (0 == strcmp(cur_arg, "--eliminate-dead-branches")) {
+ optimizer.RegisterPass(CreateDeadBranchElimPass());
+ } else if (0 == strcmp(cur_arg, "--eliminate-local-multi-store")) {
+ optimizer.RegisterPass(CreateLocalMultiStoreElimPass());
} else if (0 == strcmp(cur_arg, "--eliminate-dead-const")) {
optimizer.RegisterPass(CreateEliminateDeadConstantPass());
} else if (0 == strcmp(cur_arg, "--fold-spec-const-op-composite")) {
diff --git a/utils/generate_grammar_tables.py b/utils/generate_grammar_tables.py
index c7741f33..67d52106 100755
--- a/utils/generate_grammar_tables.py
+++ b/utils/generate_grammar_tables.py
@@ -27,7 +27,16 @@ PYGEN_VARIABLE_PREFIX = 'pygen_variable'
# Extensions to recognize, but which don't come from the SPIRV-V core grammar.
NONSTANDARD_EXTENSIONS = [
+ # TODO(dneto): Vendor extension names should really be derived from the
+ # content of .json files.
'SPV_AMD_gcn_shader',
+ 'SPV_AMD_shader_ballot',
+ 'SPV_AMD_shader_explicit_vertex_parameter',
+ 'SPV_AMD_shader_trinary_minmax',
+ # The following don't have an extended instruction set grammar file.
+ 'SPV_AMD_gpu_shader_half_float',
+ 'SPV_AMD_texture_gather_bias_lod',
+ 'SPV_AMD_gpu_shader_int16',
# Validator would ignore type declaration unique check. Should only be used
# for legacy autogenerated test files containing multiple instances of the
# same type declaration, if fixing the test by other methods is too